• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455747995

22 Oct 2024 07:56AM UTC coverage: 81.626% (+0.3%) from 81.333%
11455747995

push

github

davidwilby
incorporate feedback

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.43
/deepsensor/model/pred.py
1
from typing import Union, List, Optional
2✔
2

3
import numpy as np
2✔
4
import pandas as pd
2✔
5
import xarray as xr
2✔
6

7
Timestamp = Union[str, pd.Timestamp, np.datetime64]
2✔
8

9

10
class Prediction(dict):
2✔
11
    """
12
    Object to store model predictions in a dictionary-like format.
13

14
    Maps from target variable IDs to xarray/pandas objects containing
15
    prediction parameters (depending on the output distribution of the model).
16

17
    For example, if the model outputs a Gaussian distribution, then the xarray/pandas
18
    objects in the ``Prediction`` will contain a ``mean`` and ``std``.
19

20
    If using a ``Prediction`` to store model samples, there is only a ``samples`` entry, and the
21
    xarray/pandas objects will have an additional ``sample`` dimension.
22

23
    Args:
24
        target_var_IDs (List[str])
25
            List of target variable IDs.
26
        dates (List[Union[str, pd.Timestamp]])
27
            List of dates corresponding to the predictions.
28
        X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`)
29
            Target locations to predict at. Can be an xarray object containing
30
            on-grid locations or a pandas object containing off-grid locations.
31
        X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional)
32
            2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
33
            to the same grid as ``X_t``. Default None (no mask).
34
        n_samples (int)
35
            Number of joint samples to draw from the model. If 0, will not
36
            draw samples. Default 0.
37
        forecasting_mode (bool)
38
            If True, stored forecast predictions with an init_time and lead_time dimension,
39
            and a valid_time coordinate. If False, stores prediction at t=0 only
40
            (i.e. spatial interpolation), with only a single time dimension. Default False.
41
        lead_times (List[pd.Timedelta], optional)
42
            List of lead times to store in predictions. Must be provided if
43
            forecasting_mode is True. Default None.
44
    """
45

46
    def __init__(
2✔
47
        self,
48
        target_var_IDs: List[str],
49
        pred_params: List[str],
50
        dates: List[Timestamp],
51
        X_t: Union[
52
            xr.Dataset,
53
            xr.DataArray,
54
            pd.DataFrame,
55
            pd.Series,
56
            pd.Index,
57
            np.ndarray,
58
        ],
59
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
60
        coord_names: dict = None,
61
        n_samples: int = 0,
62
        forecasting_mode: bool = False,
63
        lead_times: Optional[List[pd.Timedelta]] = None,
64
    ):
65
        self.target_var_IDs = target_var_IDs
2✔
66
        self.X_t_mask = X_t_mask
2✔
67
        if coord_names is None:
2✔
68
            coord_names = {"x1": "x1", "x2": "x2"}
×
69
        self.x1_name = coord_names["x1"]
2✔
70
        self.x2_name = coord_names["x2"]
2✔
71

72
        self.forecasting_mode = forecasting_mode
2✔
73
        if forecasting_mode:
2✔
74
            assert (
2✔
75
                lead_times is not None
76
            ), "If forecasting_mode is True, lead_times must be provided."
77
        self.lead_times = lead_times
2✔
78

79
        self.mode = infer_prediction_modality_from_X_t(X_t)
2✔
80

81
        self.pred_params = pred_params
2✔
82
        if n_samples >= 1:
2✔
83
            self.pred_params = [
2✔
84
                *pred_params,
85
                *[f"sample_{i}" for i in range(n_samples)],
86
            ]
87

88
        # Create empty xarray/pandas objects to store predictions
89
        if self.mode == "on-grid":
2✔
90
            for var_ID in self.target_var_IDs:
2✔
91
                if self.forecasting_mode:
2✔
92
                    prepend_dims = ["lead_time"]
2✔
93
                    prepend_coords = {"lead_time": lead_times}
2✔
94
                else:
95
                    prepend_dims = None
2✔
96
                    prepend_coords = None
2✔
97
                self[var_ID] = create_empty_spatiotemporal_xarray(
2✔
98
                    X_t,
99
                    dates,
100
                    data_vars=self.pred_params,
101
                    coord_names=coord_names,
102
                    prepend_dims=prepend_dims,
103
                    prepend_coords=prepend_coords,
104
                )
105
                if self.forecasting_mode:
2✔
106
                    self[var_ID] = self[var_ID].rename(time="init_time")
2✔
107
            if self.X_t_mask is None:
2✔
108
                # Create 2D boolean array of True values to simplify indexing
109
                self.X_t_mask = (
2✔
110
                    create_empty_spatiotemporal_xarray(X_t, dates[0:1], coord_names)
111
                    .to_array()
112
                    .isel(time=0, variable=0)
113
                    .astype(bool)
114
                )
115
        elif self.mode == "off-grid":
2✔
116
            # Repeat target locs for each date to create multiindex
117
            if self.forecasting_mode:
2✔
118
                index_names = ["lead_time", "init_time", *X_t.index.names]
2✔
119
                idxs = [
2✔
120
                    (lt, date, *idxs)
121
                    for lt in lead_times
122
                    for date in dates
123
                    for idxs in X_t.index
124
                ]
125
            else:
126
                index_names = ["time", *X_t.index.names]
2✔
127
                idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
2✔
128
            index = pd.MultiIndex.from_tuples(idxs, names=index_names)
2✔
129
            for var_ID in self.target_var_IDs:
2✔
130
                self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
2✔
131

132
    def __getitem__(self, key):
2✔
133
        # Support self[i] syntax
134
        if isinstance(key, int):
2✔
135
            key = self.target_var_IDs[key]
2✔
136
        return super().__getitem__(key)
2✔
137

138
    def __str__(self):
2✔
139
        dict_repr = {var_ID: self.pred_params for var_ID in self.target_var_IDs}
×
140
        return f"Prediction({dict_repr}), mode={self.mode}"
×
141

142
    def assign(
2✔
143
        self,
144
        prediction_parameter: str,
145
        date: Union[str, pd.Timestamp],
146
        data: np.ndarray,
147
        lead_times: Optional[List[pd.Timedelta]] = None,
148
    ):
149
        """
150

151
        Args:
152
            prediction_parameter (str)
153
                ...
154
            date (Union[str, pd.Timestamp])
155
                ...
156
            data (np.ndarray)
157
                If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
158
                If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
159
            lead_time (pd.Timedelta, optional)
160
                Lead time of the forecast. Required if forecasting_mode is True. Default None.
161
        """
162
        if self.forecasting_mode:
2✔
163
            assert (
2✔
164
                lead_times is not None
165
            ), "If forecasting_mode is True, lead_times must be provided."
166

167
            msg = f"""
2✔
168
            If forecasting_mode is True, lead_times must be of equal length to the number of
169
            variables in the data (the first dimension). Got {lead_times=} of length
170
            {len(lead_times)} lead times and data shape {data.shape}.
171
            """
172
            assert len(lead_times) == data.shape[0], msg
2✔
173

174
        if self.mode == "on-grid":
2✔
175
            if prediction_parameter != "samples":
2✔
176
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
177
                    if self.forecasting_mode:
2✔
178
                        index = (lead_times[i], date)
2✔
179
                    else:
180
                        index = date
2✔
181
                    self[var_ID][prediction_parameter].loc[index].data[
2✔
182
                        self.X_t_mask.data
183
                    ] = pred.ravel()
184
            elif prediction_parameter == "samples":
2✔
185
                assert len(data.shape) == 4, (
2✔
186
                    f"If prediction_parameter is 'samples', and mode is 'on-grid', data must"
187
                    f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
188
                )
189
                for sample_i, sample in enumerate(data):
2✔
190
                    for i, (var_ID, pred) in enumerate(
2✔
191
                        zip(self.target_var_IDs, sample)
192
                    ):
193
                        if self.forecasting_mode:
2✔
194
                            index = (lead_times[i], date)
×
195
                        else:
196
                            index = date
2✔
197
                        self[var_ID][f"sample_{sample_i}"].loc[index].data[
2✔
198
                            self.X_t_mask.data
199
                        ] = pred.ravel()
200

201
        elif self.mode == "off-grid":
2✔
202
            if prediction_parameter != "samples":
2✔
203
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
204
                    if self.forecasting_mode:
2✔
205
                        index = (lead_times[i], date)
2✔
206
                    else:
207
                        index = date
2✔
208
                    self[var_ID].loc[index, prediction_parameter] = pred
2✔
209
            elif prediction_parameter == "samples":
2✔
210
                assert len(data.shape) == 3, (
2✔
211
                    f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
212
                    f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
213
                )
214
                for sample_i, sample in enumerate(data):
2✔
215
                    for i, (var_ID, pred) in enumerate(
2✔
216
                        zip(self.target_var_IDs, sample)
217
                    ):
218
                        if self.forecasting_mode:
2✔
219
                            index = (lead_times[i], date)
×
220
                        else:
221
                            index = date
2✔
222
                        self[var_ID].loc[index, f"sample_{sample_i}"] = pred
2✔
223

224

225
def create_empty_spatiotemporal_xarray(
2✔
226
    X: Union[xr.Dataset, xr.DataArray],
227
    dates: List[Timestamp],
228
    coord_names: dict = None,
229
    data_vars: List[str] = None,
230
    prepend_dims: Optional[List[str]] = None,
231
    prepend_coords: Optional[dict] = None,
232
):
233
    """
234
        ...
235

236
    Args:
237
        X (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
238
            ...
239
        dates (List[...]):
240
            ...
241
        coord_names (dict, optional):
242
            Dict mapping from normalised coord names to raw coord names,
243
            by default {"x1": "x1", "x2": "x2"}
244
        data_vars (List[str], optional):
245
            ..., by default ["var"]
246
        prepend_dims (List[str], optional):
247
            ..., by default None
248
        prepend_coords (dict, optional):
249
            ..., by default None
250

251
    Returns:
252
        ...
253
            ...
254

255
    Raises:
256
        ValueError
257
            If ``data_vars`` contains duplicate values.
258
        ValueError
259
            If ``coord_names["x1"]`` is not uniformly spaced.
260
        ValueError
261
            If ``coord_names["x2"]`` is not uniformly spaced.
262
        ValueError
263
            If ``prepend_dims`` and ``prepend_coords`` are not the same length.
264
    """
265
    if coord_names is None:
2✔
266
        coord_names = {"x1": "x1", "x2": "x2"}
×
267
    if data_vars is None:
2✔
268
        data_vars = ["var"]
2✔
269

270
    if prepend_dims is None:
2✔
271
        prepend_dims = []
2✔
272
    if prepend_coords is None:
2✔
273
        prepend_coords = {}
2✔
274

275
    # Check for any repeated data_vars
276
    if len(data_vars) != len(set(data_vars)):
2✔
277
        raise ValueError(
×
278
            f"Duplicate data_vars found in data_vars: {data_vars}. "
279
            "This would cause the xarray.Dataset to have fewer variables than expected."
280
        )
281

282
    x1_predict = X.coords[coord_names["x1"]]
2✔
283
    x2_predict = X.coords[coord_names["x2"]]
2✔
284

285
    if len(prepend_dims) != len(set(prepend_dims)):
2✔
286
        # TODO unit test
287
        raise ValueError(
×
288
            f"Length of prepend_dims ({len(prepend_dims)}) must be equal to length of "
289
            f"prepend_coords ({len(prepend_coords)})."
290
        )
291

292
    dims = [*prepend_dims, "time", coord_names["x1"], coord_names["x2"]]
2✔
293
    coords = {
2✔
294
        **prepend_coords,
295
        "time": pd.to_datetime(dates),
296
        coord_names["x1"]: x1_predict,
297
        coord_names["x2"]: x2_predict,
298
    }
299

300
    pred_ds = xr.Dataset(
2✔
301
        {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars}
302
    ).astype("float32")
303

304
    # Convert time coord to pandas timestamps
305
    pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
2✔
306

307
    return pred_ds
2✔
308

309

310
def increase_spatial_resolution(
2✔
311
    X_t_normalised,
312
    resolution_factor,
313
    coord_names: dict = None,
314
):
315
    """
316
    ...
317

318
    ..
319
        # TODO wasteful to interpolate X_t_normalised
320

321
    Args:
322
        X_t_normalised (...):
323
            ...
324
        resolution_factor (...):
325
            ...
326
        coord_names (dict, optional):
327
            Dict mapping from normalised coord names to raw coord names,
328
            by default {"x1": "x1", "x2": "x2"}
329

330
    Returns:
331
        ...
332
            ...
333

334
    """
335
    assert isinstance(resolution_factor, (float, int))
×
336
    assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset))
×
337
    if coord_names is None:
×
338
        coord_names = {"x1": "x1", "x2": "x2"}
×
339
    x1_name, x2_name = coord_names["x1"], coord_names["x2"]
×
340
    x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name]
×
341
    x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64")
×
342
    x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64")
×
343
    X_t_normalised = X_t_normalised.interp(
×
344
        **{x1_name: x1, x2_name: x2}, method="nearest"
345
    )
346
    return X_t_normalised
×
347

348

349
def infer_prediction_modality_from_X_t(
2✔
350
    X_t: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray]
351
) -> str:
352
    """
353

354
    Args:
355
        X_t (Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray]):
356
            ...
357

358
    Returns:
359
        str: "on-grid" if X_t is an xarray object, "off-grid" if X_t is a pandas or numpy object.
360

361
    Raises:
362
        ValueError
363
            If X_t is not an xarray, pandas or numpy object.
364
    """
365
    if isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
366
        mode = "on-grid"
2✔
367
    elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
368
        mode = "off-grid"
2✔
369
    else:
370
        raise ValueError(
×
371
            f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
372
        )
373
    return mode
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc