• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 9828217984

07 Jul 2024 02:59PM UTC coverage: 81.333%. Remained the same
9828217984

push

github

web-flow
use unittest's `setUpClass` instead of overriding `__init__` (#117)

1965 of 2416 relevant lines covered (81.33%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.11
/deepsensor/data/task.py
1
import deepsensor
2✔
2

3
from typing import Callable, Union, Tuple, List, Optional
2✔
4
import numpy as np
2✔
5
import lab as B
2✔
6
import plum
2✔
7
import copy
2✔
8

9
from ..errors import TaskSetIndexError, GriddedDataError
2✔
10

11

12
class Task(dict):
2✔
13
    """
14
    Task dictionary class.
15

16
    Inherits from ``dict`` and adds methods for printing and modifying the
17
    data.
18

19
    Args:
20
        task_dict (dict):
21
            Dictionary containing the task.
22
    """
23

24
    def __init__(self, task_dict: dict) -> None:
2✔
25
        super().__init__(task_dict)
2✔
26

27
        if "ops" not in self:
2✔
28
            # List of operations (str) to indicate how the task has been modified
29
            #   (reshaping, casting, etc)
30
            self["ops"] = []
×
31

32
    @classmethod
2✔
33
    def summarise_str(cls, k, v):
2✔
34
        if plum.isinstance(v, B.Numeric):
2✔
35
            return v.shape
2✔
36
        elif plum.isinstance(v, tuple):
2✔
37
            return tuple(vi.shape for vi in v)
×
38
        elif plum.isinstance(v, list):
2✔
39
            return [cls.summarise_str(k, vi) for vi in v]
2✔
40
        else:
41
            return v
2✔
42

43
    @classmethod
2✔
44
    def summarise_repr(cls, k, v) -> str:
2✔
45
        """
46
        Summarise the task in a representation that can be printed.
47

48
        Args:
49
            cls (:class:`deepsensor.data.task.Task`:):
50
                Task class.
51
            k (str):
52
                Key of the task dictionary.
53
            v (object):
54
                Value of the task dictionary.
55

56
        Returns:
57
            str: String representation of the task.
58
        """
59
        if v is None:
×
60
            return "None"
×
61
        elif plum.isinstance(v, B.Numeric):
×
62
            return f"{type(v).__name__}/{v.dtype}/{v.shape}"
×
63
        if plum.isinstance(v, deepsensor.backend.nps.mask.Masked):
×
64
            return f"{type(v).__name__}/(y={v.y.dtype}/{v.y.shape})/(mask={v.mask.dtype}/{v.mask.shape})"
×
65
        elif plum.isinstance(v, tuple):
×
66
            # return tuple(vi.shape for vi in v)
67
            return tuple([cls.summarise_repr(k, vi) for vi in v])
×
68
        elif plum.isinstance(v, list):
×
69
            return [cls.summarise_repr(k, vi) for vi in v]
×
70
        else:
71
            return f"{type(v).__name__}/{v}"
×
72

73
    def __str__(self) -> str:
2✔
74
        """
75
        Print a convenient summary of the task dictionary.
76

77
        For array entries, print their shape, otherwise print the value.
78
        """
79
        s = ""
2✔
80
        for k, v in self.items():
2✔
81
            if v is None:
2✔
82
                continue
×
83
            s += f"{k}: {Task.summarise_str(k, v)}\n"
2✔
84
        return s
2✔
85

86
    def __repr__(self) -> str:
2✔
87
        """
88
        Print a convenient summary of the task dictionary.
89

90
        Print the type of each entry and if it is an array, print its shape,
91
        otherwise print the value.
92
        """
93
        s = ""
×
94
        for k, v in self.items():
×
95
            s += f"{k}: {Task.summarise_repr(k, v)}\n"
×
96
        return s
×
97

98
    def op(self, f: Callable, op_flag: Optional[str] = None):
2✔
99
        """
100
        Apply function f to the array elements of a task dictionary.
101

102
        Useful for recasting to a different dtype or reshaping (e.g. adding a
103
        batch dimension).
104

105
        Args:
106
            f (callable):
107
                Function to apply to the array elements of the task.
108
            op_flag (str):
109
                Flag to set in the task dictionary's `ops` key.
110

111
        Returns:
112
            :class:`deepsensor.data.task.Task`:
113
                Task with f applied to the array elements and op_flag set in
114
                the ``ops`` key.
115
        """
116

117
        def recurse(k, v):
2✔
118
            if type(v) is list:
2✔
119
                return [recurse(k, vi) for vi in v]
2✔
120
            elif type(v) is tuple:
2✔
121
                return (recurse(k, v[0]), recurse(k, v[1]))
2✔
122
            elif isinstance(
2✔
123
                v,
124
                (np.ndarray, np.ma.MaskedArray, deepsensor.backend.nps.Masked),
125
            ):
126
                return f(v)
2✔
127
            else:
128
                return v  # covers metadata entries
2✔
129

130
        self = copy.deepcopy(self)  # don't modify the original
2✔
131
        for k, v in self.items():
2✔
132
            self[k] = recurse(k, v)
2✔
133
        self["ops"].append(op_flag)
2✔
134

135
        return self  # altered by reference, but return anyway
2✔
136

137
    def add_batch_dim(self):
2✔
138
        """
139
        Add a batch dimension to the arrays in the task dictionary.
140

141
        Returns:
142
            :class:`deepsensor.data.task.Task`:
143
                Task with batch dimension added to the array elements.
144
        """
145
        return self.op(lambda x: x[None, ...], op_flag="batch_dim")
2✔
146

147
    def cast_to_float32(self):
2✔
148
        """
149
        Cast the arrays in the task dictionary to float32.
150

151
        Returns:
152
            :class:`deepsensor.data.task.Task`:
153
                Task with arrays cast to float32.
154
        """
155
        return self.op(lambda x: x.astype(np.float32), op_flag="float32")
2✔
156

157
    def flatten_gridded_data(self):
2✔
158
        """
159
        Convert any gridded data in ``Task`` to flattened arrays.
160

161
        Necessary for AR sampling, which doesn't yet permit gridded context sets.
162

163
        Args:
164
            task : :class:`~.data.task.Task`
165
                ...
166

167
        Returns:
168
            :class:`deepsensor.data.task.Task`:
169
                ...
170
        """
171
        self["X_c"] = [flatten_X(X) for X in self["X_c"]]
2✔
172
        self["Y_c"] = [flatten_Y(Y) for Y in self["Y_c"]]
2✔
173
        if self["X_t"] is not None:
2✔
174
            self["X_t"] = [flatten_X(X) for X in self["X_t"]]
2✔
175
        if self["Y_t"] is not None:
2✔
176
            self["Y_t"] = [flatten_Y(Y) for Y in self["Y_t"]]
×
177

178
        self["ops"].append("gridded_data_flattened")
2✔
179

180
        return self
2✔
181

182
    def remove_context_nans(self):
2✔
183
        """
184
        If NaNs are present in task["Y_c"], remove them (and corresponding task["X_c"])
185

186
        Returns:
187
            :class:`deepsensor.data.task.Task`:
188
                ...
189
        """
190
        if "batch_dim" in self["ops"]:
×
191
            raise ValueError(
×
192
                "Cannot remove NaNs from task if a batch dim has been added."
193
            )
194

195
        # First check whether there are any NaNs that we need to remove
196
        nans_present = False
×
197
        for Y_c in self["Y_c"]:
×
198
            if B.any(B.isnan(Y_c)):
×
199
                nans_present = True
×
200
                break
×
201

202
        if not nans_present:
×
203
            return self
×
204

205
        # NaNs present in self - remove NaNs
206
        for i, (X, Y) in enumerate(zip(self["X_c"], self["Y_c"])):
×
207
            Y_c_nans = B.isnan(Y)
×
208
            if B.any(Y_c_nans):
×
209
                if isinstance(X, tuple):
×
210
                    # Gridded data - need to flatten to remove NaNs
211
                    X = flatten_X(X)
×
212
                    Y = flatten_Y(Y)
×
213
                    Y_c_nans = flatten_Y(Y_c_nans)
×
214
                Y_c_nans = B.any(Y_c_nans, axis=0)  # shape (n_cargets,)
×
215
                self["X_c"][i] = X[:, ~Y_c_nans]
×
216
                self["Y_c"][i] = Y[:, ~Y_c_nans]
×
217

218
        self["ops"].append("context_nans_removed")
×
219

220
        return self
×
221

222
    def remove_target_nans(self):
2✔
223
        """
224
        If NaNs are present in task["Y_t"], remove them (and corresponding task["X_t"])
225

226
        Returns:
227
            :class:`deepsensor.data.task.Task`:
228
                ...
229
        """
230
        if "batch_dim" in self["ops"]:
2✔
231
            raise ValueError(
×
232
                "Cannot remove NaNs from task if a batch dim has been added."
233
            )
234

235
        # First check whether there are any NaNs that we need to remove
236
        nans_present = False
2✔
237
        for Y_t in self["Y_t"]:
2✔
238
            if B.any(B.isnan(Y_t)):
2✔
239
                nans_present = True
2✔
240
                break
2✔
241

242
        if not nans_present:
2✔
243
            return self
2✔
244

245
        # NaNs present in self - remove NaNs
246
        for i, (X, Y) in enumerate(zip(self["X_t"], self["Y_t"])):
2✔
247
            Y_t_nans = B.isnan(Y)
2✔
248
            if "Y_t_aux" in self.keys():
2✔
249
                self["Y_t_aux"] = flatten_Y(self["Y_t_aux"])
×
250
            if B.any(Y_t_nans):
2✔
251
                if isinstance(X, tuple):
2✔
252
                    # Gridded data - need to flatten to remove NaNs
253
                    X = flatten_X(X)
×
254
                    Y = flatten_Y(Y)
×
255
                    Y_t_nans = flatten_Y(Y_t_nans)
×
256
                Y_t_nans = B.any(Y_t_nans, axis=0)  # shape (n_targets,)
2✔
257
                self["X_t"][i] = X[:, ~Y_t_nans]
2✔
258
                self["Y_t"][i] = Y[:, ~Y_t_nans]
2✔
259
                if "Y_t_aux" in self.keys():
2✔
260
                    self["Y_t_aux"] = self["Y_t_aux"][:, ~Y_t_nans]
×
261

262
        self["ops"].append("target_nans_removed")
2✔
263

264
        return self
2✔
265

266
    def mask_nans_numpy(self):
2✔
267
        """
268
        Replace NaNs with zeroes and set a mask to indicate where the NaNs
269
        were.
270

271
        Returns:
272
            :class:`deepsensor.data.task.Task`:
273
                Task with NaNs set to zeros and a mask indicating where the
274
                missing values are.
275
        """
276
        if "batch_dim" not in self["ops"]:
2✔
277
            raise ValueError("Must call `add_batch_dim` before `mask_nans_numpy`")
×
278

279
        def f(arr):
2✔
280
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
281
                nps_mask = arr.mask == 0
2✔
282
                nan_mask = np.isnan(arr.y)
2✔
283
                mask = np.logical_or(nps_mask, nan_mask)
2✔
284
                mask = np.any(mask, axis=1, keepdims=True)
2✔
285
                data = arr.y
2✔
286
                data[nan_mask] = 0.0
2✔
287
                arr = deepsensor.backend.nps.Masked(data, mask)
2✔
288
            else:
289
                mask = np.isnan(arr)
2✔
290
                if np.any(mask):
2✔
291
                    # arr = np.ma.MaskedArray(arr, mask=mask, fill_value=0.0)
292
                    arr = np.ma.fix_invalid(arr, fill_value=0.0)
2✔
293
            return arr
2✔
294

295
        return self.op(lambda x: f(x), op_flag="numpy_mask")
2✔
296

297
    def mask_nans_nps(self):
2✔
298
        """
299
        ...
300

301
        Returns:
302
            :class:`deepsensor.data.task.Task`:
303
                ...
304
        """
305
        if "batch_dim" not in self["ops"]:
2✔
306
            raise ValueError("Must call `add_batch_dim` before `mask_nans_nps`")
×
307
        if "numpy_mask" not in self["ops"]:
2✔
308
            raise ValueError("Must call `mask_nans_numpy` before `mask_nans_nps`")
×
309

310
        def f(arr):
2✔
311
            if isinstance(arr, np.ma.MaskedArray):
2✔
312
                # Mask array (True for observed, False for missing). Keep size 1 variable dim.
313
                mask = ~B.any(arr.mask, axis=1, squeeze=False)
2✔
314
                mask = B.cast(B.dtype(arr.data), mask)
2✔
315
                arr = deepsensor.backend.nps.Masked(arr.data, mask)
2✔
316
            return arr
2✔
317

318
        return self.op(lambda x: f(x), op_flag="nps_mask")
2✔
319

320
    def convert_to_tensor(self):
2✔
321
        """
322
        Convert to tensor object based on deep learning backend.
323

324
        Returns:
325
            :class:`deepsensor.data.task.Task`:
326
                Task with arrays converted to deep learning tensor objects.
327
        """
328

329
        def f(arr):
2✔
330
            if isinstance(arr, deepsensor.backend.nps.Masked):
2✔
331
                arr = deepsensor.backend.nps.Masked(
2✔
332
                    deepsensor.backend.convert_to_tensor(arr.y),
333
                    deepsensor.backend.convert_to_tensor(arr.mask),
334
                )
335
            else:
336
                arr = deepsensor.backend.convert_to_tensor(arr)
2✔
337
            return arr
2✔
338

339
        return self.op(lambda x: f(x), op_flag="tensor")
2✔
340

341

342
def append_obs_to_task(
2✔
343
    task: Task,
344
    X_new: B.Numeric,
345
    Y_new: B.Numeric,
346
    context_set_idx: int,
347
):
348
    """
349
    Append a single observation to a context set in ``task``.
350

351
    Makes a deep copy of the data structure to avoid affecting the original
352
    object.
353

354
    ..
355
        TODO: for speed during active learning algs, consider a shallow copy
356
        option plus ability to remove observations.
357

358
    Args:
359
        task (:class:`deepsensor.data.task.Task`:): The task to modify.
360
        X_new (array-like): New observation coordinates.
361
        Y_new (array-like): New observation values.
362
        context_set_idx (int): Index of the context set to append to.
363

364
    Returns:
365
        :class:`deepsensor.data.task.Task`:
366
            Task with new observation appended to the context set.
367
    """
368
    if not 0 <= context_set_idx <= len(task["X_c"]) - 1:
2✔
369
        raise TaskSetIndexError(context_set_idx, len(task["X_c"]), "context")
2✔
370

371
    if isinstance(task["X_c"][context_set_idx], tuple):
2✔
372
        raise GriddedDataError("Cannot append to gridded data")
2✔
373

374
    task_with_new = copy.deepcopy(task)
2✔
375

376
    if Y_new.ndim == 0:
2✔
377
        # Add size-1 observation and data dimension
378
        Y_new = Y_new[None, None]
×
379

380
    # Add size-1 observation dimension
381
    if X_new.ndim == 1:
2✔
382
        X_new = X_new[:, None]
2✔
383
    if Y_new.ndim == 1:
2✔
384
        Y_new = Y_new[:, None]
2✔
385

386
    # Context set with proposed latent sensors
387
    task_with_new["X_c"][context_set_idx] = np.concatenate(
2✔
388
        [task["X_c"][context_set_idx], X_new], axis=-1
389
    )
390

391
    # Append proxy observations
392
    task_with_new["Y_c"][context_set_idx] = np.concatenate(
2✔
393
        [task["Y_c"][context_set_idx], Y_new], axis=-1
394
    )
395

396
    return task_with_new
2✔
397

398

399
def flatten_X(X: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
400
    """
401
    Convert tuple of gridded coords to (2, N) array if necessary.
402

403
    Args:
404
        X (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
405
            ...
406

407
    Returns:
408
        :class:`numpy:numpy.ndarray`
409
            ...
410
    """
411
    if type(X) is tuple:
2✔
412
        X1, X2 = np.meshgrid(X[0], X[1], indexing="ij")
2✔
413
        X = np.stack([X1.ravel(), X2.ravel()], axis=0)
2✔
414
    return X
2✔
415

416

417
def flatten_Y(Y: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]) -> np.ndarray:
2✔
418
    """
419
    Convert gridded data of shape (N_dim, N_x1, N_x2) to (N_dim, N_x1 * N_x2)
420
    array if necessary.
421

422
    Args:
423
        Y (:class:`numpy:numpy.ndarray` | Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]):
424
            ...
425

426
    Returns:
427
        :class:`numpy:numpy.ndarray`
428
            ...
429
    """
430
    if Y.ndim == 3:
2✔
431
        Y = Y.reshape(*Y.shape[:-2], -1)
×
432
    return Y
2✔
433

434

435
def concat_tasks(tasks: List[Task], multiple: int = 1) -> Task:
2✔
436
    """
437
    Concatenate a list of tasks into a single task containing multiple batches.
438

439
    ..
440
        TODO:
441
        - Consider moving to ``nps.py`` as this leverages ``neuralprocesses``
442
          functionality.
443
        - Raise error if ``aux_t`` values passed (not supported I don't think)
444

445
    Args:
446
        tasks (List[:class:`deepsensor.data.task.Task`:]):
447
            List of tasks to concatenate into a single task.
448
        multiple (int, optional):
449
            Contexts are padded to the smallest multiple of this number that is
450
            greater than the number of contexts in each task. Defaults to 1
451
            (padded to the largest number of contexts in the tasks). Setting
452
            to a larger number will increase the amount of padding but decrease
453
            the range of tensor shapes presented to the model, which simplifies
454
            the computational graph in graph mode.
455

456
    Returns:
457
        :class:`~.data.task.Task`: Task containing multiple batches.
458

459
    Raises:
460
        ValueError:
461
            If the tasks have different numbers of target sets.
462
        ValueError:
463
            If the tasks have different numbers of targets.
464
        ValueError:
465
            If the tasks have different types of target sets (gridded/
466
            non-gridded).
467
    """
468
    if len(tasks) == 1:
2✔
469
        return tasks[0]
×
470

471
    for i, task in enumerate(tasks):
2✔
472
        if "numpy_mask" in task["ops"] or "nps_mask" in task["ops"]:
2✔
473
            raise ValueError(
×
474
                "Cannot concatenate tasks that have had NaNs masked. "
475
                "Masking will be applied automatically after concatenation."
476
            )
477
        if "target_nans_removed" not in task["ops"]:
2✔
478
            task = task.remove_target_nans()
2✔
479
        if "batch_dim" not in task["ops"]:
2✔
480
            task = task.add_batch_dim()
2✔
481
        if "float32" not in task["ops"]:
2✔
482
            task = task.cast_to_float32()
2✔
483
        tasks[i] = task
2✔
484

485
    # Assert number of target sets equal
486
    n_target_sets = [len(task["Y_t"]) for task in tasks]
2✔
487
    if not all([n == n_target_sets[0] for n in n_target_sets]):
2✔
488
        raise ValueError(
×
489
            f"All tasks must have the same number of target sets to concatenate: got {n_target_sets}. "
490
        )
491
    n_target_sets = n_target_sets[0]
2✔
492

493
    for target_set_i in range(n_target_sets):
2✔
494
        # Raise error if target sets have different numbers of targets across tasks
495
        n_target_obs = [task["Y_t"][target_set_i].size for task in tasks]
2✔
496
        if not all([n == n_target_obs[0] for n in n_target_obs]):
2✔
497
            raise ValueError(
2✔
498
                f"All tasks must have the same number of targets to concatenate: got {n_target_obs}. "
499
                "To train with Task batches containing differing numbers of targets, "
500
                "run the model individually over each task and average the losses."
501
            )
502

503
        # Raise error if target sets are different types (gridded/non-gridded) across tasks
504
        if isinstance(tasks[0]["X_t"][target_set_i], tuple):
2✔
505
            for task in tasks:
×
506
                if not isinstance(task["X_t"][target_set_i], tuple):
×
507
                    raise ValueError(
×
508
                        "All tasks must have the same type of target set (gridded or non-gridded) "
509
                        f"to concatenate. For target set {target_set_i}, got {type(task['X_t'][target_set_i])}."
510
                    )
511

512
    # For each task, store list of tuples of (x_c, y_c) (one tuple per context set)
513
    contexts = []
2✔
514
    for i, task in enumerate(tasks):
2✔
515
        contexts_i = list(zip(task["X_c"], task["Y_c"]))
2✔
516
        contexts.append(contexts_i)
2✔
517

518
    # List of tuples of merged (x_c, y_c) along batch dim with padding
519
    # (up to the smallest multiple of `multiple` greater than the number of contexts in each task)
520
    merged_context = [
2✔
521
        deepsensor.backend.nps.merge_contexts(
522
            *[context_set for context_set in contexts_i], multiple=multiple
523
        )
524
        for contexts_i in zip(*contexts)
525
    ]
526

527
    merged_task = copy.deepcopy(tasks[0])
2✔
528

529
    # Convert list of tuples of (x_c, y_c) to list of x_c and list of y_c
530
    merged_task["X_c"] = [c[0] for c in merged_context]
2✔
531
    merged_task["Y_c"] = [c[1] for c in merged_context]
2✔
532

533
    # This assumes that all tasks have the same number of targets
534
    for i in range(n_target_sets):
2✔
535
        if isinstance(tasks[0]["X_t"][i], tuple):
2✔
536
            # Target set is gridded with tuple of coords for `X_t`
537
            merged_task["X_t"][i] = (
×
538
                B.concat(*[t["X_t"][i][0] for t in tasks], axis=0),
539
                B.concat(*[t["X_t"][i][1] for t in tasks], axis=0),
540
            )
541
        else:
542
            # Target set is off-the-grid with tensor for `X_t`
543
            merged_task["X_t"][i] = B.concat(*[t["X_t"][i] for t in tasks], axis=0)
2✔
544
        merged_task["Y_t"][i] = B.concat(*[t["Y_t"][i] for t in tasks], axis=0)
2✔
545

546
    merged_task["time"] = [t["time"] for t in tasks]
2✔
547

548
    merged_task = Task(merged_task)
2✔
549

550
    # Apply masking
551
    merged_task = merged_task.mask_nans_numpy()
2✔
552
    merged_task = merged_task.mask_nans_nps()
2✔
553

554
    return merged_task
2✔
555

556

557
if __name__ == "__main__":  # pragma: no cover
558
    # print working directory
559
    import os
560

561
    print(os.path.abspath(os.getcwd()))
562

563
    import deepsensor.tensorflow as deepsensor
564
    from deepsensor.data.processor import DataProcessor
565
    from deepsensor.data.loader import TaskLoader
566
    from deepsensor.model.convnp import ConvNP
567
    from deepsensor.data.task import concat_tasks
568

569
    import xarray as xr
570
    import numpy as np
571

572
    da_raw = xr.tutorial.open_dataset("air_temperature")
573
    data_processor = DataProcessor(x1_name="lat", x2_name="lon")
574
    da = data_processor(da_raw)
575

576
    task_loader = TaskLoader(context=da, target=da)
577

578
    task1 = task_loader("2014-01-01", 50)
579
    task1["Y_c"][0][0, 0] = np.nan
580
    task2 = task_loader("2014-01-01", 100)
581

582
    # task1 = task1.add_batch_dim().mask_nans_numpy().mask_nans_nps()
583
    # task2 = task2.add_batch_dim().mask_nans_numpy().mask_nans_nps()
584

585
    merged_task = concat_tasks([task1, task2])
586
    print(repr(merged_task))
587

588
    print("got here")
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc