• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 19842460617

08 Oct 2025 10:03AM UTC coverage: 81.663%. Remained the same
19842460617

push

github

web-flow
Update README.md, adding reference to GIANT project

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.11
/deepsensor/data/task.py
1
import deepsensor
2✔
2

3
from typing import Callable, Union, Tuple, List, Optional
2✔
4
import numpy as np
2✔
5
import lab as B
2✔
6
import plum
2✔
7
import copy
2✔
8

9
from ..errors import TaskSetIndexError, GriddedDataError
2✔
10

11

12
class Task(dict):
2✔
13
    """Task dictionary class.
14

15
    Inherits from ``dict`` and adds methods for printing and modifying the
16
    data.
17

18
    Args:
19
        task_dict (dict):
20
            Dictionary containing the task.
21
    """
22

23
    def __init__(self, task_dict: dict) -> None:
2✔
24
        super().__init__(task_dict)
2✔
25

26
        if "ops" not in self:
2✔
27
            # List of operations (str) to indicate how the task has been modified
28
            #   (reshaping, casting, etc)
29
            self["ops"] = []
×
30

31
    @classmethod
2✔
32
    def summarise_str(cls, k, v):
2✔
33
        """Return string summaries for the _str__ method."""
34
        if plum.isinstance(v, B.Numeric):
2✔
35
            return v.shape
2✔
36
        elif plum.isinstance(v, tuple):
2✔
37
            return tuple(vi.shape for vi in v)
×
38
        elif plum.isinstance(v, list):
2✔
39
            return [cls.summarise_str(k, vi) for vi in v]
2✔
40
        else:
41
            return v
2✔
42

43
    @classmethod
2✔
44
    def summarise_repr(cls, k, v) -> str:
2✔
45
        """Summarise the task in a representation that can be printed.
46

47
        Args:
48
            cls (:class:`deepsensor.data.task.Task`:):
49
                Task class.
50
            k (str):
51
                Key of the task dictionary.
52
            v (object):
53
                Value of the task dictionary.
54

55
        Returns:
56
            str: String representation of the task.
57
        """
58
        if v is None:
×
59
            return "None"
×
60
        elif plum.isinstance(v, B.Numeric):
×
61
            return f"{type(v).__name__}/{v.dtype}/{v.shape}"
×
62
        if plum.isinstance(v, deepsensor.backend.nps.mask.Masked):
×
63
            return f"{type(v).__name__}/(y={v.y.dtype}/{v.y.shape})/(mask={v.mask.dtype}/{v.mask.shape})"
×
64
        elif plum.isinstance(v, tuple):
×
65
            # return tuple(vi.shape for vi in v)
66
            return tuple([cls.summarise_repr(k, vi) for vi in v])
×
67
        elif plum.isinstance(v, list):
×
68
            return [cls.summarise_repr(k, vi) for vi in v]
×
69
        else:
70
            return f"{type(v).__name__}/{v}"
×
71

72
    def __str__(self) -> str:
2✔
73
        """Print a convenient summary of the task dictionary.
74

75
        For array entries, print their shape, otherwise print the value.
76
        """
77
        s = ""
2✔
78
        for k, v in self.items():
2✔
79
            if v is None:
2✔
80
                continue
×
81
            s += f"{k}: {Task.summarise_str(k, v)}\n"
2✔
82
        return s
2✔
83

84
    def __repr__(self) -> str:
2✔
85
        """Print a convenient summary of the task dictionary.
86

87
        Print the type of each entry and if it is an array, print its shape,
88
        otherwise print the value.
89
        """
90
        s = ""
×
91
        for k, v in self.items():
×
92
            s += f"{k}: {Task.summarise_repr(k, v)}\n"
×
93
        return s
×
94

95
    def op(self, f: Callable, op_flag: Optional[str] = None):
2✔
96
        """Apply function f to the array elements of a task dictionary.
97

98
        Useful for recasting to a different dtype or reshaping (e.g. adding a
99
        batch dimension).
100

101
        Args:
102
            f (callable):
103
                Function to apply to the array elements of the task.
104
            op_flag (str):
105
                Flag to set in the task dictionary's `ops` key.
106

107
        Returns:
108
            :class:`deepsensor.data.task.Task`:
109
                Task with f applied to the array elements and op_flag set in
110
                the ``ops`` key.
111
        """
112

113
        def recurse(k, v):
2✔
114
            if type(v) is list:
2✔
115
                return [recurse(k, vi) for vi in v]
2✔
116
            elif type(v) is tuple:
2✔
117
                return (recurse(k, v[0]), recurse(k, v[1]))
2✔
118
            elif isinstance(
2✔
119
                v,
120
                (np.ndarray, np.ma.MaskedArray, deepsensor.backend.nps.Masked),
121
            ):
122
                return f(v)
2✔
123
            else:
124
                return v  # covers metadata entries
2✔
125

126
        self = copy.deepcopy(self)  # don't modify the original
2✔
127
        for k, v in self.items():
2✔
128
            self[k] = recurse(k, v)
2✔
129
        self["ops"].append(op_flag)
2✔
130

131
        return self  # altered by reference, but return anyway
2✔
132

133
    def add_batch_dim(self):
2✔
134
        """Add a batch dimension to the arrays in the task dictionary.
135

136
        Returns:
137
            :class:`deepsensor.data.task.Task`:
138
                Task with batch dimension added to the array elements.
139
        """
140
        return self.op(lambda x: x[None, ...], op_flag="batch_dim")
2✔
141

142
    def cast_to_float32(self):
2✔
143
        """Cast the arrays in the task dictionary to float32.
144

145
        Returns:
146
            :class:`deepsensor.data.task.Task`:
147
                Task with arrays cast to float32.
148
        """
149
        return self.op(lambda x: x.astype(np.float32), op_flag="float32")
2✔
150

151
    def flatten_gridded_data(self):
2✔
152
        """Convert any gridded data in ``Task`` to flattened arrays.
153

154
        Necessary for AR sampling, which doesn't yet permit gridded context sets.
155

156
        Args:
157
            task : :class:`~.data.task.Task`
158
                ...
159

160
        Returns:
161
            :class:`deepsensor.data.task.Task`:
162
                ...
163
        """
164
        self["X_c"] = [flatten_X(X) for X in self["X_c"]]
2✔
165
        self["Y_c"] = [flatten_Y(Y) for Y in self["Y_c"]]
2✔
166
        if self["X_t"] is not None:
2✔
167
            self["X_t"] = [flatten_X(X) for X in self["X_t"]]
2✔
168
        if self["Y_t"] is not None:
2✔
169
            self["Y_t"] = [flatten_Y(Y) for Y in self["Y_t"]]
×
170

171
        self["ops"].append("gridded_data_flattened")
2✔
172

173
        return self
2✔
174

175
    def remove_context_nans(self):
2✔
176
        """If NaNs are present in task["Y_c"], remove them (and corresponding task["X_c"]).
177

178
        Returns:
179
            :class:`deepsensor.data.task.Task`:
180
                ...
181
        """
182
        if "batch_dim" in self["ops"]:
×
183
            raise ValueError(
×
184
                "Cannot remove NaNs from task if a batch dim has been added."
185
            )
186

187
        # First check whether there are any NaNs that we need to remove
188
        nans_present = False
×
189
        for Y_c in self["Y_c"]:
×
190
            if B.any(B.isnan(Y_c)):
×
191
                nans_present = True
×
192
                break
×
193

194
        if not nans_present:
×
195
            return self
×
196

197
        # NaNs present in self - remove NaNs
198
        for i, (X, Y) in enumerate(zip(self["X_c"], self["Y_c"])):
×
199
            Y_c_nans = B.isnan(Y)
×
200
            if B.any(Y_c_nans):
×
201
                if isinstance(X, tuple):
×
202
                    # Gridded data - need to flatten to remove NaNs
203
                    X = flatten_X(X)
×
204
                    Y = flatten_Y(Y)
×
205
                    Y_c_nans = flatten_Y(Y_c_nans)
×
206
                Y_c_nans = B.any(Y_c_nans, axis=0)  # shape (n_cargets,)
×
207
                self["X_c"][i] = X[:, ~Y_c_nans]
×
208
                self["Y_c"][i] = Y[:, ~Y_c_nans]
×
209

210
        self["ops"].append("context_nans_removed")
×
211

212
        return self
×
213

214
    def remove_target_nans(self):
2✔
215
        """If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"]).
216

217
        Returns:
218
            :class:`deepsensor.data.task.Task`:
219
                ...
220
        """
221
        if "batch_dim" in self["ops"]:
2✔
222
            raise ValueError(
×
223
                "Cannot remove NaNs from task if a batch dim has been added."
224
            )
225

226
        # First check whether there are any NaNs that we need to remove
227
        nans_present = False
2✔
228
        for Y_t in self["Y_t"]:
2✔
229
            if B.any(B.isnan(Y_t)):
2✔
230
                nans_present = True
2✔
231
                break
2✔
232

233
        if not nans_present:
2✔
234
            return self
2✔
235

236
        # NaNs present in self - remove NaNs
237
        for i, (X, Y) in enumerate(zip(self["X_t"], self["Y_t"])):
2✔
238
            Y_t_nans = B.isnan(Y)
2✔
239
            if "Y_t_aux" in self.keys():
2✔
240
                self["Y_t_aux"] = flatten_Y(self["Y_t_aux"])
×
241
            if B.any(Y_t_nans):
2✔
242
                if isinstance(X, tuple):
2✔
243
                    # Gridded data - need to flatten to remove NaNs
244
                    X = flatten_X(X)
×
245
                    Y = flatten_Y(Y)
×
246
                    Y_t_nans = flatten_Y(Y_t_nans)
×
247
                Y_t_nans = B.any(Y_t_nans, axis=0)  # shape (n_targets,)
2✔
248
                self["X_t"][i] = X[:, ~Y_t_nans]
2✔
249
                self["Y_t"][i] = Y[:, ~Y_t_nans]
2✔
250
                if "Y_t_aux" in self.keys():
2✔
251
                    self["Y_t_aux"] = self["Y_t_aux"][:, ~Y_t_nans]
×
252

253
        self["ops"].append("target_nans_removed")
2✔
254

255
        return self
2✔
256

257
    def mask_nans_numpy(self):
2✔
258
        """Replace NaNs with zeroes and set a mask to indicate where the NaNs
259
        were.
260

261
        Returns:
262
            :class:`deepsensor.data.task.Task`:
263
                Task with NaNs set to zeros and a mask indicating where the
264
                missing values are.
265
        """
266
        if "batch_dim" not in self["ops"]:
2✔
267
            raise ValueError("Must call `add_batch_dim` before `mask_nans_numpy`")
×
268

269
        def f(arr):
2✔
270
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
271
                nps_mask = arr.mask == 0
2✔
272
                nan_mask = np.isnan(arr.y)
2✔
273
                mask = np.logical_or(nps_mask, nan_mask)
2✔
274
                mask = np.any(mask, axis=1, keepdims=True)
2✔
275
                data = arr.y
2✔
276
                data[nan_mask] = 0.0
2✔
277
                arr = deepsensor.backend.nps.Masked(data, mask)
2✔
278
            else:
279
                mask = np.isnan(arr)
2✔
280
                if np.any(mask):
2✔
281
                    # arr = np.ma.MaskedArray(arr, mask=mask, fill_value=0.0)
282
                    arr = np.ma.fix_invalid(arr, fill_value=0.0)
2✔
283
            return arr
2✔
284

285
        return self.op(lambda x: f(x), op_flag="numpy_mask")
2✔
286

287
    def mask_nans_nps(self):
2✔
288
        """...
289

290
        Returns:
291
            :class:`deepsensor.data.task.Task`:
292
                ...
293
        """
294
        if "batch_dim" not in self["ops"]:
2✔
295
            raise ValueError("Must call `add_batch_dim` before `mask_nans_nps`")
×
296
        if "numpy_mask" not in self["ops"]:
2✔
297
            raise ValueError("Must call `mask_nans_numpy` before `mask_nans_nps`")
×
298

299
        def f(arr):
2✔
300
            if isinstance(arr, np.ma.MaskedArray):
2✔
301
                # Mask array (True for observed, False for missing). Keep size 1 variable dim.
302
                mask = ~B.any(arr.mask, axis=1, squeeze=False)
2✔
303
                mask = B.cast(B.dtype(arr.data), mask)
2✔
304
                arr = deepsensor.backend.nps.Masked(arr.data, mask)
2✔
305
            return arr
2✔
306

307
        return self.op(lambda x: f(x), op_flag="nps_mask")
2✔
308

309
    def convert_to_tensor(self):
2✔
310
        """Convert to tensor object based on deep learning backend.
311

312
        Returns:
313
            :class:`deepsensor.data.task.Task`:
314
                Task with arrays converted to deep learning tensor objects.
315
        """
316

317
        def f(arr):
2✔
318
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
319
                arr = deepsensor.backend.nps.Masked(
2✔
320
                    deepsensor.backend.convert_to_tensor(arr.y),
321
                    deepsensor.backend.convert_to_tensor(arr.mask),
322
                )
323
            else:
324
                arr = deepsensor.backend.convert_to_tensor(arr)
2✔
325
            return arr
2✔
326

327
        return self.op(lambda x: f(x), op_flag="tensor")
2✔
328

329

330
def append_obs_to_task(
2✔
331
    task: Task,
332
    X_new: B.Numeric,
333
    Y_new: B.Numeric,
334
    context_set_idx: int,
335
):
336
    """Append a single observation to a context set in ``task``.
337

338
    Makes a deep copy of the data structure to avoid affecting the original
339
    object.
340

341
    ..
342
        TODO: for speed during active learning algs, consider a shallow copy
343
        option plus ability to remove observations.
344

345
    Args:
346
        task (:class:`deepsensor.data.task.Task`:): The task to modify.
347
        X_new (array-like): New observation coordinates.
348
        Y_new (array-like): New observation values.
349
        context_set_idx (int): Index of the context set to append to.
350

351
    Returns:
352
        :class:`deepsensor.data.task.Task`:
353
            Task with new observation appended to the context set.
354
    """
355
    if not 0 <= context_set_idx <= len(task["X_c"]) - 1:
2✔
356
        raise TaskSetIndexError(context_set_idx, len(task["X_c"]), "context")
2✔
357

358
    if isinstance(task["X_c"][context_set_idx], tuple):
2✔
359
        raise GriddedDataError("Cannot append to gridded data")
2✔
360

361
    task_with_new = copy.deepcopy(task)
2✔
362

363
    if Y_new.ndim == 0:
2✔
364
        # Add size-1 observation and data dimension
365
        Y_new = Y_new[None, None]
×
366

367
    # Add size-1 observation dimension
368
    if X_new.ndim == 1:
2✔
369
        X_new = X_new[:, None]
2✔
370
    if Y_new.ndim == 1:
2✔
371
        Y_new = Y_new[:, None]
2✔
372

373
    # Context set with proposed latent sensors
374
    task_with_new["X_c"][context_set_idx] = np.concatenate(
2✔
375
        [task["X_c"][context_set_idx], X_new], axis=-1
376
    )
377

378
    # Append proxy observations
379
    task_with_new["Y_c"][context_set_idx] = np.concatenate(
2✔
380
        [task["Y_c"][context_set_idx], Y_new], axis=-1
381
    )
382

383
    return task_with_new
2✔
384

385

386
def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
387
    """Convert tuple of gridded coords to (2, N) array if necessary.
388

389
    Args:
390
        X (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
391
            ...
392

393
    Returns:
394
        :class:`numpy:numpy.ndarray`
395
            ...
396
    """
397
    if type(X) is tuple:
2✔
398
        X1, X2 = np.meshgrid(X[0], X[1], indexing="ij")
2✔
399
        X = np.stack([X1.ravel(), X2.ravel()], axis=0)
2✔
400
    return X
2✔
401

402

403
def flatten_Y(Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
404
    """Convert gridded data of shape (N_dim, N_x1, N_x2) to (N_dim, N_x1 * N_x2)
405
    array if necessary.
406

407
    Args:
408
        Y (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
409
            ...
410

411
    Returns:
412
        :class:`numpy:numpy.ndarray`
413
            ...
414
    """
415
    if Y.ndim == 3:
2✔
416
        Y = Y.reshape(*Y.shape[:-2], -1)
×
417
    return Y
2✔
418

419

420
def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
2✔
421
    """Concatenate a list of tasks into a single task containing multiple batches.
422

423
    ..
424

425
    Todo:
426
        - Consider moving to ``nps.py`` as this leverages ``neuralprocesses``
427
          functionality.
428
        - Raise error if ``aux_t`` values passed (not supported I don't think)
429

430
    Args:
431
        tasks (List[:class:`deepsensor.data.task.Task`:]):
432
            List of tasks to concatenate into a single task.
433
        multiple (int, optional):
434
            Contexts are padded to the smallest multiple of this number that is
435
            greater than the number of contexts in each task. Defaults to 1
436
            (padded to the largest number of contexts in the tasks). Setting
437
            to a larger number will increase the amount of padding but decrease
438
            the range of tensor shapes presented to the model, which simplifies
439
            the computational graph in graph mode.
440

441
    Returns:
442
        :class:`~.data.task.Task`: Task containing multiple batches.
443

444
    Raises:
445
        ValueError:
446
            If the tasks have different numbers of target sets.
447
        ValueError:
448
            If the tasks have different numbers of targets.
449
        ValueError:
450
            If the tasks have different types of target sets (gridded/
451
            non-gridded).
452
    """
453
    if len(tasks) == 1:
2✔
454
        return tasks[0]
×
455

456
    for i, task in enumerate(tasks):
2✔
457
        if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
2✔
458
            raise ValueError(
×
459
                "Cannot concatenate tasks that have had NaNs masked. "
460
                "Masking will be applied automatically after concatenation."
461
            )
462
        if "target_nans_removed" not in task["ops"]:
2✔
463
            task = task.remove_target_nans()
2✔
464
        if "batch_dim" not in task["ops"]:
2✔
465
            task = task.add_batch_dim()
2✔
466
        if "float32" not in task["ops"]:
2✔
467
            task = task.cast_to_float32()
2✔
468
        tasks[i] = task
2✔
469

470
    # Assert number of target sets equal
471
    n_target_sets = [len(task["Y_t"]) for task in tasks]
2✔
472
    if not all([n == n_target_sets[0] for n in n_target_sets]):
2✔
473
        raise ValueError(
×
474
            f"All tasks must have the same number of target sets to concatenate: got {n_target_sets}. "
475
        )
476
    n_target_sets = n_target_sets[0]
2✔
477

478
    for target_set_i in range(n_target_sets):
2✔
479
        # Raise error if target sets have different numbers of targets across tasks
480
        n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
2✔
481
        if not all([n == n_target_obs[0] for n in n_target_obs]):
2✔
482
            raise ValueError(
2✔
483
                f"All tasks must have the same number of targets to concatenate: got {n_target_obs}. "
484
                "To train with Task batches containing differing numbers of targets, "
485
                "run the model individually over each task and average the losses."
486
            )
487

488
        # Raise error if target sets are different types (gridded/non-gridded) across tasks
489
        if isinstance(tasks[0]["X_t"][target_set_i], tuple):
2✔
490
            for task in tasks:
×
491
                if not isinstance(task["X_t"][target_set_i], tuple):
×
492
                    raise ValueError(
×
493
                        "All tasks must have the same type of target set (gridded or non-gridded) "
494
                        f"to concatenate. For target set {target_set_i}, got {type(task['X_t'][target_set_i])}."
495
                    )
496

497
    # For each task, store list of tuples of (x_c, y_c) (one tuple per context set)
498
    contexts = []
2✔
499
    for i, task in enumerate(tasks):
2✔
500
        contexts_i = list(zip(task["X_c"], task["Y_c"]))
2✔
501
        contexts.append(contexts_i)
2✔
502

503
    # List of tuples of merged (x_c, y_c) along batch dim with padding
504
    # (up to the smallest multiple of `multiple` greater than the number of contexts in each task)
505
    merged_context = [
2✔
506
        deepsensor.backend.nps.merge_contexts(
507
            *[context_set for context_set in contexts_i], multiple=multiple
508
        )
509
        for contexts_i in zip(*contexts)
510
    ]
511

512
    merged_task = copy.deepcopy(tasks[0])
2✔
513

514
    # Convert list of tuples of (x_c, y_c) to list of x_c and list of y_c
515
    merged_task["X_c"] = [c[0] for c in merged_context]
2✔
516
    merged_task["Y_c"] = [c[1] for c in merged_context]
2✔
517

518
    # This assumes that all tasks have the same number of targets
519
    for i in range(n_target_sets):
2✔
520
        if isinstance(tasks[0]["X_t"][i], tuple):
2✔
521
            # Target set is gridded with tuple of coords for `X_t`
522
            merged_task["X_t"][i] = (
×
523
                B.concat(*[t["X_t"][i][0] for t in tasks], axis=0),
524
                B.concat(*[t["X_t"][i][1] for t in tasks], axis=0),
525
            )
526
        else:
527
            # Target set is off-the-grid with tensor for `X_t`
528
            merged_task["X_t"][i] = B.concat(*[t["X_t"][i] for t in tasks], axis=0)
2✔
529
        merged_task["Y_t"][i] = B.concat(*[t["Y_t"][i] for t in tasks], axis=0)
2✔
530

531
    merged_task["time"] = [t["time"] for t in tasks]
2✔
532

533
    merged_task = Task(merged_task)
2✔
534

535
    # Apply masking
536
    merged_task = merged_task.mask_nans_numpy()
2✔
537
    merged_task = merged_task.mask_nans_nps()
2✔
538

539
    return merged_task
2✔
540

541

542
if __name__ == "__main__":  # pragma: no cover
543
    # print working directory
544
    import os
545

546
    print(os.path.abspath(os.getcwd()))
547

548
    import deepsensor.tensorflow as deepsensor
549
    from deepsensor.data.processor import DataProcessor
550
    from deepsensor.data.loader import TaskLoader
551
    from deepsensor.model.convnp import ConvNP
552
    from deepsensor.data.task import concat_tasks
553

554
    import xarray as xr
555
    import numpy as np
556

557
    da_raw = xr.tutorial.open_dataset("air_temperature")
558
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
559
    da = data_processor(da_raw)
560

561
    task_loader = TaskLoader(context=da, target=da)
562

563
    task1 = task_loader("2014-01-01", 50)
564
    task1["Y_c"][0][0, 0] = np.nan
565
    task2 = task_loader("2014-01-01", 100)
566

567
    # task1 = task1.add_batch_dim().mask_nans_numpy().mask_nans_nps()
568
    # task2 = task2.add_batch_dim().mask_nans_numpy().mask_nans_nps()
569

570
    merged_task = concat_tasks([task1, task2])
571
    print(repr(merged_task))
572

573
    print("got here")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc