• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 19842460617

08 Oct 2025 10:03AM UTC coverage: 81.663%. Remained the same
19842460617

push

github

web-flow
Update README.md, adding reference to GIANT project

2053 of 2514 relevant lines covered (81.66%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.43
/deepsensor/model/pred.py
1
from typing import Union, List, Optional
2✔
2

3
import numpy as np
2✔
4
import pandas as pd
2✔
5
import xarray as xr
2✔
6

7
Timestamp = Union[str, pd.Timestamp, np.datetime64]
2✔
8

9

10
class Prediction(dict):
2✔
11
    """Object to store model predictions in a dictionary-like format.
12

13
    Maps from target variable IDs to xarray/pandas objects containing
14
    prediction parameters (depending on the output distribution of the model).
15

16
    For example, if the model outputs a Gaussian distribution, then the xarray/pandas
17
    objects in the ``Prediction`` will contain a ``mean`` and ``std``.
18

19
    If using a ``Prediction`` to store model samples, there is only a ``samples`` entry, and the
20
    xarray/pandas objects will have an additional ``sample`` dimension.
21

22
    Args:
23
        target_var_IDs (List[str])
24
            List of target variable IDs.
25
        dates (List[Union[str, pd.Timestamp]])
26
            List of dates corresponding to the predictions.
27
        X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`)
28
            Target locations to predict at. Can be an xarray object containing
29
            on-grid locations or a pandas object containing off-grid locations.
30
        X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional)
31
            2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
32
            to the same grid as ``X_t``. Default None (no mask).
33
        n_samples (int)
34
            Number of joint samples to draw from the model. If 0, will not
35
            draw samples. Default 0.
36
        forecasting_mode (bool)
37
            If True, stored forecast predictions with an init_time and lead_time dimension,
38
            and a valid_time coordinate. If False, stores prediction at t=0 only
39
            (i.e. spatial interpolation), with only a single time dimension. Default False.
40
        lead_times (List[pd.Timedelta], optional)
41
            List of lead times to store in predictions. Must be provided if
42
            forecasting_mode is True. Default None.
43
    """
44

45
    def __init__(
2✔
46
        self,
47
        target_var_IDs: List[str],
48
        pred_params: List[str],
49
        dates: List[Timestamp],
50
        X_t: Union[
51
            xr.Dataset,
52
            xr.DataArray,
53
            pd.DataFrame,
54
            pd.Series,
55
            pd.Index,
56
            np.ndarray,
57
        ],
58
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
59
        coord_names: dict = None,
60
        n_samples: int = 0,
61
        forecasting_mode: bool = False,
62
        lead_times: Optional[List[pd.Timedelta]] = None,
63
    ):
64
        self.target_var_IDs = target_var_IDs
2✔
65
        self.X_t_mask = X_t_mask
2✔
66
        if coord_names is None:
2✔
67
            coord_names = {"x1": "x1", "x2": "x2"}
×
68
        self.x1_name = coord_names["x1"]
2✔
69
        self.x2_name = coord_names["x2"]
2✔
70

71
        self.forecasting_mode = forecasting_mode
2✔
72
        if forecasting_mode:
2✔
73
            assert (
2✔
74
                lead_times is not None
75
            ), "If forecasting_mode is True, lead_times must be provided."
76
        self.lead_times = lead_times
2✔
77

78
        self.mode = infer_prediction_modality_from_X_t(X_t)
2✔
79

80
        self.pred_params = pred_params
2✔
81
        if n_samples >= 1:
2✔
82
            self.pred_params = [
2✔
83
                *pred_params,
84
                *[f"sample_{i}" for i in range(n_samples)],
85
            ]
86

87
        # Create empty xarray/pandas objects to store predictions
88
        if self.mode == "on-grid":
2✔
89
            for var_ID in self.target_var_IDs:
2✔
90
                if self.forecasting_mode:
2✔
91
                    prepend_dims = ["lead_time"]
2✔
92
                    prepend_coords = {"lead_time": lead_times}
2✔
93
                else:
94
                    prepend_dims = None
2✔
95
                    prepend_coords = None
2✔
96
                self[var_ID] = create_empty_spatiotemporal_xarray(
2✔
97
                    X_t,
98
                    dates,
99
                    data_vars=self.pred_params,
100
                    coord_names=coord_names,
101
                    prepend_dims=prepend_dims,
102
                    prepend_coords=prepend_coords,
103
                )
104
                if self.forecasting_mode:
2✔
105
                    self[var_ID] = self[var_ID].rename(time="init_time")
2✔
106
            if self.X_t_mask is None:
2✔
107
                # Create 2D boolean array of True values to simplify indexing
108
                self.X_t_mask = (
2✔
109
                    create_empty_spatiotemporal_xarray(X_t, dates[0:1], coord_names)
110
                    .to_array()
111
                    .isel(time=0, variable=0)
112
                    .astype(bool)
113
                )
114
        elif self.mode == "off-grid":
2✔
115
            # Repeat target locs for each date to create multiindex
116
            if self.forecasting_mode:
2✔
117
                index_names = ["lead_time", "init_time", *X_t.index.names]
2✔
118
                idxs = [
2✔
119
                    (lt, date, *idxs)
120
                    for lt in lead_times
121
                    for date in dates
122
                    for idxs in X_t.index
123
                ]
124
            else:
125
                index_names = ["time", *X_t.index.names]
2✔
126
                idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
2✔
127
            index = pd.MultiIndex.from_tuples(idxs, names=index_names)
2✔
128
            for var_ID in self.target_var_IDs:
2✔
129
                self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
2✔
130

131
    def __getitem__(self, key):
2✔
132
        # Support self[i] syntax
133
        if isinstance(key, int):
2✔
134
            key = self.target_var_IDs[key]
2✔
135
        return super().__getitem__(key)
2✔
136

137
    def __str__(self):
2✔
138
        dict_repr = {var_ID: self.pred_params for var_ID in self.target_var_IDs}
×
139
        return f"Prediction({dict_repr}), mode={self.mode}"
×
140

141
    def assign(
2✔
142
        self,
143
        prediction_parameter: str,
144
        date: Union[str, pd.Timestamp],
145
        data: np.ndarray,
146
        lead_times: Optional[List[pd.Timedelta]] = None,
147
    ):
148
        """Args:
149
        prediction_parameter (str)
150
            ...
151
        date (Union[str, pd.Timestamp])
152
            ...
153
        data (np.ndarray)
154
            If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
155
            If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
156
        lead_time (pd.Timedelta, optional)
157
            Lead time of the forecast. Required if forecasting_mode is True. Default None.
158
        """
159
        if self.forecasting_mode:
2✔
160
            assert (
2✔
161
                lead_times is not None
162
            ), "If forecasting_mode is True, lead_times must be provided."
163

164
            msg = f"""
2✔
165
            If forecasting_mode is True, lead_times must be of equal length to the number of
166
            variables in the data (the first dimension). Got {lead_times=} of length
167
            {len(lead_times)} lead times and data shape {data.shape}.
168
            """
169
            assert len(lead_times) == data.shape[0], msg
2✔
170

171
        if self.mode == "on-grid":
2✔
172
            if prediction_parameter != "samples":
2✔
173
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
174
                    if self.forecasting_mode:
2✔
175
                        index = (lead_times[i], date)
2✔
176
                    else:
177
                        index = date
2✔
178
                    self[var_ID][prediction_parameter].loc[index].data[
2✔
179
                        self.X_t_mask.data
180
                    ] = pred.ravel()
181
            elif prediction_parameter == "samples":
2✔
182
                assert len(data.shape) == 4, (
2✔
183
                    f"If prediction_parameter is 'samples', and mode is 'on-grid', data must"
184
                    f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
185
                )
186
                for sample_i, sample in enumerate(data):
2✔
187
                    for i, (var_ID, pred) in enumerate(
2✔
188
                        zip(self.target_var_IDs, sample)
189
                    ):
190
                        if self.forecasting_mode:
2✔
191
                            index = (lead_times[i], date)
×
192
                        else:
193
                            index = date
2✔
194
                        self[var_ID][f"sample_{sample_i}"].loc[index].data[
2✔
195
                            self.X_t_mask.data
196
                        ] = pred.ravel()
197

198
        elif self.mode == "off-grid":
2✔
199
            if prediction_parameter != "samples":
2✔
200
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
201
                    if self.forecasting_mode:
2✔
202
                        index = (lead_times[i], date)
2✔
203
                    else:
204
                        index = date
2✔
205
                    self[var_ID].loc[index, prediction_parameter] = pred
2✔
206
            elif prediction_parameter == "samples":
2✔
207
                assert len(data.shape) == 3, (
2✔
208
                    f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
209
                    f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
210
                )
211
                for sample_i, sample in enumerate(data):
2✔
212
                    for i, (var_ID, pred) in enumerate(
2✔
213
                        zip(self.target_var_IDs, sample)
214
                    ):
215
                        if self.forecasting_mode:
2✔
216
                            index = (lead_times[i], date)
×
217
                        else:
218
                            index = date
2✔
219
                        self[var_ID].loc[index, f"sample_{sample_i}"] = pred
2✔
220

221

222
def create_empty_spatiotemporal_xarray(
2✔
223
    X: Union[xr.Dataset, xr.DataArray],
224
    dates: List[Timestamp],
225
    coord_names: dict = None,
226
    data_vars: List[str] = None,
227
    prepend_dims: Optional[List[str]] = None,
228
    prepend_coords: Optional[dict] = None,
229
):
230
    """...
231

232
    Args:
233
        X (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
234
            ...
235
        dates (List[...]):
236
            ...
237
        coord_names (dict, optional):
238
            Dict mapping from normalised coord names to raw coord names,
239
            by default {"x1": "x1", "x2": "x2"}
240
        data_vars (List[str], optional):
241
            ..., by default ["var"]
242
        prepend_dims (List[str], optional):
243
            ..., by default None
244
        prepend_coords (dict, optional):
245
            ..., by default None
246

247
    Returns:
248
        ...
249
            ...
250

251
    Raises:
252
        ValueError
253
            If ``data_vars`` contains duplicate values.
254
        ValueError
255
            If ``coord_names["x1"]`` is not uniformly spaced.
256
        ValueError
257
            If ``coord_names["x2"]`` is not uniformly spaced.
258
        ValueError
259
            If ``prepend_dims`` and ``prepend_coords`` are not the same length.
260
    """
261
    if coord_names is None:
2✔
262
        coord_names = {"x1": "x1", "x2": "x2"}
×
263
    if data_vars is None:
2✔
264
        data_vars = ["var"]
2✔
265

266
    if prepend_dims is None:
2✔
267
        prepend_dims = []
2✔
268
    if prepend_coords is None:
2✔
269
        prepend_coords = {}
2✔
270

271
    # Check for any repeated data_vars
272
    if len(data_vars) != len(set(data_vars)):
2✔
273
        raise ValueError(
×
274
            f"Duplicate data_vars found in data_vars: {data_vars}. "
275
            "This would cause the xarray.Dataset to have fewer variables than expected."
276
        )
277

278
    x1_predict = X.coords[coord_names["x1"]]
2✔
279
    x2_predict = X.coords[coord_names["x2"]]
2✔
280

281
    if len(prepend_dims) != len(set(prepend_dims)):
2✔
282
        # TODO unit test
283
        raise ValueError(
×
284
            f"Length of prepend_dims ({len(prepend_dims)}) must be equal to length of "
285
            f"prepend_coords ({len(prepend_coords)})."
286
        )
287

288
    dims = [*prepend_dims, "time", coord_names["x1"], coord_names["x2"]]
2✔
289
    coords = {
2✔
290
        **prepend_coords,
291
        "time": pd.to_datetime(dates),
292
        coord_names["x1"]: x1_predict,
293
        coord_names["x2"]: x2_predict,
294
    }
295

296
    pred_ds = xr.Dataset(
2✔
297
        {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars}
298
    ).astype("float32")
299

300
    # Convert time coord to pandas timestamps
301
    pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
2✔
302

303
    return pred_ds
2✔
304

305

306
def increase_spatial_resolution(
2✔
307
    X_t_normalised,
308
    resolution_factor,
309
    coord_names: dict = None,
310
):
311
    """...
312

313
    ..
314
        # TODO wasteful to interpolate X_t_normalised
315

316
    Args:
317
        X_t_normalised (...):
318
            ...
319
        resolution_factor (...):
320
            ...
321
        coord_names (dict, optional):
322
            Dict mapping from normalised coord names to raw coord names,
323
            by default {"x1": "x1", "x2": "x2"}
324

325
    Returns:
326
        ...
327
            ...
328

329
    """
330
    assert isinstance(resolution_factor, (float, int))
×
331
    assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset))
×
332
    if coord_names is None:
×
333
        coord_names = {"x1": "x1", "x2": "x2"}
×
334
    x1_name, x2_name = coord_names["x1"], coord_names["x2"]
×
335
    x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name]
×
336
    x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64")
×
337
    x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64")
×
338
    X_t_normalised = X_t_normalised.interp(
×
339
        **{x1_name: x1, x2_name: x2}, method="nearest"
340
    )
341
    return X_t_normalised
×
342

343

344
def infer_prediction_modality_from_X_t(
2✔
345
    X_t: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray],
346
) -> str:
347
    """Args:
348
        X_t (Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray]):
349
            ...
350

351
    Returns:
352
        str: "on-grid" if X_t is an xarray object, "off-grid" if X_t is a pandas or numpy object.
353

354
    Raises:
355
        ValueError
356
            If X_t is not an xarray, pandas or numpy object.
357
    """
358
    if isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
359
        mode = "on-grid"
2✔
360
    elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
361
        mode = "off-grid"
2✔
362
    else:
363
        raise ValueError(
×
364
            f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
365
        )
366
    return mode
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc