• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

alan-turing-institute / deepsensor / 11455747995

22 Oct 2024 07:56AM UTC coverage: 81.626% (+0.3%) from 81.333%
11455747995

push

github

davidwilby
incorporate feedback

2048 of 2509 relevant lines covered (81.63%)

1.63 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.04
/deepsensor/data/processor.py
1
import numpy as np
2✔
2
import os
2✔
3
import json
2✔
4

5
import warnings
2✔
6
import xarray as xr
2✔
7
import pandas as pd
2✔
8

9
import pprint
2✔
10

11
from copy import deepcopy
2✔
12

13
from plum import dispatch
2✔
14
from typing import Union, Optional, List
2✔
15

16

17
class DataProcessor:
2✔
18
    """
19
    Normalise xarray and pandas data for use in deepsensor models
20

21
    Args:
22
        folder (str, optional):
23
            Folder to load normalisation params from. Defaults to None.
24
        x1_name (str, optional):
25
            Name of first spatial coord (e.g. "lat"). Defaults to "x1".
26
        x2_name (str, optional):
27
            Name of second spatial coord (e.g. "lon"). Defaults to "x2".
28
        x1_map (tuple, optional):
29
            2-tuple of raw x1 coords to linearly map to (0, 1),
30
            respectively. Defaults to (0, 1) (i.e. no normalisation).
31
        x2_map (tuple, optional):
32
            2-tuple of raw x2 coords to linearly map to (0, 1),
33
            respectively. Defaults to (0, 1) (i.e. no normalisation).
34
        deepcopy (bool, optional):
35
            Whether to make a deepcopy of raw data to ensure it is not
36
            changed by reference when normalising. Defaults to True.
37
        verbose (bool, optional):
38
            Whether to print verbose output. Defaults to False.
39
    """
40

41
    config_fname = "data_processor_config.json"
2✔
42

43
    def __init__(
2✔
44
        self,
45
        folder: Union[str, None] = None,
46
        time_name: str = "time",
47
        x1_name: str = "x1",
48
        x2_name: str = "x2",
49
        x1_map: Union[tuple, None] = None,
50
        x2_map: Union[tuple, None] = None,
51
        deepcopy: bool = True,
52
        verbose: bool = False,
53
    ):
54
        if folder is not None:
2✔
55
            fpath = os.path.join(folder, self.config_fname)
2✔
56
            if not os.path.exists(fpath):
2✔
57
                raise FileNotFoundError(
×
58
                    f"Could not find DataProcessor config file {fpath}"
59
                )
60
            with open(fpath, "r") as f:
2✔
61
                self.config = json.load(f)
2✔
62
                self.config["coords"]["x1"]["map"] = tuple(
2✔
63
                    self.config["coords"]["x1"]["map"]
64
                )
65
                self.config["coords"]["x2"]["map"] = tuple(
2✔
66
                    self.config["coords"]["x2"]["map"]
67
                )
68

69
            self.x1_name = self.config["coords"]["x1"]["name"]
2✔
70
            self.x2_name = self.config["coords"]["x2"]["name"]
2✔
71
            self.x1_map = self.config["coords"]["x1"]["map"]
2✔
72
            self.x2_map = self.config["coords"]["x2"]["map"]
2✔
73
        else:
74
            self.config = {}
2✔
75
            self.x1_name = x1_name
2✔
76
            self.x2_name = x2_name
2✔
77
            self.x1_map = x1_map
2✔
78
            self.x2_map = x2_map
2✔
79

80
            # rewrite below more concisely
81
            if self.x1_map is None and not self.x2_map is None:
2✔
82
                raise ValueError("Must provide both x1_map and x2_map, or neither.")
×
83
            elif not self.x1_map is None and self.x2_map is None:
2✔
84
                raise ValueError("Must provide both x1_map and x2_map, or neither.")
2✔
85
            elif not self.x1_map is None and not self.x2_map is None:
2✔
86
                x1_map, x2_map = self._validate_coord_mappings(x1_map, x2_map)
×
87

88
            if "coords" not in self.config:
2✔
89
                # Add coordinate normalisation info to config
90
                self.set_coord_params(time_name, x1_name, x1_map, x2_name, x2_map)
2✔
91

92
        self.raw_spatial_coord_names = [
2✔
93
            self.config["coords"][coord]["name"] for coord in ["x1", "x2"]
94
        ]
95

96
        self.deepcopy = deepcopy
2✔
97
        self.verbose = verbose
2✔
98

99
        # List of valid normalisation method names
100
        self.valid_methods = ["mean_std", "min_max", "positive_semidefinite"]
2✔
101

102
    def save(self, folder: str):
2✔
103
        """Save DataProcessor config to JSON in `folder`"""
104
        os.makedirs(folder, exist_ok=True)
2✔
105
        fpath = os.path.join(folder, self.config_fname)
2✔
106
        with open(fpath, "w") as f:
2✔
107
            json.dump(self.config, f, indent=4, sort_keys=False)
2✔
108

109
    def _validate_coord_mappings(self, x1_map, x2_map):
2✔
110
        """Ensure the maps are valid and of appropriate types."""
111
        try:
2✔
112
            x1_map = (float(x1_map[0]), float(x1_map[1]))
2✔
113
            x2_map = (float(x2_map[0]), float(x2_map[1]))
2✔
114
        except:
×
115
            raise TypeError(
×
116
                "Provided coordinate mappings can't be cast to 2D Tuple[float]"
117
            )
118

119
        # Check that map is not two of the same number
120
        if np.diff(x1_map) == 0:
2✔
121
            raise ValueError(
×
122
                f"x1_map must be a 2-tuple of different numbers, not {x1_map}"
123
            )
124
        if np.diff(x2_map) == 0:
2✔
125
            raise ValueError(
×
126
                f"x2_map must be a 2-tuple of different numbers, not {x2_map}"
127
            )
128
        if np.diff(x1_map) != np.diff(x2_map):
2✔
129
            warnings.warn(
×
130
                f"x1_map={x1_map} and x2_map={x2_map} have different ranges ({float(np.diff(x1_map))} "
131
                f"and {float(np.diff(x2_map))}, respectively). "
132
                "This can lead to stretching/squashing of data, which may "
133
                "impact model performance.",
134
                UserWarning,
135
            )
136

137
        return x1_map, x2_map
2✔
138

139
    def _validate_xr(self, data: Union[xr.DataArray, xr.Dataset]):
2✔
140
        def _validate_da(da: xr.DataArray):
2✔
141
            coord_names = [
2✔
142
                self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"]
143
            ]
144
            if coord_names[0] not in da.dims:
2✔
145
                # We don't have a time dimension.
146
                coord_names = coord_names[1:]
×
147
            if list(da.dims) != coord_names:
2✔
148
                raise ValueError(
2✔
149
                    f"Dimensions of {da.name} need to be {coord_names} but are {list(da.dims)}."
150
                )
151

152
        if isinstance(data, xr.DataArray):
2✔
153
            _validate_da(data)
2✔
154

155
        elif isinstance(data, xr.Dataset):
2✔
156
            for var_ID, da in data.data_vars.items():
2✔
157
                _validate_da(da)
2✔
158

159
    def _validate_pandas(self, df: Union[pd.DataFrame, pd.Series]):
2✔
160
        coord_names = [
2✔
161
            self.config["coords"][coord]["name"] for coord in ["time", "x1", "x2"]
162
        ]
163

164
        if coord_names[0] not in df.index.names:
2✔
165
            # We don't have a time dimension.
166
            if list(df.index.names)[:2] != coord_names[1:]:
×
167
                raise ValueError(
×
168
                    f"Indexes need to start with {coord_names} or {coord_names[1:]} but are {df.index.names}."
169
                )
170
        else:
171
            # We have a time dimension.
172
            if list(df.index.names)[:3] != coord_names:
2✔
173
                raise ValueError(
2✔
174
                    f"Indexes need to start with {coord_names} or {coord_names[1:]} but are {df.index.names}."
175
                )
176

177
    def __str__(self):
2✔
178
        s = "DataProcessor with normalisation params:\n"
×
179
        s += pprint.pformat(self.config)
×
180
        return s
×
181

182
    @classmethod
2✔
183
    def load_dask(cls, data: Union[xr.DataArray, xr.Dataset]):
2✔
184
        """
185
        Load dask data into memory.
186

187
        Args:
188
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset`):
189
                Description of the parameter.
190

191
        Returns:
192
            [Type and description of the returned value(s) needed.]
193
        """
194
        if isinstance(data, xr.DataArray):
2✔
195
            data.load()
2✔
196
        elif isinstance(data, xr.Dataset):
2✔
197
            data.load()
×
198
        return data
2✔
199

200
    def set_coord_params(self, time_name, x1_name, x1_map, x2_name, x2_map) -> None:
2✔
201
        """
202
        Set coordinate normalisation params.
203

204
        Args:
205
            time_name:
206
                [Type] Description needed.
207
            x1_name:
208
                [Type] Description needed.
209
            x1_map:
210
                [Type] Description needed.
211
            x2_name:
212
                [Type] Description needed.
213
            x2_map:
214
                [Type] Description needed.
215

216
        Returns:
217
            None.
218
        """
219
        self.config["coords"] = {}
2✔
220
        self.config["coords"]["time"] = {"name": time_name}
2✔
221
        self.config["coords"]["x1"] = {}
2✔
222
        self.config["coords"]["x2"] = {}
2✔
223
        self.config["coords"]["x1"]["name"] = x1_name
2✔
224
        self.config["coords"]["x1"]["map"] = x1_map
2✔
225
        self.config["coords"]["x2"]["name"] = x2_name
2✔
226
        self.config["coords"]["x2"]["map"] = x2_map
2✔
227

228
    def check_params_computed(self, var_ID, method) -> bool:
2✔
229
        """
230
        Check if normalisation params computed for a given variable.
231

232
        Args:
233
            var_ID:
234
                [Type] Description needed.
235
            method:
236
                [Type] Description needed.
237

238
        Returns:
239
            bool:
240
                Whether normalisation params are computed for a given variable.
241
        """
242
        if (
2✔
243
            var_ID in self.config
244
            and self.config[var_ID]["method"] == method
245
            and "params" in self.config[var_ID]
246
        ):
247
            return True
2✔
248

249
        return False
2✔
250

251
    def add_to_config(self, var_ID, **kwargs):
2✔
252
        """Add `kwargs` to `config` dict for variable `var_ID`"""
253
        self.config[var_ID] = kwargs
2✔
254

255
    def get_config(self, var_ID, data, method=None):
2✔
256
        """
257
        Get pre-computed normalisation params or compute them for variable
258
        ``var_ID``.
259

260
        .. note::
261
            TODO do we need to pass var_ID? Can we just use the name of data?
262

263
        Args:
264
            var_ID:
265
                [Type] Description needed.
266
            data:
267
                [Type] Description needed.
268
            method (optional):
269
                [Type] Description needed. Defaults to None.
270

271
        Returns:
272
            [Type]:
273
                Description of the returned value(s) needed.
274
        """
275
        if method not in self.valid_methods:
2✔
276
            raise ValueError(
×
277
                f"Method {method} not recognised. Must be one of {self.valid_methods}"
278
            )
279

280
        if self.check_params_computed(var_ID, method):
2✔
281
            # Already have "params" in config with `"method": method` - load them
282
            params = self.config[var_ID]["params"]
2✔
283
        else:
284
            # Params not computed - compute them now
285
            if self.verbose:
2✔
286
                print(
×
287
                    f"Normalisation params for {var_ID} not computed. Computing now... ",
288
                    end="",
289
                    flush=True,
290
                )
291
            DataProcessor.load_dask(data)
2✔
292
            if method == "mean_std":
2✔
293
                params = {"mean": float(data.mean()), "std": float(data.std())}
2✔
294
            elif method == "min_max":
2✔
295
                params = {"min": float(data.min()), "max": float(data.max())}
2✔
296
            elif method == "positive_semidefinite":
2✔
297
                params = {"min": float(data.min()), "std": float(data.std())}
2✔
298
            if self.verbose:
2✔
299
                print(f"Done. {var_ID} {method} params={params}")
×
300
            self.add_to_config(
2✔
301
                var_ID,
302
                **{"method": method, "params": params},
303
            )
304
        return params
2✔
305

306
    def map_coord_array(self, coord_array: np.ndarray, unnorm: bool = False):
2✔
307
        """
308
        Normalise or unnormalise a coordinate array.
309

310
        Args:
311
            coord_array (:class:`numpy:numpy.ndarray`):
312
                Array of shape ``(2, N)`` containing coords.
313
            unnorm (bool, optional):
314
                Whether to unnormalise. Defaults to ``False``.
315

316
        Returns:
317
            [Type]:
318
                Description of the returned value(s) needed.
319
        """
320
        x1, x2 = self.map_x1_and_x2(coord_array[0], coord_array[1], unnorm=unnorm)
2✔
321
        new_coords = np.stack([x1, x2], axis=0)
2✔
322
        return new_coords
2✔
323

324
    def map_x1_and_x2(self, x1: np.ndarray, x2: np.ndarray, unnorm: bool = False):
2✔
325
        """
326
        Normalise or unnormalise spatial coords in an array.
327

328
        Args:
329
            x1 (:class:`numpy:numpy.ndarray`):
330
                Array of shape ``(N_x1,)`` containing spatial coords of x1.
331
            x2 (:class:`numpy:numpy.ndarray`):
332
                Array of shape ``(N_x2,)`` containing spatial coords of x2.
333
            unnorm (bool, optional):
334
                Whether to unnormalise. Defaults to ``False``.
335

336
        Returns:
337
            Tuple[:class:`numpy:numpy.ndarray`, :class:`numpy:numpy.ndarray`]:
338
                Normalised or unnormalised spatial coords of x1 and x2.
339
        """
340
        x11, x12 = self.config["coords"]["x1"]["map"]
2✔
341
        x21, x22 = self.config["coords"]["x2"]["map"]
2✔
342

343
        if not unnorm:
2✔
344
            new_coords_x1 = (x1 - x11) / (x12 - x11)
2✔
345
            new_coords_x2 = (x2 - x21) / (x22 - x21)
2✔
346
        else:
347
            new_coords_x1 = x1 * (x12 - x11) + x11
2✔
348
            new_coords_x2 = x2 * (x22 - x21) + x21
2✔
349

350
        return new_coords_x1, new_coords_x2
2✔
351

352
    def map_coords(
2✔
353
        self,
354
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series],
355
        unnorm=False,
356
    ):
357
        """
358
        Normalise spatial coords in a pandas or xarray object.
359

360
        Args:
361
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`):
362
                [Description Needed]
363
            unnorm (bool, optional):
364
                [Description Needed]. Defaults to [Default Value].
365

366
        Returns:
367
            [Type]:
368
                [Description Needed]
369
        """
370
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
371
            # Reset index to get coords as columns
372
            indexes = list(data.index.names)
2✔
373
            data = data.reset_index()
2✔
374

375
        if unnorm:
2✔
376
            new_coord_IDs = [
2✔
377
                self.config["coords"][coord_ID]["name"]
378
                for coord_ID in ["time", "x1", "x2"]
379
            ]
380
            old_coord_IDs = ["time", "x1", "x2"]
2✔
381
        else:
382
            new_coord_IDs = ["time", "x1", "x2"]
2✔
383
            old_coord_IDs = [
2✔
384
                self.config["coords"][coord_ID]["name"]
385
                for coord_ID in ["time", "x1", "x2"]
386
            ]
387

388
        x1, x2 = (
2✔
389
            data[old_coord_IDs[1]],
390
            data[old_coord_IDs[2]],
391
        )
392

393
        # Infer x1 and x2 mappings from min/max of data coords if not provided by user
394
        if self.x1_map is None and self.x2_map is None:
2✔
395
            # Ensure scalings are the same for x1 and x2
396
            x1_range = x1.max() - x1.min()
2✔
397
            x2_range = x2.max() - x2.min()
2✔
398
            range = np.max([x1_range, x2_range])
2✔
399
            self.x1_map = (x1.min(), x1.min() + range)
2✔
400
            self.x2_map = (x2.min(), x2.min() + range)
2✔
401

402
            self.x1_map, self.x2_map = self._validate_coord_mappings(
2✔
403
                self.x1_map, self.x2_map
404
            )
405
            self.config["coords"]["x1"]["map"] = self.x1_map
2✔
406
            self.config["coords"]["x2"]["map"] = self.x2_map
2✔
407

408
            if self.verbose:
2✔
409
                print(
×
410
                    f"Inferring x1_map={self.x1_map} and x2_map={self.x2_map} from data min/max"
411
                )
412

413
        new_x1, new_x2 = self.map_x1_and_x2(x1, x2, unnorm=unnorm)
2✔
414

415
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
416
            # Drop old spatial coord columns *before* adding new ones, in case
417
            # the old name is already x1.
418
            data = data.drop(columns=old_coord_IDs[1:])
2✔
419
            # Add coords to dataframe
420
            data[new_coord_IDs[1]] = new_x1
2✔
421
            data[new_coord_IDs[2]] = new_x2
2✔
422

423
            if old_coord_IDs[0] in data.columns:
2✔
424
                # Rename time dimension.
425
                rename = {old_coord_IDs[0]: new_coord_IDs[0]}
2✔
426
                data = data.rename(rename, axis=1)
2✔
427
            else:
428
                # We don't have a time dimension.
429
                old_coord_IDs = old_coord_IDs[1:]
2✔
430
                new_coord_IDs = new_coord_IDs[1:]
2✔
431

432
        elif isinstance(data, (xr.DataArray, xr.Dataset)):
2✔
433
            data = data.assign_coords(
2✔
434
                {old_coord_IDs[1]: new_x1, old_coord_IDs[2]: new_x2}
435
            )
436

437
            if old_coord_IDs[0] not in data.dims:
2✔
438
                # We don't have a time dimension.
439
                old_coord_IDs = old_coord_IDs[1:]
2✔
440
                new_coord_IDs = new_coord_IDs[1:]
2✔
441

442
            # Rename all dimensions.
443
            rename = {
2✔
444
                old: new for old, new in zip(old_coord_IDs, new_coord_IDs) if old != new
445
            }
446
            data = data.rename(rename)
2✔
447

448
        if isinstance(data, (pd.DataFrame, pd.Series)):
2✔
449
            # Set index back to original
450
            [indexes.remove(old_coord_ID) for old_coord_ID in old_coord_IDs]
2✔
451
            indexes = new_coord_IDs + indexes  # Put dims first
2✔
452
            data = data.set_index(indexes)
2✔
453

454
        return data
2✔
455

456
    def map_array(
2✔
457
        self,
458
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series, np.ndarray],
459
        var_ID: str,
460
        method: Optional[str] = None,
461
        unnorm: bool = False,
462
        add_offset: bool = True,
463
    ):
464
        """
465
        Normalise or unnormalise the data values in an xarray, pandas, or
466
        numpy object.
467

468
        Args:
469
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, :class:`pandas.Series`, or :class:`numpy:numpy.ndarray`):
470
                [Description Needed]
471
            var_ID (str):
472
                [Description Needed]
473
            method (str, optional):
474
                [Description Needed]. Defaults to None.
475
            unnorm (bool, optional):
476
                [Description Needed]. Defaults to False.
477
            add_offset (bool, optional):
478
                [Description Needed]. Defaults to True.
479

480
        Returns:
481
            [Type]:
482
                [Description Needed]
483
        """
484
        if not unnorm and method is None:
2✔
485
            raise ValueError("Must provide `method` if normalising data.")
×
486
        elif unnorm and method is not None and self.config[var_ID]["method"] != method:
2✔
487
            # User has provided a different method to the one used for normalising
488
            raise ValueError(
×
489
                f"Variable '{var_ID}' has been normalised with method '{self.config[var_ID]['method']}', "
490
                f"cannot unnormalise with method '{method}'. Pass `method=None` or"
491
                f"`method='{self.config[var_ID]['method']}'`"
492
            )
493
        if method is None and unnorm:
2✔
494
            # Determine normalisation method from config for unnormalising
495
            method = self.config[var_ID]["method"]
2✔
496
        elif method not in self.valid_methods:
2✔
497
            raise ValueError(
×
498
                f"Method {method} not recognised. Use one of {self.valid_methods}"
499
            )
500

501
        params = self.get_config(var_ID, data, method)
2✔
502

503
        # Linear transformation:
504
        # - Inverse normalisation: y_unnorm = m * y_norm + c
505
        # - Inverse normalisation: y_norm = (1/m) * y_unnorm - c/m
506
        if method == "mean_std":
2✔
507
            m = params["std"]
2✔
508
            c = params["mean"]
2✔
509
        elif method == "min_max":
2✔
510
            m = (params["max"] - params["min"]) / 2
2✔
511
            c = (params["max"] + params["min"]) / 2
2✔
512
        elif method == "positive_semidefinite":
2✔
513
            m = params["std"]
2✔
514
            c = params["min"]
2✔
515
        if not unnorm:
2✔
516
            c = -c / m
2✔
517
            m = 1 / m
2✔
518
        data = data * m
2✔
519
        if add_offset:
2✔
520
            data = data + c
2✔
521
        return data
2✔
522

523
    def map(
2✔
524
        self,
525
        data: Union[xr.DataArray, xr.Dataset, pd.DataFrame, pd.Series],
526
        method: Optional[str] = None,
527
        add_offset: bool = True,
528
        unnorm: bool = False,
529
        assert_computed: bool = False,
530
    ):
531
        """
532
        Normalise or unnormalise the data values and coords in an xarray or
533
        pandas object.
534

535
        Args:
536
            data (:class:`xarray.DataArray`, :class:`xarray.Dataset`, :class:`pandas.DataFrame`, or :class:`pandas.Series`):
537
                [Description Needed]
538
            method (str, optional):
539
                [Description Needed]. Defaults to None.
540
            add_offset (bool, optional):
541
                [Description Needed]. Defaults to True.
542
            unnorm (bool, optional):
543
                [Description Needed]. Defaults to False.
544

545
        Returns:
546
            [Type]:
547
                [Description Needed]
548
        """
549
        if self.deepcopy:
2✔
550
            data = deepcopy(data)
2✔
551

552
        if isinstance(data, (xr.DataArray, xr.Dataset)) and not unnorm:
2✔
553
            self._validate_xr(data)
2✔
554
        elif isinstance(data, (pd.DataFrame, pd.Series)) and not unnorm:
2✔
555
            self._validate_pandas(data)
2✔
556

557
        if isinstance(data, (xr.DataArray, pd.Series)):
2✔
558
            # Single var
559
            var_ID = data.name
2✔
560
            if assert_computed:
2✔
561
                assert self.check_params_computed(
×
562
                    var_ID, method
563
                ), f"{method} normalisation params for {var_ID} not computed."
564
            data = self.map_array(data, var_ID, method, unnorm, add_offset)
2✔
565
        elif isinstance(data, (xr.Dataset, pd.DataFrame)):
2✔
566
            # Multiple vars
567
            for var_ID in data:
2✔
568
                if assert_computed:
2✔
569
                    assert self.check_params_computed(
×
570
                        var_ID, method
571
                    ), f"{method} normalisation params for {var_ID} not computed."
572
                data[var_ID] = self.map_array(
2✔
573
                    data[var_ID], var_ID, method, unnorm, add_offset
574
                )
575

576
        data = self.map_coords(data, unnorm=unnorm)
2✔
577

578
        return data
2✔
579

580
    def __call__(
2✔
581
        self,
582
        data: Union[
583
            xr.DataArray,
584
            xr.Dataset,
585
            pd.DataFrame,
586
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
587
        ],
588
        method: str = "mean_std",
589
        assert_computed: bool = False,
590
    ) -> Union[
591
        xr.DataArray,
592
        xr.Dataset,
593
        pd.DataFrame,
594
        List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
595
    ]:
596
        """
597
        Normalise data.
598

599
        Args:
600
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]):
601
                Data to be normalised. Can be an xarray DataArray, xarray
602
                Dataset, pandas DataFrame, or a list containing objects of
603
                these types.
604
            method (str, optional): Normalisation method. Options include:
605
                - "mean_std": Normalise to mean=0 and std=1 (default)
606
                - "min_max": Normalise to min=-1 and max=1
607
                - "positive_semidefinite": Normalise to min=0 and std=1
608

609
        Returns:
610
            :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
611
                Normalised data. Type or structure depends on the input.
612
        """
613
        if isinstance(data, list):
2✔
614
            return [
2✔
615
                self.map(d, method, unnorm=False, assert_computed=assert_computed)
616
                for d in data
617
            ]
618
        else:
619
            return self.map(data, method, unnorm=False, assert_computed=assert_computed)
2✔
620

621
    def unnormalise(
2✔
622
        self,
623
        data: Union[
624
            xr.DataArray,
625
            xr.Dataset,
626
            pd.DataFrame,
627
            List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
628
        ],
629
        add_offset: bool = True,
630
    ) -> Union[
631
        xr.DataArray,
632
        xr.Dataset,
633
        pd.DataFrame,
634
        List[Union[xr.DataArray, xr.Dataset, pd.DataFrame]],
635
    ]:
636
        """
637
        Unnormalise data.
638

639
        Args:
640
            data (:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]):
641
                Data to unnormalise.
642
            add_offset (bool, optional):
643
                Whether to add the offset to the data when unnormalising. Set
644
                to False to unnormalise uncertainty values (e.g. std dev).
645
                Defaults to True.
646

647
        Returns:
648
            :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
649
                Unnormalised data.
650
        """
651
        if isinstance(data, list):
2✔
652
            return [self.map(d, add_offset=add_offset, unnorm=True) for d in data]
2✔
653
        else:
654
            return self.map(data, add_offset=add_offset, unnorm=True)
2✔
655

656

657
def xarray_to_coord_array_normalised(da: Union[xr.Dataset, xr.DataArray]) -> np.ndarray:
2✔
658
    """
659
    Convert xarray to normalised coordinate array.
660

661
    Args:
662
        da (:class:`xarray.Dataset` | :class:`xarray.DataArray`)
663
            ...
664

665
    Returns:
666
        :class:`numpy:numpy.ndarray`
667
            A normalised coordinate array of shape ``(2, N)``.
668
    """
669
    x1, x2 = da["x1"].values, da["x2"].values
2✔
670
    X1, X2 = np.meshgrid(x1, x2, indexing="ij")
2✔
671
    return np.stack([X1.ravel(), X2.ravel()], axis=0)
2✔
672

673

674
def process_X_mask_for_X(X_mask: xr.DataArray, X: xr.DataArray) -> xr.DataArray:
2✔
675
    """Process X_mask by interpolating to X and converting to boolean.
676

677
    Both X_mask and X are xarray DataArrays with the same spatial coords.
678

679
    Args:
680
        X_mask (:class:`xarray.DataArray`):
681
            ...
682
        X (:class:`xarray.DataArray`):
683
            ...
684

685
    Returns:
686
        :class:`xarray.DataArray`
687
            ...
688
    """
689
    X_mask = X_mask.astype(float).interp_like(
×
690
        X, method="nearest", kwargs={"fill_value": 0}
691
    )
692
    X_mask.data = X_mask.data.astype(bool)
×
693
    X_mask.load()
×
694
    return X_mask
×
695

696

697
def mask_coord_array_normalised(
2✔
698
    coord_arr: np.ndarray, mask_da: Union[xr.DataArray, xr.Dataset, None]
699
) -> np.ndarray:
700
    """
701
    Remove points from (2, N) numpy array that are outside gridded xarray
702
    boolean mask.
703

704
    If `coord_arr` is shape `(2, N)`, then `mask_da` is a shape `(N,)` boolean
705
    array (True if point is inside mask, False if outside).
706

707
    Args:
708
        coord_arr (:class:`numpy:numpy.ndarray`):
709
            ...
710
        mask_da (:class:`xarray.Dataset` | :class:`xarray.DataArray`):
711
            ...
712

713
    Returns:
714
        :class:`numpy:numpy.ndarray`
715
            ...
716
    """
717
    if mask_da is None:
×
718
        return coord_arr
×
719
    mask_da = mask_da.astype(bool)
×
720
    x1, x2 = xr.DataArray(coord_arr[0]), xr.DataArray(coord_arr[1])
×
721
    mask_da = mask_da.sel(x1=x1, x2=x2, method="nearest")
×
722
    return coord_arr[:, mask_da]
×
723

724

725
def da1_da2_same_grid(da1: xr.DataArray, da2: xr.DataArray) -> bool:
2✔
726
    """
727
    Check if ``da1`` and ``da2`` are on the same grid.
728

729
    .. note::
730
        ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``.
731

732
    Args:
733
        da1 (:class:`xarray.DataArray`):
734
            ...
735
        da2 (:class:`xarray.DataArray`):
736
            ...
737

738
    Returns:
739
        bool
740
            Whether ``da1`` and ``da2`` are on the same grid.
741
    """
742
    x1equal = np.array_equal(da1["x1"].values, da2["x1"].values)
×
743
    x2equal = np.array_equal(da1["x2"].values, da2["x2"].values)
×
744
    return x1equal and x2equal
×
745

746

747
def interp_da1_to_da2(da1: xr.DataArray, da2: xr.DataArray) -> xr.DataArray:
2✔
748
    """
749
    Interpolate ``da1`` to ``da2``.
750

751
    .. note::
752
        ``da1`` and ``da2`` are assumed normalised by ``DataProcessor``.
753

754
    Args:
755
        da1 (:class:`xarray.DataArray`):
756
            ...
757
        da2 (:class:`xarray.DataArray`):
758
            ...
759

760
    Returns:
761
        :class:`xarray.DataArray`
762
            Interpolated xarray.
763
    """
764
    return da1.interp(x1=da2["x1"], x2=da2["x2"], method="nearest")
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc