• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455483170

22 Oct 2024 07:38AM UTC coverage: 81.626%. Remained the same
11455483170

push

github

davidwilby
update pre-commit ruff version

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.04
/deepsensor/data/processor.py
1
import numpy as np
2✔
2
import os
2✔
3
import json
2✔
4

5
import warnings
2✔
6
import xarray as xr
2✔
7
import pandas as pd
2✔
8

9
import pprint
2✔
10

11
from copy import deepcopy
2✔
12

13
from plum import dispatch
2✔
14
from typing import Union, Optional, List
2✔
15

16

17
class DataProcessor:
2✔
18
    """Normalise xarray and pandas data for use in deepsensor models.
19

20
    Args:
21
        folder (str, optional):
22
            Folder to load normalisation params from. Defaults to None.
23
        x1_name (str, optional):
24
            Name of first spatial coord (e.g. "lat"). Defaults to "x1".
25
        x2_name (str, optional):
26
            Name of second spatial coord (e.g. "lon"). Defaults to "x2".
27
        x1_map (tuple, optional):
28
            2-tuple of raw x1 coords to linearly map to (0, 1),
29
            respectively. Defaults to (0, 1) (i.e. no normalisation).
30
        x2_map (tuple, optional):
31
            2-tuple of raw x2 coords to linearly map to (0, 1),
32
            respectively. Defaults to (0, 1) (i.e. no normalisation).
33
        deepcopy (bool, optional):
34
            Whether to make a deepcopy of raw data to ensure it is not
35
            changed by reference when normalising. Defaults to True.
36
        verbose (bool, optional):
37
            Whether to print verbose output. Defaults to False.
38
    """
39

40
    config_fname = "data_processor_config.json"
2✔
41

42
    def __init__(
2✔
43
        self,
44
        folder: Union[str, None] = None,
45
        time_name: str = "time",
46
        x1_name: str = "x1",
47
        x2_name: str = "x2",
48
        x1_map: Union[tuple, None] = None,
49
        x2_map: Union[tuple, None] = None,
50
        deepcopy: bool = True,
51
        verbose: bool = False,
52
    ):
53
        if folder is not None:
2✔
54
            fpath = os.path.join(folder, self.config_fname)
2✔
55
            if not os.path.exists(fpath):
2✔
56
                raise FileNotFoundError(
×
57
                    f"Could not find DataProcessor config file {fpath}"
58
                )
59
            with open(fpath, "r") as f:
2✔
60
                self.config = json.load(f)
2✔
61
                self.config["coords"]["x1"]["map"] = tuple(
2✔
62
                    self.config["coords"]["x1"]["map"]
63
                )
64
                self.config["coords"]["x2"]["map"] = tuple(
2✔
65
                    self.config["coords"]["x2"]["map"]
66
                )
67

68
            self.x1_name = self.config["coords"]["x1"]["name"]
2✔
69
            self.x2_name = self.config["coords"]["x2"]["name"]
2✔
70
            self.x1_map = self.config["coords"]["x1"]["map"]
2✔
71
            self.x2_map = self.config["coords"]["x2"]["map"]
2✔
72
        else:
73
            self.config = {}
2✔
74
            self.x1_name = x1_name
2✔
75
            self.x2_name = x2_name
2✔
76
            self.x1_map = x1_map
2✔
77
            self.x2_map = x2_map
2✔
78

79
            # rewrite below more concisely
80
            if self.x1_map is None and not self.x2_map is None:
2✔
81
                raise ValueError("Must provide both x1_map and x2_map, or neither.")
×
82
            elif not self.x1_map is None and self.x2_map is None:
2✔
83
                raise ValueError("Must provide both x1_map and x2_map, or neither.")
2✔
84
            elif not self.x1_map is None and not self.x2_map is None:
2✔
85
                x1_map, x2_map = self._validate_coord_mappings(x1_map, x2_map)
×
86

87
            if "coords" not in self.config:
2✔
88
                # Add coordinate normalisation info to config
89
                self.set_coord_params(time_name, x1_name, x1_map, x2_name, x2_map)
2✔
90

91
        self.raw_spatial_coord_names = [
2✔
92
            self.config["coords"][coord]["name"] for coord in ["x1", "x2"]
93
        ]
94

95
        self.deepcopy = deepcopy
2✔
96
        self.verbose = verbose
2✔
97

98
        # List of valid normalisation method names
99
        self.valid_methods = ["mean_std", "min_max", "positive_semidefinite"]
2✔
100

101
    def save(self, folder: str):
2✔
102
        """Save DataProcessor config to JSON in `folder`."""
103
        os.makedirs(folder, exist_ok=True)
2✔
104
        fpath = os.path.join(folder, self.config_fname)
2✔
105
        with open(fpath, "w") as f:
2✔
106
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
107

108
    def _validate_coord_mappings(self, x1_map, x2_map):
2✔
109
        """Ensure the maps are valid and of appropriate types."""
110
        try:
2✔
111
            x1_map = (float(x1_map[0]), float(x1_map[1]))
2✔
112
            x2_map = (float(x2_map[0]), float(x2_map[1]))
2✔
113
        except:
×
114
            raise TypeError(
×
115
                "Provided coordinate mappings can't be cast to 2D Tuple[float]"
116
            )
117

118
        # Check that map is not two of the same number
119
        if np.diff(x1_map) == 0:
2✔
120
            raise ValueError(
×
121
                f"x1_map must be a 2-tuple of different numbers, not {x1_map}"
122
            )
123
        if np.diff(x2_map) == 0:
2✔
124
            raise ValueError(
×
125
                f"x2_map must be a 2-tuple of different numbers, not {x2_map}"
126
            )
127
        if np.diff(x1_map) != np.diff(x2_map):
2✔
128
            warnings.warn(
×
129
                f"x1_map={x1_map} and x2_map={x2_map} have different ranges ({float(np.diff(x1_map))} "
130
                f"and {float(np.diff(x2_map))}, respectively). "
131
                "This can lead to stretching/squashing of data, which may "
132
                "impact model performance.",
133
                UserWarning,
134
            )
135

136
        return x1_map, x2_map
2✔
137

138
    def _validate_xr(self, data: Union[xr.DataArray, xr.Dataset]):
2✔
139
        def _validate_da(da: xr.DataArray):
2✔
140
            coord_names = [
2✔
141
                self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"]
142
            ]
143
            if coord_names[0] not in da.dims:
2✔
144
                # We don't have a time dimension.
145
                coord_names = coord_names[1:]
×
146
            if list(da.dims) != coord_names:
2✔
147
                raise ValueError(
2✔
148
                    f"Dimensions of {da.name} need to be {coord_names} but are {list(da.dims)}."
149
                )
150

151
        if isinstance(data, xr.DataArray):
2✔
152
            _validate_da(data)
2✔
153

154
        elif isinstance(data, xr.Dataset):
2✔
155
            for var_ID, da in data.data_vars.items():
2✔
156
                _validate_da(da)
2✔
157

158
    def _validate_pandas(self, df: Union[pd.DataFrame, pd.Series]):
2✔
159
        coord_names = [
2✔
160
            self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"]
161
        ]
162

163
        if coord_names[0] not in df.index.names:
2✔
164
            # We don't have a time dimension.
165
            if list(df.index.names)[:2] != coord_names[1:]:
×
166
                raise ValueError(
×
167
                    f"Indexes need to start with {coord_names} or {coord_names[1:]} but are {df.index.names}."
168
                )
169
        else:
170
            # We have a time dimension.
171
            if list(df.index.names)[:3] != coord_names:
2✔
172
                raise ValueError(
2✔
173
                    f"Indexes need to start with {coord_names} or {coord_names[1:]} but are {df.index.names}."
174
                )
175

176
    def __str__(self):
2✔
177
        s = "DataProcessor with normalisation params:\n"
×
178
        s += pprint.pformat(self.config)
×
179
        return s
×
180

181
    @classmethod
2✔
182
    def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]):
2✔
183
        """Load dask data into memory.
184

185
        Args:
186
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
187
                Description of the parameter.
188

189
        Returns:
190
            [Type and description of the returned value(s) needed.]
191
        """
192
        if isinstance(data, xr.DataArray):
2✔
193
            data.load()
2✔
194
        elif isinstance(data, xr.Dataset):
2✔
195
            data.load()
×
196
        return data
2✔
197

198
    def set_coord_params(self, time_name, x1_name, x1_map, x2_name, x2_map) -> None:
2✔
199
        """Set coordinate normalisation params.
200

201
        Args:
202
            time_name:
203
                [Type] Description needed.
204
            x1_name:
205
                [Type] Description needed.
206
            x1_map:
207
                [Type] Description needed.
208
            x2_name:
209
                [Type] Description needed.
210
            x2_map:
211
                [Type] Description needed.
212

213
        Returns:
214
            None.
215
        """
216
        self.config["coords"] = {}
2✔
217
        self.config["coords"]["time"] = {"name": time_name}
2✔
218
        self.config["coords"]["x1"] = {}
2✔
219
        self.config["coords"]["x2"] = {}
2✔
220
        self.config["coords"]["x1"]["name"] = x1_name
2✔
221
        self.config["coords"]["x1"]["map"] = x1_map
2✔
222
        self.config["coords"]["x2"]["name"] = x2_name
2✔
223
        self.config["coords"]["x2"]["map"] = x2_map
2✔
224

225
    def check_params_computed(self, var_ID, method) -> bool:
2✔
226
        """Check if normalisation params computed for a given variable.
227

228
        Args:
229
            var_ID:
230
                [Type] Description needed.
231
            method:
232
                [Type] Description needed.
233

234
        Returns:
235
            bool:
236
                Whether normalisation params are computed for a given variable.
237
        """
238
        if (
2✔
239
            var_ID in self.config
240
            and self.config[var_ID]["method"] == method
241
            and "params" in self.config[var_ID]
242
        ):
243
            return True
2✔
244

245
        return False
2✔
246

247
    def add_to_config(self, var_ID, **kwargs):
2✔
248
        """Add `kwargs` to `config` dict for variable `var_ID`."""
249
        self.config[var_ID] = kwargs
2✔
250

251
    def get_config(self, var_ID, data, method=None):
2✔
252
        """Get pre-computed normalisation params or compute them for variable
253
        ``var_ID``.
254

255
        .. note::
256
            TODO do we need to pass var_ID? Can we just use the name of data?
257

258
        Args:
259
            var_ID:
260
                [Type] Description needed.
261
            data:
262
                [Type] Description needed.
263
            method (optional):
264
                [Type] Description needed. Defaults to None.
265

266
        Returns:
267
            [Type]:
268
                Description of the returned value(s) needed.
269
        """
270
        if method not in self.valid_methods:
2✔
271
            raise ValueError(
×
272
                f"Method {method} not recognised. Must be one of {self.valid_methods}"
273
            )
274

275
        if self.check_params_computed(var_ID, method):
2✔
276
            # Already have "params" in config with `"method": method` - load them
277
            params = self.config[var_ID]["params"]
2✔
278
        else:
279
            # Params not computed - compute them now
280
            if self.verbose:
2✔
281
                print(
×
282
                    f"Normalisation params for {var_ID} not computed. Computing now... ",
283
                    end="",
284
                    flush=True,
285
                )
286
            DataProcessor.load_dask(data)
2✔
287
            if method == "mean_std":
2✔
288
                params = {"mean": float(data.mean()), "std": float(data.std())}
2✔
289
            elif method == "min_max":
2✔
290
                params = {"min": float(data.min()), "max": float(data.max())}
2✔
291
            elif method == "positive_semidefinite":
2✔
292
                params = {"min": float(data.min()), "std": float(data.std())}
2✔
293
            if self.verbose:
2✔
294
                print(f"Done. {var_ID} {method} params={params}")
×
295
            self.add_to_config(
2✔
296
                var_ID,
297
                **{"method": method, "params": params},
298
            )
299
        return params
2✔
300

301
    def map_coord_array(self, coord_array: np.ndarray, unnorm: bool = False):
2✔
302
        """Normalise or unnormalise a coordinate array.
303

304
        Args:
305
            coord_array (:class:`numpy:numpy.ndarray`):
306
                Array of shape ``(2, N)`` containing coords.
307
            unnorm (bool, optional):
308
                Whether to unnormalise. Defaults to ``False``.
309

310
        Returns:
311
            [Type]:
312
                Description of the returned value(s) needed.
313
        """
314
        x1, x2 = self.map_x1_and_x2(coord_array[0], coord_array[1], unnorm=unnorm)
2✔
315
        new_coords = np.stack([x1, x2], axis=0)
2✔
316
        return new_coords
2✔
317

318
    def map_x1_and_x2(self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False):
2✔
319
        """Normalise or unnormalise spatial coords in an array.
320

321
        Args:
322
            x1 (:class:`numpy:numpy.ndarray`):
323
                Array of shape ``(N_x1,)`` containing spatial coords of x1.
324
            x2 (:class:`numpy:numpy.ndarray`):
325
                Array of shape ``(N_x2,)`` containing spatial coords of x2.
326
            unnorm (bool, optional):
327
                Whether to unnormalise. Defaults to ``False``.
328

329
        Returns:
330
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
331
                Normalised or unnormalised spatial coords of x1 and x2.
332
        """
333
        x11, x12 = self.config["coords"]["x1"]["map"]
2✔
334
        x21, x22 = self.config["coords"]["x2"]["map"]
2✔
335

336
        if not unnorm:
2✔
337
            new_coords_x1 = (x1 - x11) / (x12 - x11)
2✔
338
            new_coords_x2 = (x2 - x21) / (x22 - x21)
2✔
339
        else:
340
            new_coords_x1 = x1 * (x12 - x11) + x11
2✔
341
            new_coords_x2 = x2 * (x22 - x21) + x21
2✔
342

343
        return new_coords_x1, new_coords_x2
2✔
344

345
    def map_coords(
2✔
346
        self,
347
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series],
348
        unnorm=False,
349
    ):
350
        """Normalise spatial coords in a pandas or xarray object.
351

352
        Args:
353
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`):
354
                [Description Needed]
355
            unnorm (bool, optional):
356
                [Description Needed]. Defaults to [Default Value].
357

358
        Returns:
359
            [Type]:
360
                [Description Needed]
361
        """
362
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
363
            # Reset index to get coords as columns
364
            indexes = list(data.index.names)
2✔
365
            data = data.reset_index()
2✔
366

367
        if unnorm:
2✔
368
            new_coord_IDs = [
2✔
369
                self.config["coords"][coord_ID]["name"]
370
                for coord_ID in ["time", "x1", "x2"]
371
            ]
372
            old_coord_IDs = ["time", "x1", "x2"]
2✔
373
        else:
374
            new_coord_IDs = ["time", "x1", "x2"]
2✔
375
            old_coord_IDs = [
2✔
376
                self.config["coords"][coord_ID]["name"]
377
                for coord_ID in ["time", "x1", "x2"]
378
            ]
379

380
        x1, x2 = (
2✔
381
            data[old_coord_IDs[1]],
382
            data[old_coord_IDs[2]],
383
        )
384

385
        # Infer x1 and x2 mappings from min/max of data coords if not provided by user
386
        if self.x1_map is None and self.x2_map is None:
2✔
387
            # Ensure scalings are the same for x1 and x2
388
            x1_range = x1.max() - x1.min()
2✔
389
            x2_range = x2.max() - x2.min()
2✔
390
            range = np.max([x1_range, x2_range])
2✔
391
            self.x1_map = (x1.min(), x1.min() + range)
2✔
392
            self.x2_map = (x2.min(), x2.min() + range)
2✔
393

394
            self.x1_map, self.x2_map = self._validate_coord_mappings(
2✔
395
                self.x1_map, self.x2_map
396
            )
397
            self.config["coords"]["x1"]["map"] = self.x1_map
2✔
398
            self.config["coords"]["x2"]["map"] = self.x2_map
2✔
399

400
            if self.verbose:
2✔
401
                print(
×
402
                    f"Inferring x1_map={self.x1_map} and x2_map={self.x2_map} from data min/max"
403
                )
404

405
        new_x1, new_x2 = self.map_x1_and_x2(x1, x2, unnorm=unnorm)
2✔
406

407
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
408
            # Drop old spatial coord columns *before* adding new ones, in case
409
            # the old name is already x1.
410
            data = data.drop(columns=old_coord_IDs[1:])
2✔
411
            # Add coords to dataframe
412
            data[new_coord_IDs[1]] = new_x1
2✔
413
            data[new_coord_IDs[2]] = new_x2
2✔
414

415
            if old_coord_IDs[0] in data.columns:
2✔
416
                # Rename time dimension.
417
                rename = {old_coord_IDs[0]: new_coord_IDs[0]}
2✔
418
                data = data.rename(rename, axis=1)
2✔
419
            else:
420
                # We don't have a time dimension.
421
                old_coord_IDs = old_coord_IDs[1:]
2✔
422
                new_coord_IDs = new_coord_IDs[1:]
2✔
423

424
        elif isinstance(data, (xr.DataArray, xr.Dataset)):
2✔
425
            data = data.assign_coords(
2✔
426
                {old_coord_IDs[1]: new_x1, old_coord_IDs[2]: new_x2}
427
            )
428

429
            if old_coord_IDs[0] not in data.dims:
2✔
430
                # We don't have a time dimension.
431
                old_coord_IDs = old_coord_IDs[1:]
2✔
432
                new_coord_IDs = new_coord_IDs[1:]
2✔
433

434
            # Rename all dimensions.
435
            rename = {
2✔
436
                old: new for old, new in zip(old_coord_IDs, new_coord_IDs) if old != new
437
            }
438
            data = data.rename(rename)
2✔
439

440
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
441
            # Set index back to original
442
            [indexes.remove(old_coord_ID) for old_coord_ID in old_coord_IDs]
2✔
443
            indexes = new_coord_IDs + indexes  # Put dims first
2✔
444
            data = data.set_index(indexes)
2✔
445

446
        return data
2✔
447

448
    def map_array(
2✔
449
        self,
450
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray],
451
        var_ID: str,
452
        method: Optional[str] = None,
453
        unnorm: bool = False,
454
        add_offset: bool = True,
455
    ):
456
        """Normalise or unnormalise the data values in an xarray, pandas, or
457
        numpy object.
458

459
        Args:
460
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, :class:`pandas.Series`, or :class:`numpy:numpy.ndarray`):
461
                [Description Needed]
462
            var_ID (str):
463
                [Description Needed]
464
            method (str, optional):
465
                [Description Needed]. Defaults to None.
466
            unnorm (bool, optional):
467
                [Description Needed]. Defaults to False.
468
            add_offset (bool, optional):
469
                [Description Needed]. Defaults to True.
470

471
        Returns:
472
            [Type]:
473
                [Description Needed]
474
        """
475
        if not unnorm and method is None:
2✔
476
            raise ValueError("Must provide `method` if normalising data.")
×
477
        elif unnorm and method is not None and self.config[var_ID]["method"] != method:
2✔
478
            # User has provided a different method to the one used for normalising
479
            raise ValueError(
×
480
                f"Variable '{var_ID}' has been normalised with method '{self.config[var_ID]['method']}', "
481
                f"cannot unnormalise with method '{method}'. Pass `method=None` or"
482
                f"`method='{self.config[var_ID]['method']}'`"
483
            )
484
        if method is None and unnorm:
2✔
485
            # Determine normalisation method from config for unnormalising
486
            method = self.config[var_ID]["method"]
2✔
487
        elif method not in self.valid_methods:
2✔
488
            raise ValueError(
×
489
                f"Method {method} not recognised. Use one of {self.valid_methods}"
490
            )
491

492
        params = self.get_config(var_ID, data, method)
2✔
493

494
        # Linear transformation:
495
        # - Inverse normalisation: y_unnorm = m * y_norm + c
496
        # - Inverse normalisation: y_norm = (1/m) * y_unnorm - c/m
497
        if method == "mean_std":
2✔
498
            m = params["std"]
2✔
499
            c = params["mean"]
2✔
500
        elif method == "min_max":
2✔
501
            m = (params["max"] - params["min"]) / 2
2✔
502
            c = (params["max"] + params["min"]) / 2
2✔
503
        elif method == "positive_semidefinite":
2✔
504
            m = params["std"]
2✔
505
            c = params["min"]
2✔
506
        if not unnorm:
2✔
507
            c = -c / m
2✔
508
            m = 1 / m
2✔
509
        data = data * m
2✔
510
        if add_offset:
2✔
511
            data = data + c
2✔
512
        return data
2✔
513

514
    def map(
2✔
515
        self,
516
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series],
517
        method: Optional[str] = None,
518
        add_offset: bool = True,
519
        unnorm: bool = False,
520
        assert_computed: bool = False,
521
    ):
522
        """Normalise or unnormalise the data values and coords in an xarray or
523
        pandas object.
524

525
        Args:
526
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`):
527
                [Description Needed]
528
            method (str, optional):
529
                [Description Needed]. Defaults to None.
530
            add_offset (bool, optional):
531
                [Description Needed]. Defaults to True.
532
            unnorm (bool, optional):
533
                [Description Needed]. Defaults to False.
534

535
        Returns:
536
            [Type]:
537
                [Description Needed]
538
        """
539
        if self.deepcopy:
2✔
540
            data = deepcopy(data)
2✔
541

542
        if isinstance(data, (xr.DataArray, xr.Dataset)) and not unnorm:
2✔
543
            self._validate_xr(data)
2✔
544
        elif isinstance(data, (pd.DataFrame, pd.Series)) and not unnorm:
2✔
545
            self._validate_pandas(data)
2✔
546

547
        if isinstance(data, (xr.DataArray, pd.Series)):
2✔
548
            # Single var
549
            var_ID = data.name
2✔
550
            if assert_computed:
2✔
551
                assert self.check_params_computed(
×
552
                    var_ID, method
553
                ), f"{method} normalisation params for {var_ID} not computed."
554
            data = self.map_array(data, var_ID, method, unnorm, add_offset)
2✔
555
        elif isinstance(data, (xr.Dataset, pd.DataFrame)):
2✔
556
            # Multiple vars
557
            for var_ID in data:
2✔
558
                if assert_computed:
2✔
559
                    assert self.check_params_computed(
×
560
                        var_ID, method
561
                    ), f"{method} normalisation params for {var_ID} not computed."
562
                data[var_ID] = self.map_array(
2✔
563
                    data[var_ID], var_ID, method, unnorm, add_offset
564
                )
565

566
        data = self.map_coords(data, unnorm=unnorm)
2✔
567

568
        return data
2✔
569

570
    def __call__(
2✔
571
        self,
572
        data: Union[
573
            xr.DataArray,
574
            xr.Dataset,
575
            pd.DataFrame,
576
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
577
        ],
578
        method: str = "mean_std",
579
        assert_computed: bool = False,
580
    ) -> Union[
581
        xr.DataArray,
582
        xr.Dataset,
583
        pd.DataFrame,
584
        List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
585
    ]:
586
        """Normalise data.
587

588
        Args:
589
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]):
590
                Data to be normalised. Can be an xarray DataArray, xarray
591
                Dataset, pandas DataFrame, or a list containing objects of
592
                these types.
593
            method (str, optional): Normalisation method. Options include:
594
                - "mean_std": Normalise to mean=0 and std=1 (default)
595
                - "min_max": Normalise to min=-1 and max=1
596
                - "positive_semidefinite": Normalise to min=0 and std=1
597

598
        Returns:
599
            :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
600
                Normalised data. Type or structure depends on the input.
601
        """
602
        if isinstance(data, list):
2✔
603
            return [
2✔
604
                self.map(d, method, unnorm=False, assert_computed=assert_computed)
605
                for d in data
606
            ]
607
        else:
608
            return self.map(data, method, unnorm=False, assert_computed=assert_computed)
2✔
609

610
    def unnormalise(
2✔
611
        self,
612
        data: Union[
613
            xr.DataArray,
614
            xr.Dataset,
615
            pd.DataFrame,
616
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
617
        ],
618
        add_offset: bool = True,
619
    ) -> Union[
620
        xr.DataArray,
621
        xr.Dataset,
622
        pd.DataFrame,
623
        List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
624
    ]:
625
        """Unnormalise data.
626

627
        Args:
628
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]):
629
                Data to unnormalise.
630
            add_offset (bool, optional):
631
                Whether to add the offset to the data when unnormalising. Set
632
                to False to unnormalise uncertainty values (e.g. std dev).
633
                Defaults to True.
634

635
        Returns:
636
            :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
637
                Unnormalised data.
638
        """
639
        if isinstance(data, list):
2✔
640
            return [self.map(d, add_offset=add_offset, unnorm=True) for d in data]
2✔
641
        else:
642
            return self.map(data, add_offset=add_offset, unnorm=True)
2✔
643

644

645
def xarray_to_coord_array_normalised(da: Union[xr.Dataset, xr.DataArray]) -> np.ndarray:
2✔
646
    """Convert xarray to normalised coordinate array.
647

648
    Args:
649
        da (:class:`xarray.Dataset` | :class:`xarray.DataArray`)
650
            ...
651

652
    Returns:
653
        :class:`numpy:numpy.ndarray`
654
            A normalised coordinate array of shape ``(2, N)``.
655
    """
656
    x1, x2 = da["x1"].values, da["x2"].values
2✔
657
    X1, X2 = np.meshgrid(x1, x2, indexing="ij")
2✔
658
    return np.stack([X1.ravel(), X2.ravel()], axis=0)
2✔
659

660

661
def process_X_mask_for_X(X_mask: xr.DataArray, X: xr.DataArray) -> xr.DataArray:
2✔
662
    """Process X_mask by interpolating to X and converting to boolean.
663

664
    Both X_mask and X are xarray DataArrays with the same spatial coords.
665

666
    Args:
667
        X_mask (:class:`xarray.DataArray`):
668
            ...
669
        X (:class:`xarray.DataArray`):
670
            ...
671

672
    Returns:
673
        :class:`xarray.DataArray`
674
            ...
675
    """
676
    X_mask = X_mask.astype(float).interp_like(
×
677
        X, method="nearest", kwargs={"fill_value": 0}
678
    )
679
    X_mask.data = X_mask.data.astype(bool)
×
680
    X_mask.load()
×
681
    return X_mask
×
682

683

684
def mask_coord_array_normalised(
2✔
685
    coord_arr: np.ndarray, mask_da: Union[xr.DataArray, xr.Dataset, None]
686
) -> np.ndarray:
687
    """Remove points from (2, N) numpy array that are outside gridded xarray
688
    boolean mask.
689

690
    If `coord_arr` is shape `(2, N)`, then `mask_da` is a shape `(N,)` boolean
691
    array (True if point is inside mask, False if outside).
692

693
    Args:
694
        coord_arr (:class:`numpy:numpy.ndarray`):
695
            ...
696
        mask_da (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
697
            ...
698

699
    Returns:
700
        :class:`numpy:numpy.ndarray`
701
            ...
702
    """
703
    if mask_da is None:
×
704
        return coord_arr
×
705
    mask_da = mask_da.astype(bool)
×
706
    x1, x2 = xr.DataArray(coord_arr[0]), xr.DataArray(coord_arr[1])
×
707
    mask_da = mask_da.sel(x1=x1, x2=x2, method="nearest")
×
708
    return coord_arr[:, mask_da]
×
709

710

711
def da1_da2_same_grid(da1: xr.DataArray, da2: xr.DataArray) -> bool:
2✔
712
    """Check if ``da1`` and ``da2`` are on the same grid.
713

714
    .. note::
715
        ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``.
716

717
    Args:
718
        da1 (:class:`xarray.DataArray`):
719
            ...
720
        da2 (:class:`xarray.DataArray`):
721
            ...
722

723
    Returns:
724
        bool
725
            Whether ``da1`` and ``da2`` are on the same grid.
726
    """
727
    x1equal = np.array_equal(da1["x1"].values, da2["x1"].values)
×
728
    x2equal = np.array_equal(da1["x2"].values, da2["x2"].values)
×
729
    return x1equal and x2equal
×
730

731

732
def interp_da1_to_da2(da1: xr.DataArray, da2: xr.DataArray) -> xr.DataArray:
2✔
733
    """Interpolate ``da1`` to ``da2``.
734

735
    .. note::
736
        ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``.
737

738
    Args:
739
        da1 (:class:`xarray.DataArray`):
740
            ...
741
        da2 (:class:`xarray.DataArray`):
742
            ...
743

744
    Returns:
745
        :class:`xarray.DataArray`
746
            Interpolated xarray.
747
    """
748
    return da1.interp(x1=da2["x1"], x2=da2["x2"], method="nearest")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc