• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14313118307

07 Apr 2025 03:23PM UTC coverage: 82.511% (+0.8%) from 81.663%
14313118307

Pull #135

github

web-flow
Merge ea3987b38 into 38ec5ef26
Pull Request #135: Enable patchwise training and prediction

294 of 329 new or added lines in 4 files covered. (89.36%)

1 existing line in 1 file now uncovered.

2340 of 2836 relevant lines covered (82.51%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.25
/deepsensor/data/task.py
1
import deepsensor
2✔
2

3
from typing import Callable, Union, Tuple, List, Optional
2✔
4
import numpy as np
2✔
5
import lab as B
2✔
6
import plum
2✔
7
import copy
2✔
8

9
from ..errors import TaskSetIndexError, GriddedDataError
2✔
10

11

12
class Task(dict):
2✔
13
    """Task dictionary class.
14

15
    Inherits from ``dict`` and adds methods for printing and modifying the
16
    data.
17

18
    Args:
19
        task_dict (dict):
20
            Dictionary containing the task.
21
    """
22

23
    def __init__(self, task_dict: dict) -> None:
2✔
24
        super().__init__(task_dict)
2✔
25

26
        if "ops" not in self:
2✔
27
            # List of operations (str) to indicate how the task has been modified
28
            #   (reshaping, casting, etc)
29
            self["ops"] = []
×
30

31
    @classmethod
2✔
32
    def summarise_str(cls, k, v):
2✔
33
        """Return string summaries for the _str__ method."""
34
        if isinstance(v, float):
2✔
NEW
35
            return v
×
36
        elif plum.isinstance(v, B.Numeric):
2✔
37
            return v.shape
2✔
38
        elif plum.isinstance(v, tuple):
2✔
39
            return tuple(vi.shape for vi in v)
×
40
        elif plum.isinstance(v, list):
2✔
41
            return [cls.summarise_str(k, vi) for vi in v]
2✔
42
        else:
43
            return v
2✔
44

45
    @classmethod
2✔
46
    def summarise_repr(cls, k, v) -> str:
2✔
47
        """Summarise the task in a representation that can be printed.
48

49
        Args:
50
            cls (:class:`deepsensor.data.task.Task`:):
51
                Task class.
52
            k (str):
53
                Key of the task dictionary.
54
            v (object):
55
                Value of the task dictionary.
56

57
        Returns:
58
            str: String representation of the task.
59
        """
60
        if v is None:
×
61
            return "None"
×
NEW
62
        elif isinstance(v, float):
×
NEW
63
            return f"{type(v).__name__}"
×
64
        elif plum.isinstance(v, B.Numeric):
×
65
            return f"{type(v).__name__}/{v.dtype}/{v.shape}"
×
66
        if plum.isinstance(v, deepsensor.backend.nps.mask.Masked):
×
67
            return f"{type(v).__name__}/(y={v.y.dtype}/{v.y.shape})/(mask={v.mask.dtype}/{v.mask.shape})"
×
68
        elif plum.isinstance(v, tuple):
×
69
            # return tuple(vi.shape for vi in v)
70
            return tuple([cls.summarise_repr(k, vi) for vi in v])
×
71
        elif plum.isinstance(v, list):
×
72
            return [cls.summarise_repr(k, vi) for vi in v]
×
73
        else:
74
            return f"{type(v).__name__}/{v}"
×
75

76
    def __str__(self) -> str:
2✔
77
        """Print a convenient summary of the task dictionary.
78

79
        For array entries, print their shape, otherwise print the value.
80
        """
81
        s = ""
2✔
82
        for k, v in self.items():
2✔
83
            if v is None:
2✔
84
                continue
×
85
            s += f"{k}: {Task.summarise_str(k, v)}\n"
2✔
86
        return s
2✔
87

88
    def __repr__(self) -> str:
2✔
89
        """Print a convenient summary of the task dictionary.
90

91
        Print the type of each entry and if it is an array, print its shape,
92
        otherwise print the value.
93
        """
94
        s = ""
×
95
        for k, v in self.items():
×
96
            s += f"{k}: {Task.summarise_repr(k, v)}\n"
×
97
        return s
×
98

99
    def op(self, f: Callable, op_flag: Optional[str] = None):
2✔
100
        """Apply function f to the array elements of a task dictionary.
101

102
        Useful for recasting to a different dtype or reshaping (e.g. adding a
103
        batch dimension).
104

105
        Args:
106
            f (callable):
107
                Function to apply to the array elements of the task.
108
            op_flag (str):
109
                Flag to set in the task dictionary's `ops` key.
110

111
        Returns:
112
            :class:`deepsensor.data.task.Task`:
113
                Task with f applied to the array elements and op_flag set in
114
                the ``ops`` key.
115
        """
116

117
        def recurse(k, v):
2✔
118
            if type(v) is list:
2✔
119
                return [recurse(k, vi) for vi in v]
2✔
120
            elif type(v) is tuple:
2✔
121
                return (recurse(k, v[0]), recurse(k, v[1]))
2✔
122
            elif isinstance(
2✔
123
                v,
124
                (np.ndarray, np.ma.MaskedArray, deepsensor.backend.nps.Masked),
125
            ):
126
                return f(v)
2✔
127
            else:
128
                return v  # covers metadata entries
2✔
129

130
        self = copy.deepcopy(self)  # don't modify the original
2✔
131
        for k, v in self.items():
2✔
132
            self[k] = recurse(k, v)
2✔
133
        self["ops"].append(op_flag)
2✔
134

135
        return self  # altered by reference, but return anyway
2✔
136

137
    def add_batch_dim(self):
2✔
138
        """Add a batch dimension to the arrays in the task dictionary.
139

140
        Returns:
141
            :class:`deepsensor.data.task.Task`:
142
                Task with batch dimension added to the array elements.
143
        """
144
        return self.op(lambda x: x[None, ...], op_flag="batch_dim")
2✔
145

146
    def cast_to_float32(self):
2✔
147
        """Cast the arrays in the task dictionary to float32.
148

149
        Returns:
150
            :class:`deepsensor.data.task.Task`:
151
                Task with arrays cast to float32.
152
        """
153
        return self.op(lambda x: x.astype(np.float32), op_flag="float32")
2✔
154

155
    def flatten_gridded_data(self):
2✔
156
        """Convert any gridded data in ``Task`` to flattened arrays.
157

158
        Necessary for AR sampling, which doesn't yet permit gridded context sets.
159

160
        Args:
161
            task : :class:`~.data.task.Task`
162
                ...
163

164
        Returns:
165
            :class:`deepsensor.data.task.Task`:
166
                ...
167
        """
168
        self["X_c"] = [flatten_X(X) for X in self["X_c"]]
2✔
169
        self["Y_c"] = [flatten_Y(Y) for Y in self["Y_c"]]
2✔
170
        if self["X_t"] is not None:
2✔
171
            self["X_t"] = [flatten_X(X) for X in self["X_t"]]
2✔
172
        if self["Y_t"] is not None:
2✔
173
            self["Y_t"] = [flatten_Y(Y) for Y in self["Y_t"]]
×
174

175
        self["ops"].append("gridded_data_flattened")
2✔
176

177
        return self
2✔
178

179
    def remove_context_nans(self):
2✔
180
        """If NaNs are present in task["Y_c"], remove them (and corresponding task["X_c"]).
181

182
        Returns:
183
            :class:`deepsensor.data.task.Task`:
184
                ...
185
        """
186
        if "batch_dim" in self["ops"]:
×
187
            raise ValueError(
×
188
                "Cannot remove NaNs from task if a batch dim has been added."
189
            )
190

191
        # First check whether there are any NaNs that we need to remove
192
        nans_present = False
×
193
        for Y_c in self["Y_c"]:
×
194
            if B.any(B.isnan(Y_c)):
×
195
                nans_present = True
×
196
                break
×
197

198
        if not nans_present:
×
199
            return self
×
200

201
        # NaNs present in self - remove NaNs
202
        for i, (X, Y) in enumerate(zip(self["X_c"], self["Y_c"])):
×
203
            Y_c_nans = B.isnan(Y)
×
204
            if B.any(Y_c_nans):
×
205
                if isinstance(X, tuple):
×
206
                    # Gridded data - need to flatten to remove NaNs
207
                    X = flatten_X(X)
×
208
                    Y = flatten_Y(Y)
×
209
                    Y_c_nans = flatten_Y(Y_c_nans)
×
210
                Y_c_nans = B.any(Y_c_nans, axis=0)  # shape (n_cargets,)
×
211
                self["X_c"][i] = X[:, ~Y_c_nans]
×
212
                self["Y_c"][i] = Y[:, ~Y_c_nans]
×
213

214
        self["ops"].append("context_nans_removed")
×
215

216
        return self
×
217

218
    def remove_target_nans(self):
2✔
219
        """If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"]).
220

221
        Returns:
222
            :class:`deepsensor.data.task.Task`:
223
                ...
224
        """
225
        if "batch_dim" in self["ops"]:
2✔
226
            raise ValueError(
×
227
                "Cannot remove NaNs from task if a batch dim has been added."
228
            )
229

230
        # First check whether there are any NaNs that we need to remove
231
        nans_present = False
2✔
232
        for Y_t in self["Y_t"]:
2✔
233
            if B.any(B.isnan(Y_t)):
2✔
234
                nans_present = True
2✔
235
                break
2✔
236

237
        if not nans_present:
2✔
238
            return self
2✔
239

240
        # NaNs present in self - remove NaNs
241
        for i, (X, Y) in enumerate(zip(self["X_t"], self["Y_t"])):
2✔
242
            Y_t_nans = B.isnan(Y)
2✔
243
            if "Y_t_aux" in self.keys():
2✔
244
                self["Y_t_aux"] = flatten_Y(self["Y_t_aux"])
×
245
            if B.any(Y_t_nans):
2✔
246
                if isinstance(X, tuple):
2✔
247
                    # Gridded data - need to flatten to remove NaNs
248
                    X = flatten_X(X)
×
249
                    Y = flatten_Y(Y)
×
250
                    Y_t_nans = flatten_Y(Y_t_nans)
×
251
                Y_t_nans = B.any(Y_t_nans, axis=0)  # shape (n_targets,)
2✔
252
                self["X_t"][i] = X[:, ~Y_t_nans]
2✔
253
                self["Y_t"][i] = Y[:, ~Y_t_nans]
2✔
254
                if "Y_t_aux" in self.keys():
2✔
255
                    self["Y_t_aux"] = self["Y_t_aux"][:, ~Y_t_nans]
×
256

257
        self["ops"].append("target_nans_removed")
2✔
258

259
        return self
2✔
260

261
    def mask_nans_numpy(self):
2✔
262
        """Replace NaNs with zeroes and set a mask to indicate where the NaNs
263
        were.
264

265
        Returns:
266
            :class:`deepsensor.data.task.Task`:
267
                Task with NaNs set to zeros and a mask indicating where the
268
                missing values are.
269
        """
270
        if "batch_dim" not in self["ops"]:
2✔
271
            raise ValueError("Must call `add_batch_dim` before `mask_nans_numpy`")
×
272

273
        def f(arr):
2✔
274
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
275
                nps_mask = arr.mask == 0
2✔
276
                nan_mask = np.isnan(arr.y)
2✔
277
                mask = np.logical_or(nps_mask, nan_mask)
2✔
278
                mask = np.any(mask, axis=1, keepdims=True)
2✔
279
                data = arr.y
2✔
280
                data[nan_mask] = 0.0
2✔
281
                arr = deepsensor.backend.nps.Masked(data, mask)
2✔
282
            else:
283
                mask = np.isnan(arr)
2✔
284
                if np.any(mask):
2✔
285
                    # arr = np.ma.MaskedArray(arr, mask=mask, fill_value=0.0)
286
                    arr = np.ma.fix_invalid(arr, fill_value=0.0)
2✔
287
            return arr
2✔
288

289
        return self.op(lambda x: f(x), op_flag="numpy_mask")
2✔
290

291
    def mask_nans_nps(self):
2✔
292
        """...
293

294
        Returns:
295
            :class:`deepsensor.data.task.Task`:
296
                ...
297
        """
298
        if "batch_dim" not in self["ops"]:
2✔
299
            raise ValueError("Must call `add_batch_dim` before `mask_nans_nps`")
×
300
        if "numpy_mask" not in self["ops"]:
2✔
301
            raise ValueError("Must call `mask_nans_numpy` before `mask_nans_nps`")
×
302

303
        def f(arr):
2✔
304
            if isinstance(arr, np.ma.MaskedArray):
2✔
305
                # Mask array (True for observed, False for missing). Keep size 1 variable dim.
306
                mask = ~B.any(arr.mask, axis=1, squeeze=False)
2✔
307
                mask = B.cast(B.dtype(arr.data), mask)
2✔
308
                arr = deepsensor.backend.nps.Masked(arr.data, mask)
2✔
309
            return arr
2✔
310

311
        return self.op(lambda x: f(x), op_flag="nps_mask")
2✔
312

313
    def convert_to_tensor(self):
2✔
314
        """Convert to tensor object based on deep learning backend.
315

316
        Returns:
317
            :class:`deepsensor.data.task.Task`:
318
                Task with arrays converted to deep learning tensor objects.
319
        """
320

321
        def f(arr):
2✔
322
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
323
                arr = deepsensor.backend.nps.Masked(
2✔
324
                    deepsensor.backend.convert_to_tensor(arr.y),
325
                    deepsensor.backend.convert_to_tensor(arr.mask),
326
                )
327
            else:
328
                arr = deepsensor.backend.convert_to_tensor(arr)
2✔
329
            return arr
2✔
330

331
        return self.op(lambda x: f(x), op_flag="tensor")
2✔
332

333

334
def append_obs_to_task(
2✔
335
    task: Task,
336
    X_new: B.Numeric,
337
    Y_new: B.Numeric,
338
    context_set_idx: int,
339
):
340
    """Append a single observation to a context set in ``task``.
341

342
    Makes a deep copy of the data structure to avoid affecting the original
343
    object.
344

345
    ..
346
        TODO: for speed during active learning algs, consider a shallow copy
347
        option plus ability to remove observations.
348

349
    Args:
350
        task (:class:`deepsensor.data.task.Task`:): The task to modify.
351
        X_new (array-like): New observation coordinates.
352
        Y_new (array-like): New observation values.
353
        context_set_idx (int): Index of the context set to append to.
354

355
    Returns:
356
        :class:`deepsensor.data.task.Task`:
357
            Task with new observation appended to the context set.
358
    """
359
    if not 0 <= context_set_idx <= len(task["X_c"]) - 1:
2✔
360
        raise TaskSetIndexError(context_set_idx, len(task["X_c"]), "context")
2✔
361

362
    if isinstance(task["X_c"][context_set_idx], tuple):
2✔
363
        raise GriddedDataError("Cannot append to gridded data")
2✔
364

365
    task_with_new = copy.deepcopy(task)
2✔
366

367
    if Y_new.ndim == 0:
2✔
368
        # Add size-1 observation and data dimension
369
        Y_new = Y_new[None, None]
×
370

371
    # Add size-1 observation dimension
372
    if X_new.ndim == 1:
2✔
373
        X_new = X_new[:, None]
2✔
374
    if Y_new.ndim == 1:
2✔
375
        Y_new = Y_new[:, None]
2✔
376

377
    # Context set with proposed latent sensors
378
    task_with_new["X_c"][context_set_idx] = np.concatenate(
2✔
379
        [task["X_c"][context_set_idx], X_new], axis=-1
380
    )
381

382
    # Append proxy observations
383
    task_with_new["Y_c"][context_set_idx] = np.concatenate(
2✔
384
        [task["Y_c"][context_set_idx], Y_new], axis=-1
385
    )
386

387
    return task_with_new
2✔
388

389

390
def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
391
    """Convert tuple of gridded coords to (2, N) array if necessary.
392

393
    Args:
394
        X (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
395
            ...
396

397
    Returns:
398
        :class:`numpy:numpy.ndarray`
399
            ...
400
    """
401
    if type(X) is tuple:
2✔
402
        X1, X2 = np.meshgrid(X[0], X[1], indexing="ij")
2✔
403
        X = np.stack([X1.ravel(), X2.ravel()], axis=0)
2✔
404
    return X
2✔
405

406

407
def flatten_Y(Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
408
    """Convert gridded data of shape (N_dim, N_x1, N_x2) to (N_dim, N_x1 * N_x2)
409
    array if necessary.
410

411
    Args:
412
        Y (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
413
            ...
414

415
    Returns:
416
        :class:`numpy:numpy.ndarray`
417
            ...
418
    """
419
    if Y.ndim == 3:
2✔
420
        Y = Y.reshape(*Y.shape[:-2], -1)
×
421
    return Y
2✔
422

423

424
def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
2✔
425
    """Concatenate a list of tasks into a single task containing multiple batches.
426

427
    ..
428

429
    Todo:
430
        - Consider moving to ``nps.py`` as this leverages ``neuralprocesses``
431
          functionality.
432
        - Raise error if ``aux_t`` values passed (not supported I don't think)
433

434
    Args:
435
        tasks (List[:class:`deepsensor.data.task.Task`:]):
436
            List of tasks to concatenate into a single task.
437
        multiple (int, optional):
438
            Contexts are padded to the smallest multiple of this number that is
439
            greater than the number of contexts in each task. Defaults to 1
440
            (padded to the largest number of contexts in the tasks). Setting
441
            to a larger number will increase the amount of padding but decrease
442
            the range of tensor shapes presented to the model, which simplifies
443
            the computational graph in graph mode.
444

445
    Returns:
446
        :class:`~.data.task.Task`: Task containing multiple batches.
447

448
    Raises:
449
        ValueError:
450
            If the tasks have different numbers of target sets.
451
        ValueError:
452
            If the tasks have different numbers of targets.
453
        ValueError:
454
            If the tasks have different types of target sets (gridded/
455
            non-gridded).
456
    """
457
    if len(tasks) == 1:
2✔
458
        return tasks[0]
×
459

460
    for i, task in enumerate(tasks):
2✔
461
        if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
2✔
462
            raise ValueError(
×
463
                "Cannot concatenate tasks that have had NaNs masked. "
464
                "Masking will be applied automatically after concatenation."
465
            )
466
        if "target_nans_removed" not in task["ops"]:
2✔
467
            task = task.remove_target_nans()
2✔
468
        if "batch_dim" not in task["ops"]:
2✔
469
            task = task.add_batch_dim()
2✔
470
        if "float32" not in task["ops"]:
2✔
471
            task = task.cast_to_float32()
2✔
472
        tasks[i] = task
2✔
473

474
    # Assert number of target sets equal
475
    n_target_sets = [len(task["Y_t"]) for task in tasks]
2✔
476
    if not all([n == n_target_sets[0] for n in n_target_sets]):
2✔
477
        raise ValueError(
×
478
            f"All tasks must have the same number of target sets to concatenate: got {n_target_sets}. "
479
        )
480
    n_target_sets = n_target_sets[0]
2✔
481

482
    for target_set_i in range(n_target_sets):
2✔
483
        # Raise error if target sets have different numbers of targets across tasks
484
        n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
2✔
485
        if not all([n == n_target_obs[0] for n in n_target_obs]):
2✔
486
            raise ValueError(
2✔
487
                f"All tasks must have the same number of targets to concatenate: got {n_target_obs}. "
488
                "To train with Task batches containing differing numbers of targets, "
489
                "run the model individually over each task and average the losses."
490
            )
491

492
        # Raise error if target sets are different types (gridded/non-gridded) across tasks
493
        if isinstance(tasks[0]["X_t"][target_set_i], tuple):
2✔
494
            for task in tasks:
×
495
                if not isinstance(task["X_t"][target_set_i], tuple):
×
496
                    raise ValueError(
×
497
                        "All tasks must have the same type of target set (gridded or non-gridded) "
498
                        f"to concatenate. For target set {target_set_i}, got {type(task['X_t'][target_set_i])}."
499
                    )
500

501
    # For each task, store list of tuples of (x_c, y_c) (one tuple per context set)
502
    contexts = []
2✔
503
    for i, task in enumerate(tasks):
2✔
504
        contexts_i = list(zip(task["X_c"], task["Y_c"]))
2✔
505
        contexts.append(contexts_i)
2✔
506

507
    # List of tuples of merged (x_c, y_c) along batch dim with padding
508
    # (up to the smallest multiple of `multiple` greater than the number of contexts in each task)
509
    merged_context = [
2✔
510
        deepsensor.backend.nps.merge_contexts(
511
            *[context_set for context_set in contexts_i], multiple=multiple
512
        )
513
        for contexts_i in zip(*contexts)
514
    ]
515

516
    merged_task = copy.deepcopy(tasks[0])
2✔
517

518
    # Convert list of tuples of (x_c, y_c) to list of x_c and list of y_c
519
    merged_task["X_c"] = [c[0] for c in merged_context]
2✔
520
    merged_task["Y_c"] = [c[1] for c in merged_context]
2✔
521

522
    # This assumes that all tasks have the same number of targets
523
    for i in range(n_target_sets):
2✔
524
        if isinstance(tasks[0]["X_t"][i], tuple):
2✔
525
            # Target set is gridded with tuple of coords for `X_t`
526
            merged_task["X_t"][i] = (
×
527
                B.concat(*[t["X_t"][i][0] for t in tasks], axis=0),
528
                B.concat(*[t["X_t"][i][1] for t in tasks], axis=0),
529
            )
530
        else:
531
            # Target set is off-the-grid with tensor for `X_t`
532
            merged_task["X_t"][i] = B.concat(*[t["X_t"][i] for t in tasks], axis=0)
2✔
533
        merged_task["Y_t"][i] = B.concat(*[t["Y_t"][i] for t in tasks], axis=0)
2✔
534

535
    merged_task["time"] = [t["time"] for t in tasks]
2✔
536

537
    merged_task = Task(merged_task)
2✔
538

539
    # Apply masking
540
    merged_task = merged_task.mask_nans_numpy()
2✔
541
    merged_task = merged_task.mask_nans_nps()
2✔
542

543
    return merged_task
2✔
544

545

546
if __name__ == "__main__":  # pragma: no cover
547
    # print working directory
548
    import os
549

550
    print(os.path.abspath(os.getcwd()))
551

552
    import deepsensor.tensorflow as deepsensor
553
    from deepsensor.data.processor import DataProcessor
554
    from deepsensor.data.loader import TaskLoader
555
    from deepsensor.model.convnp import ConvNP
556
    from deepsensor.data.task import concat_tasks
557

558
    import xarray as xr
559
    import numpy as np
560

561
    da_raw = xr.tutorial.open_dataset("air_temperature")
562
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
563
    da = data_processor(da_raw)
564

565
    task_loader = TaskLoader(context=da, target=da)
566

567
    task1 = task_loader("2014-01-01", 50)
568
    task1["Y_c"][0][0, 0] = np.nan
569
    task2 = task_loader("2014-01-01", 100)
570

571
    # task1 = task1.add_batch_dim().mask_nans_numpy().mask_nans_nps()
572
    # task2 = task2.add_batch_dim().mask_nans_numpy().mask_nans_nps()
573

574
    merged_task = concat_tasks([task1, task2])
575
    print(repr(merged_task))
576

577
    print("got here")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc