• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 14313171846

07 Apr 2025 03:26PM UTC coverage: 82.511% (+0.8%) from 81.663%
14313171846

Pull #135

github

web-flow
Merge fc0266596 into 38ec5ef26
Pull Request #135: Enable patchwise training and prediction

294 of 329 new or added lines in 4 files covered. (89.36%)

1 existing line in 1 file now uncovered.

2340 of 2836 relevant lines covered (82.51%)

1.65 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.5
/deepsensor/model/pred.py
1
import copy
2✔
2
from typing import Union, List, Optional, Tuple
2✔
3

4
import numpy as np
2✔
5
import pandas as pd
2✔
6
import xarray as xr
2✔
7

8
Timestamp = Union[str, pd.Timestamp, np.datetime64]
2✔
9

10

11
class Prediction(dict):
2✔
12
    """Object to store model predictions in a dictionary-like format.
13

14
    Maps from target variable IDs to xarray/pandas objects containing
15
    prediction parameters (depending on the output distribution of the model).
16

17
    For example, if the model outputs a Gaussian distribution, then the xarray/pandas
18
    objects in the ``Prediction`` will contain a ``mean`` and ``std``.
19

20
    If using a ``Prediction`` to store model samples, there is only a ``samples`` entry, and the
21
    xarray/pandas objects will have an additional ``sample`` dimension.
22

23
    Args:
24
        target_var_IDs (List[str])
25
            List of target variable IDs.
26
        dates (List[Union[str, pd.Timestamp]])
27
            List of dates corresponding to the predictions.
28
        X_t (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`)
29
            Target locations to predict at. Can be an xarray object containing
30
            on-grid locations or a pandas object containing off-grid locations.
31
        X_t_mask (:class:`xarray.Dataset` | :class:`xarray.DataArray`, optional)
32
            2D mask to apply to gridded ``X_t`` (zero/False will be NaNs). Will be interpolated
33
            to the same grid as ``X_t``. Default None (no mask).
34
        n_samples (int)
35
            Number of joint samples to draw from the model. If 0, will not
36
            draw samples. Default 0.
37
        forecasting_mode (bool)
38
            If True, stored forecast predictions with an init_time and lead_time dimension,
39
            and a valid_time coordinate. If False, stores prediction at t=0 only
40
            (i.e. spatial interpolation), with only a single time dimension. Default False.
41
        lead_times (List[pd.Timedelta], optional)
42
            List of lead times to store in predictions. Must be provided if
43
            forecasting_mode is True. Default None.
44
    """
45

46
    def __init__(
2✔
47
        self,
48
        target_var_IDs: List[str],
49
        pred_params: List[str],
50
        dates: List[Timestamp],
51
        X_t: Union[
52
            xr.Dataset,
53
            xr.DataArray,
54
            pd.DataFrame,
55
            pd.Series,
56
            pd.Index,
57
            np.ndarray,
58
        ],
59
        X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
60
        coord_names: dict = None,
61
        n_samples: int = 0,
62
        forecasting_mode: bool = False,
63
        lead_times: Optional[List[pd.Timedelta]] = None,
64
    ):
65
        self.target_var_IDs = target_var_IDs
2✔
66
        self.X_t_mask = X_t_mask
2✔
67
        if coord_names is None:
2✔
68
            coord_names = {"x1": "x1", "x2": "x2"}
×
69
        self.x1_name = coord_names["x1"]
2✔
70
        self.x2_name = coord_names["x2"]
2✔
71

72
        self.forecasting_mode = forecasting_mode
2✔
73
        if forecasting_mode:
2✔
74
            assert (
2✔
75
                lead_times is not None
76
            ), "If forecasting_mode is True, lead_times must be provided."
77
        self.lead_times = lead_times
2✔
78

79
        self.mode = infer_prediction_modality_from_X_t(X_t)
2✔
80

81
        self.pred_params = pred_params
2✔
82
        if n_samples >= 1:
2✔
83
            self.pred_params = [
2✔
84
                *pred_params,
85
                *[f"sample_{i}" for i in range(n_samples)],
86
            ]
87

88
        # Create empty xarray/pandas objects to store predictions
89
        if self.mode == "on-grid":
2✔
90
            for var_ID in self.target_var_IDs:
2✔
91
                if self.forecasting_mode:
2✔
92
                    prepend_dims = ["lead_time"]
2✔
93
                    prepend_coords = {"lead_time": lead_times}
2✔
94
                else:
95
                    prepend_dims = None
2✔
96
                    prepend_coords = None
2✔
97
                self[var_ID] = create_empty_spatiotemporal_xarray(
2✔
98
                    X_t,
99
                    dates,
100
                    data_vars=self.pred_params,
101
                    coord_names=coord_names,
102
                    prepend_dims=prepend_dims,
103
                    prepend_coords=prepend_coords,
104
                )
105
                if self.forecasting_mode:
2✔
106
                    self[var_ID] = self[var_ID].rename(time="init_time")
2✔
107
            if self.X_t_mask is None:
2✔
108
                # Create 2D boolean array of True values to simplify indexing
109
                self.X_t_mask = (
2✔
110
                    create_empty_spatiotemporal_xarray(X_t, dates[0:1], coord_names)
111
                    .to_array()
112
                    .isel(time=0, variable=0)
113
                    .astype(bool)
114
                )
115
        elif self.mode == "off-grid":
2✔
116
            # Repeat target locs for each date to create multiindex
117
            if self.forecasting_mode:
2✔
118
                index_names = ["lead_time", "init_time", *X_t.index.names]
2✔
119
                idxs = [
2✔
120
                    (lt, date, *idxs)
121
                    for lt in lead_times
122
                    for date in dates
123
                    for idxs in X_t.index
124
                ]
125
            else:
126
                index_names = ["time", *X_t.index.names]
2✔
127
                idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
2✔
128
            index = pd.MultiIndex.from_tuples(idxs, names=index_names)
2✔
129
            for var_ID in self.target_var_IDs:
2✔
130
                self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
2✔
131

132
    def __getitem__(self, key):
2✔
133
        # Support self[i] syntax
134
        if isinstance(key, int):
2✔
135
            key = self.target_var_IDs[key]
2✔
136
        return super().__getitem__(key)
2✔
137

138
    def __str__(self):
2✔
139
        dict_repr = {var_ID: self.pred_params for var_ID in self.target_var_IDs}
×
140
        return f"Prediction({dict_repr}), mode={self.mode}"
×
141

142
    def assign(
2✔
143
        self,
144
        prediction_parameter: str,
145
        date: Union[str, pd.Timestamp],
146
        data: np.ndarray,
147
        lead_times: Optional[List[pd.Timedelta]] = None,
148
    ):
149
        """Args:
150
        prediction_parameter (str)
151
            ...
152
        date (Union[str, pd.Timestamp])
153
            ...
154
        data (np.ndarray)
155
            If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
156
            If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
157
        lead_time (pd.Timedelta, optional)
158
            Lead time of the forecast. Required if forecasting_mode is True. Default None.
159
        """
160
        if self.forecasting_mode:
2✔
161
            assert (
2✔
162
                lead_times is not None
163
            ), "If forecasting_mode is True, lead_times must be provided."
164

165
            msg = f"""
2✔
166
            If forecasting_mode is True, lead_times must be of equal length to the number of
167
            variables in the data (the first dimension). Got {lead_times=} of length
168
            {len(lead_times)} lead times and data shape {data.shape}.
169
            """
170
            assert len(lead_times) == data.shape[0], msg
2✔
171

172
        if self.mode == "on-grid":
2✔
173
            if prediction_parameter != "samples":
2✔
174
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
175
                    if self.forecasting_mode:
2✔
176
                        index = (lead_times[i], date)
2✔
177
                    else:
178
                        index = date
2✔
179
                    self[var_ID][prediction_parameter].loc[index].data[
2✔
180
                        self.X_t_mask.data
181
                    ] = pred.ravel()
182
            elif prediction_parameter == "samples":
2✔
183
                assert len(data.shape) == 4, (
2✔
184
                    f"If prediction_parameter is 'samples', and mode is 'on-grid', data must"
185
                    f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
186
                )
187
                for sample_i, sample in enumerate(data):
2✔
188
                    for i, (var_ID, pred) in enumerate(
2✔
189
                        zip(self.target_var_IDs, sample)
190
                    ):
191
                        if self.forecasting_mode:
2✔
192
                            index = (lead_times[i], date)
×
193
                        else:
194
                            index = date
2✔
195
                        self[var_ID][f"sample_{sample_i}"].loc[index].data[
2✔
196
                            self.X_t_mask.data
197
                        ] = pred.ravel()
198

199
        elif self.mode == "off-grid":
2✔
200
            if prediction_parameter != "samples":
2✔
201
                for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
2✔
202
                    if self.forecasting_mode:
2✔
203
                        index = (lead_times[i], date)
2✔
204
                    else:
205
                        index = date
2✔
206
                    self[var_ID].loc[index, prediction_parameter] = pred
2✔
207
            elif prediction_parameter == "samples":
2✔
208
                assert len(data.shape) == 3, (
2✔
209
                    f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
210
                    f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
211
                )
212
                for sample_i, sample in enumerate(data):
2✔
213
                    for i, (var_ID, pred) in enumerate(
2✔
214
                        zip(self.target_var_IDs, sample)
215
                    ):
216
                        if self.forecasting_mode:
2✔
217
                            index = (lead_times[i], date)
×
218
                        else:
219
                            index = date
2✔
220
                        self[var_ID].loc[index, f"sample_{sample_i}"] = pred
2✔
221

222

223
def create_empty_spatiotemporal_xarray(
2✔
224
    X: Union[xr.Dataset, xr.DataArray],
225
    dates: List[Timestamp],
226
    coord_names: dict = None,
227
    data_vars: List[str] = None,
228
    prepend_dims: Optional[List[str]] = None,
229
    prepend_coords: Optional[dict] = None,
230
):
231
    """...
232

233
    Args:
234
        X (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
235
            ...
236
        dates (List[...]):
237
            ...
238
        coord_names (dict, optional):
239
            Dict mapping from normalised coord names to raw coord names,
240
            by default {"x1": "x1", "x2": "x2"}
241
        data_vars (List[str], optional):
242
            ..., by default ["var"]
243
        prepend_dims (List[str], optional):
244
            ..., by default None
245
        prepend_coords (dict, optional):
246
            ..., by default None
247

248
    Returns:
249
        ...
250
            ...
251

252
    Raises:
253
        ValueError
254
            If ``data_vars`` contains duplicate values.
255
        ValueError
256
            If ``coord_names["x1"]`` is not uniformly spaced.
257
        ValueError
258
            If ``coord_names["x2"]`` is not uniformly spaced.
259
        ValueError
260
            If ``prepend_dims`` and ``prepend_coords`` are not the same length.
261
    """
262
    if coord_names is None:
2✔
263
        coord_names = {"x1": "x1", "x2": "x2"}
×
264
    if data_vars is None:
2✔
265
        data_vars = ["var"]
2✔
266

267
    if prepend_dims is None:
2✔
268
        prepend_dims = []
2✔
269
    if prepend_coords is None:
2✔
270
        prepend_coords = {}
2✔
271

272
    # Check for any repeated data_vars
273
    if len(data_vars) != len(set(data_vars)):
2✔
274
        raise ValueError(
×
275
            f"Duplicate data_vars found in data_vars: {data_vars}. "
276
            "This would cause the xarray.Dataset to have fewer variables than expected."
277
        )
278

279
    x1_predict = X.coords[coord_names["x1"]]
2✔
280
    x2_predict = X.coords[coord_names["x2"]]
2✔
281

282
    if len(prepend_dims) != len(set(prepend_dims)):
2✔
283
        # TODO unit test
284
        raise ValueError(
×
285
            f"Length of prepend_dims ({len(prepend_dims)}) must be equal to length of "
286
            f"prepend_coords ({len(prepend_coords)})."
287
        )
288

289
    dims = [*prepend_dims, "time", coord_names["x1"], coord_names["x2"]]
2✔
290
    coords = {
2✔
291
        **prepend_coords,
292
        "time": pd.to_datetime(dates),
293
        coord_names["x1"]: x1_predict,
294
        coord_names["x2"]: x2_predict,
295
    }
296

297
    pred_ds = xr.Dataset(
2✔
298
        {data_var: xr.DataArray(dims=dims, coords=coords) for data_var in data_vars}
299
    ).astype("float32")
300

301
    # Convert time coord to pandas timestamps
302
    pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
2✔
303

304
    return pred_ds
2✔
305

306

307
def increase_spatial_resolution(
2✔
308
    X_t_normalised,
309
    resolution_factor,
310
    coord_names: dict = None,
311
):
312
    """...
313

314
    ..
315
        # TODO wasteful to interpolate X_t_normalised
316

317
    Args:
318
        X_t_normalised (...):
319
            ...
320
        resolution_factor (...):
321
            ...
322
        coord_names (dict, optional):
323
            Dict mapping from normalised coord names to raw coord names,
324
            by default {"x1": "x1", "x2": "x2"}
325

326
    Returns:
327
        ...
328
            ...
329

330
    """
331
    assert isinstance(resolution_factor, (float, int))
×
332
    assert isinstance(X_t_normalised, (xr.DataArray, xr.Dataset))
×
333
    if coord_names is None:
×
334
        coord_names = {"x1": "x1", "x2": "x2"}
×
335
    x1_name, x2_name = coord_names["x1"], coord_names["x2"]
×
336
    x1, x2 = X_t_normalised.coords[x1_name], X_t_normalised.coords[x2_name]
×
337
    x1 = np.linspace(x1[0], x1[-1], int(x1.size * resolution_factor), dtype="float64")
×
338
    x2 = np.linspace(x2[0], x2[-1], int(x2.size * resolution_factor), dtype="float64")
×
339
    X_t_normalised = X_t_normalised.interp(
×
340
        **{x1_name: x1, x2_name: x2}, method="nearest"
341
    )
342
    return X_t_normalised
×
343

344

345
def infer_prediction_modality_from_X_t(
2✔
346
    X_t: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray],
347
) -> str:
348
    """Args:
349
        X_t (Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, pd.Index, np.ndarray]):
350
            ...
351

352
    Returns:
353
        str: "on-grid" if X_t is an xarray object, "off-grid" if X_t is a pandas or numpy object.
354

355
    Raises:
356
        ValueError
357
            If X_t is not an xarray, pandas or numpy object.
358
    """
359
    if isinstance(X_t, (xr.DataArray, xr.Dataset)):
2✔
360
        mode = "on-grid"
2✔
361
    elif isinstance(X_t, (pd.DataFrame, pd.Series, pd.Index, np.ndarray)):
2✔
362
        mode = "off-grid"
2✔
363
    else:
364
        raise ValueError(
×
365
            f"X_t must be and xarray, pandas or numpy object. Got {type(X_t)}."
366
        )
367
    return mode
2✔
368

369

370
def _get_coordinate_extent(
2✔
371
    ds: Union[xr.DataArray, xr.Dataset],
372
    orig_x1_name: str,
373
    orig_x2_name: str,
374
    x1_ascend: bool,
375
    x2_ascend: bool,
376
) -> Tuple:
377
    """Get coordinate extent of dataset.
378
    Coordinate extent is defined as maximum and minimum value of x1 and x2.
379

380
    Parameters
381
    ----------
382
    ds : Data object
383
        The dataset or data array to determine coordinate extent for.
384

385
    x1_ascend : bool
386
        Whether the x1 coordinates ascend (increase) from top to bottom.
387

388
    x2_ascend : bool
389
        Whether the x2 coordinates ascend (increase) from left to right.
390

391
    Returns:
392
    -------
393
    tuple of tuples:
394
        Extents of x1 and x2 coordinates as ((min_x1, max_x1), (min_x2, max_x2)).
395
    """
396
    if x1_ascend:
2✔
397
        ds_x1_coords = (
2✔
398
            ds.coords[orig_x1_name].min().values,
399
            ds.coords[orig_x1_name].max().values,
400
        )
401
    else:
NEW
402
        ds_x1_coords = (
×
403
            ds.coords[orig_x1_name].max().values,
404
            ds.coords[orig_x1_name].min().values,
405
        )
406
    if x2_ascend:
2✔
407
        ds_x2_coords = (
2✔
408
            ds.coords[orig_x2_name].min().values,
409
            ds.coords[orig_x2_name].max().values,
410
        )
411
    else:
NEW
412
        ds_x2_coords = (
×
413
            ds.coords[orig_x2_name].max().values,
414
            ds.coords[orig_x2_name].min().values,
415
        )
416
    return ds_x1_coords, ds_x2_coords
2✔
417

418

419
def _get_index(
2✔
420
    *args,
421
    X_t: Union[
422
        xr.Dataset,
423
        xr.DataArray,
424
        pd.DataFrame,
425
        pd.Series,
426
        pd.Index,
427
        np.ndarray,
428
    ],
429
    orig_x1_name: str,
430
    orig_x2_name: str,
431
    x1: bool = True,
432
) -> Union[int, Tuple[List[int], List[int]]]:
433
    """Convert coordinates into pixel row/column (index).
434

435
    Parameters
436
    ----------
437
    args : tuple
438
        If one argument (numeric), it represents the coordinate value.
439
        If two arguments (lists), they represent lists of coordinate values.
440

441
    X_t : (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`)
442
        Target locations to predict at. Can be an xarray object
443
        containing on-grid locations or a pandas object containing off-grid locations.
444

445
    x1 : bool, optional
446
        If True, compute index for x1 (default is True).
447

448
    Returns:
449
    -------
450
        Union[int, Tuple[List[int], List[int]]]
451
        If one argument is provided and x1 is True or False, returns the index position.
452
        If two arguments are provided, returns a tuple containing two lists:
453
        - First list: indices corresponding to x1 coordinates.
454
        - Second list: indices corresponding to x2 coordinates.
455

456
    """
457
    if len(args) == 1:
2✔
458
        patch_coord = args
2✔
459
        if x1:
2✔
460
            coord_index = np.argmin(
2✔
461
                np.abs(X_t.coords[orig_x1_name].values - patch_coord)
462
            )
463
        else:
464
            coord_index = np.argmin(
2✔
465
                np.abs(X_t.coords[orig_x2_name].values - patch_coord)
466
            )
467
        return coord_index
2✔
468

469
    elif len(args) == 2:
2✔
470
        patch_x1, patch_x2 = args
2✔
471
        x1_index = [
2✔
472
            np.argmin(np.abs(X_t.coords[orig_x1_name].values - target_x1))
473
            for target_x1 in patch_x1
474
        ]
475
        x2_index = [
2✔
476
            np.argmin(np.abs(X_t.coords[orig_x2_name].values - target_x2))
477
            for target_x2 in patch_x2
478
        ]
479
        return (x1_index, x2_index)
2✔
480

481

482
def stitch_clipped_predictions(
2✔
483
    patch_preds: List[Prediction],
484
    patch_overlap: int,
485
    patches_per_row: int,
486
    X_t: Union[
487
        xr.Dataset,
488
        xr.DataArray,
489
        pd.DataFrame,
490
        pd.Series,
491
        pd.Index,
492
        np.ndarray,
493
    ],
494
    orig_x1_name: str,
495
    orig_x2_name: str,
496
    x1_ascend: bool = True,
497
    x2_ascend: bool = True,
498
) -> Prediction:
499
    """Stitch patchwise predictions to form prediction at original extent of X_t.
500

501
    Parameters
502
    ----------
503
    patch_preds : list (class:`~.model.pred.Prediction`)
504
        List of patchwise predictions
505

506
    patch_overlap: int
507
        Overlap between adjacent patches in pixels.
508

509
    patches_per_row: int
510
        Number of patchwise predictions in each row.
511

512
    X_t : (:class:`xarray.Dataset` | :class:`xarray.DataArray` | :class:`pandas.DataFrame` | :class:`pandas.Series` | :class:`pandas.Index` | :class:`numpy:numpy.ndarray`)
513
        Target locations to predict at. Can be an xarray object
514
        containing on-grid locations or a pandas object containing off-grid locations.
515

516
    orig_x1_name : str
517
        x1 coordinate names of original unnormalised dataset
518

519
    orig_x2_name : str
520
        x2 coordinate names of original unnormalised dataset
521

522
    x1_ascend : bool
523
        Boolean defining whether the x1 coords ascend (increase) from top to bottom, default = True.
524

525
    x2_ascend : bool
526
        Boolean defining whether the x2 coords ascend (increase) from left to right, default = True.
527

528
    Returns:
529
    -------
530
    combined: dict
531
        Dictionary object containing the stitched model predictions.
532
    """
533
    # Get row/col index values of X_t.
534
    data_x1_coords, data_x2_coords = _get_coordinate_extent(
2✔
535
        X_t,
536
        orig_x1_name=orig_x1_name,
537
        orig_x2_name=orig_x2_name,
538
        x1_ascend=x1_ascend,
539
        x2_ascend=x2_ascend,
540
    )
541
    data_x1_index, data_x2_index = _get_index(
2✔
542
        data_x1_coords,
543
        data_x2_coords,
544
        X_t=X_t,
545
        orig_x1_name=orig_x1_name,
546
        orig_x2_name=orig_x2_name,
547
    )
548

549
    # Iterate through patchwise predictions and slice edges prior to stitchin.
550
    patches_clipped = []
2✔
551
    for i, patch_pred in enumerate(patch_preds):
2✔
552
        # Get one variable name to use for coordinates and extent
553
        first_key = list(patch_pred.keys())[0]
2✔
554
        # Get row/col index values of each patch.
555
        patch_x1_coords, patch_x2_coords = _get_coordinate_extent(
2✔
556
            patch_pred[first_key],
557
            orig_x1_name=orig_x1_name,
558
            orig_x2_name=orig_x2_name,
559
            x1_ascend=x1_ascend,
560
            x2_ascend=x2_ascend,
561
        )
562
        patch_x1_index, patch_x2_index = _get_index(
2✔
563
            patch_x1_coords,
564
            patch_x2_coords,
565
            X_t=X_t,
566
            orig_x1_name=orig_x1_name,
567
            orig_x2_name=orig_x2_name,
568
        )
569

570
        # Calculate size of border to slice off each edge of patchwise predictions.
571
        # Initially set the size of all borders to the size of the overlap.
572
        b_x1_min, b_x1_max = patch_overlap[0], patch_overlap[0]
2✔
573
        b_x2_min, b_x2_max = patch_overlap[1], patch_overlap[1]
2✔
574

575
        # Do not remove border for the patches along top and left of dataset and change overlap size for last patch in each row and column.
576
        if patch_x2_index[0] == data_x2_index[0]:
2✔
577
            b_x2_min = 0
2✔
578
            b_x2_max = b_x2_max
2✔
579

580
        # At end of row (when patch_x2_index = data_x2_index), calculate the number of pixels to remove from left hand side of patch.
581
        elif patch_x2_index[1] == data_x2_index[1]:
2✔
582
            b_x2_max = 0
2✔
583
            patch_row_prev = patch_preds[i - 1]
2✔
584

585
            # If x2 is ascending, subtract previous patch x2 max value from current patch x2 min value to get bespoke overlap in column pixels.
586
            # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
587
            if x2_ascend:
2✔
588
                prev_patch_x2_max = _get_index(
2✔
589
                    patch_row_prev[first_key].coords[orig_x2_name].max(),
590
                    X_t=X_t,
591
                    orig_x1_name=orig_x1_name,
592
                    orig_x2_name=orig_x2_name,
593
                    x1=False,
594
                )
595
                b_x2_min = (prev_patch_x2_max - patch_x2_index[0]) - patch_overlap[1]
2✔
596

597
            # If x2 is descending, subtract current patch max x2 value from previous patch min x2 value to get bespoke overlap in column pixels.
598
            # To account for the clipping done to the previous patch, then subtract patch_overlap value in pixels
599
            else:
NEW
600
                prev_patch_x2_min = _get_index(
×
601
                    patch_row_prev[first_key].coords[orig_x2_name].min(),
602
                    X_t=X_t,
603
                    orig_x1_name=orig_x1_name,
604
                    orig_x2_name=orig_x2_name,
605
                    x1=False,
606
                )
NEW
607
                b_x2_min = (patch_x2_index[0] - prev_patch_x2_min) - patch_overlap[1]
×
608
        else:
609
            b_x2_max = b_x2_max
2✔
610

611
        # Repeat process as above for x1 coordinates.
612
        if patch_x1_index[0] == data_x1_index[0]:
2✔
613
            b_x1_min = 0
2✔
614

615
        elif abs(patch_x1_index[1] - data_x1_index[1]) < 2:
2✔
616
            b_x1_max = 0
2✔
617
            b_x1_max = b_x1_max
2✔
618
            patch_prev = patch_preds[i - patches_per_row]
2✔
619
            if x1_ascend:
2✔
620
                prev_patch_x1_max = _get_index(
2✔
621
                    patch_prev[first_key].coords[orig_x1_name].max(),
622
                    X_t=X_t,
623
                    orig_x1_name=orig_x1_name,
624
                    orig_x2_name=orig_x2_name,
625
                    x1=True,
626
                )
627
                b_x1_min = (prev_patch_x1_max - patch_x1_index[0]) - patch_overlap[0]
2✔
628
            else:
NEW
629
                prev_patch_x1_min = _get_index(
×
630
                    patch_prev[first_key].coords[orig_x1_name].min(),
631
                    X_t=X_t,
632
                    orig_x1_name=orig_x1_name,
633
                    orig_x2_name=orig_x2_name,
634
                    x1=True,
635
                )
636

NEW
637
                b_x1_min = (prev_patch_x1_min - patch_x1_index[0]) - patch_overlap[0]
×
638
        else:
639
            b_x1_max = b_x1_max
2✔
640

641
        patch_clip_x1_min = int(b_x1_min)
2✔
642
        patch_clip_x1_max = int(patch_pred[first_key].sizes[orig_x1_name] - b_x1_max)
2✔
643
        patch_clip_x2_min = int(b_x2_min)
2✔
644
        patch_clip_x2_max = int(patch_pred[first_key].sizes[orig_x2_name] - b_x2_max)
2✔
645

646
        # Define slicing parameters
647
        slicing_params = {
2✔
648
            orig_x1_name: slice(patch_clip_x1_min, patch_clip_x1_max),
649
            orig_x2_name: slice(patch_clip_x2_min, patch_clip_x2_max),
650
        }
651

652
        # Slice patchwise predictions
653
        patch_clip = {
2✔
654
            key: dataset.isel(**slicing_params) for key, dataset in patch_pred.items()
655
        }
656

657
        patches_clipped.append(patch_clip)
2✔
658

659
    # Create blank prediction object to stitch prediction values onto.
660
    stitched_prediction = copy.deepcopy(patch_preds[0])
2✔
661
    # Set prediction object extent to the same as X_t.
662
    for var_name, data_array in stitched_prediction.items():
2✔
663
        blank_ds = xr.Dataset(
2✔
664
            coords={
665
                orig_x1_name: X_t[orig_x1_name],
666
                orig_x2_name: X_t[orig_x2_name],
667
                "time": stitched_prediction[0]["time"],
668
            }
669
        )
670

671
        # Set data variable names e.g. mean, std to those in patched prediction. Make all values Nan.
672
        for data_var in data_array.data_vars:
2✔
673
            blank_ds[data_var] = data_array[data_var]
2✔
674
            blank_ds[data_var][:] = np.nan
2✔
675
        stitched_prediction[var_name] = blank_ds
2✔
676

677
    # Restructure prediction objects for merging
678
    restructured_patches = {
2✔
679
        key: [item[key] for item in patches_clipped]
680
        for key in patches_clipped[0].keys()
681
    }
682

683
    # Merge patchwise predictions to create final stiched prediction.
684
    # Iterate over each variable (key) in the prediction dictionary
685
    for var_name, patches in restructured_patches.items():
2✔
686
        # Retrieve the blank dataset for the current variable
687
        prediction_array = stitched_prediction[var_name]
2✔
688

689
        # Merge each patch into the combined dataset
690
        for patch in patches:
2✔
691
            for var in patch.data_vars:
2✔
692
                # Reindex the patch to catch any slight rounding errors and misalignment with the combined dataset
693
                reindexed_patch = patch[var].reindex_like(
2✔
694
                    prediction_array[var], method="nearest", tolerance=1e-6
695
                )
696

697
                # Combine data, prioritizing non-NaN values from patches
698
                prediction_array[var] = prediction_array[var].where(
2✔
699
                    np.isnan(reindexed_patch), reindexed_patch
700
                )
701

702
        # Update the dictionary with the merged dataset
703
        stitched_prediction[var_name] = prediction_array
2✔
704
    return stitched_prediction
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc