• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

ghiggi / gpm_api / 13679277988

05 Mar 2025 03:17PM UTC coverage: 89.223% (-5.2%) from 94.43%
13679277988

push

github

web-flow
Update PMW Tutorial (#74)

* Fix gridlines removal for cartopy artist update

* Add TC-PRIMED tutorial

* Update PMW 1C tutorial

65 of 181 new or added lines in 10 files covered. (35.91%)

909 existing lines in 41 files now uncovered.

14911 of 16712 relevant lines covered (89.22%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.32
/gpm/utils/xarray.py
1
# -----------------------------------------------------------------------------.
2
# MIT License
3

4
# Copyright (c) 2024 GPM-API developers
5
#
6
# This file is part of GPM-API.
7

8
# Permission is hereby granted, free of charge, to any person obtaining a copy
9
# of this software and associated documentation files (the "Software"), to deal
10
# in the Software without restriction, including without limitation the rights
11
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
# copies of the Software, and to permit persons to whom the Software is
13
# furnished to do so, subject to the following conditions:
14
#
15
# The above copyright notice and this permission notice shall be included in all
16
# copies or substantial portions of the Software.
17
#
18
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
# SOFTWARE.
25

26
# -----------------------------------------------------------------------------.
27
"""This module contains general utility for xarray objects."""
28
import functools
1✔
29

30
import numpy as np
1✔
31
import xarray as xr
1✔
32

33
####-------------------------------------------------------------------
34
#################
35
#### Checker ####
36
#################
37

38

39
def check_is_xarray(x):
1✔
40
    if not isinstance(x, (xr.DataArray, xr.Dataset)):
1✔
41
        raise TypeError("Expecting a xarray.Dataset or xarray.DataArray.")
1✔
42

43

44
def check_is_xarray_dataarray(x):
1✔
45
    if not isinstance(x, xr.DataArray):
1✔
46
        raise TypeError("Expecting a xarray.DataArray.")
1✔
47

48

49
def check_is_xarray_dataset(x):
1✔
50
    if not isinstance(x, xr.Dataset):
1✔
51
        raise TypeError("Expecting a xarray.Dataset.")
1✔
52

53

54
def check_variable_availabilty(ds, variable, argname):
1✔
55
    """Check variable availability in an xarray Dataset."""
56
    if variable is None:
1✔
57
        raise ValueError("Please specify a dataset variable.")
×
58
    if variable not in ds:
1✔
59
        raise ValueError(
1✔
60
            f"{variable} is not a variable of the xarray.Dataset. Invalid {argname} argument.",
61
        )
62

63

64
####-------------------------------------------------------------------
65
###################
66
#### Utilities ####
67
###################
68

69

70
def get_dataset_variables(ds, sort=False):
1✔
71
    """Get list of xarray.Dataset variables."""
72
    variables = list(ds.data_vars)
1✔
73
    if sort:
1✔
74
        variables = sorted(variables)
1✔
75
    return variables
1✔
76

77

78
def get_xarray_variable(xr_obj, variable=None):
1✔
79
    """Return variable DataArray from xarray object.
80

81
    If variable is a xr.DataArray, it returns it
82
    If variable is None and the the input is a xr.DataArray, it returns it
83
    If the input is a xr.Dataset, it returns the specified variable.
84
    """
85
    check_is_xarray(xr_obj)
1✔
86
    if isinstance(variable, xr.DataArray):
1✔
UNCOV
87
        return variable
×
88
    if isinstance(xr_obj, xr.Dataset):
1✔
89
        check_variable_availabilty(xr_obj, variable, argname="variable")
1✔
90
        da = xr_obj[variable]
1✔
91
    else:
92
        da = xr_obj
1✔
93
    return da
1✔
94

95

96
def get_default_variable(ds: xr.Dataset, possible_variables) -> str:
1✔
97
    """Return one of the possible default variables.
98

99
    Check if one of the variables in 'possible_variables' is present in the xarray.Dataset.
100
    If neither variable is present, raise an error.
101
    If both are present, raise an error.
102
    Return the name of the single available variable in the xarray.Dataset
103

104
    Parameters
105
    ----------
106
    ds : xarray.Dataset
107
        The xarray dataset to inspect.
108
    possible_variables : list of str
109
        The variable names to look for.
110

111
    Returns
112
    -------
113
    str
114
        The name of the variable found in the xarray.Dataset.
115
    """
116
    if isinstance(possible_variables, str):
1✔
117
        possible_variables = [possible_variables]
1✔
118
    found_vars = [v for v in possible_variables if v in ds.data_vars]
1✔
119
    if len(found_vars) == 0:
1✔
120
        raise ValueError(f"None of {possible_variables} variables were found in the dataset.")
1✔
121
    if len(found_vars) > 1:
1✔
122
        raise ValueError(f"Multiple variables found: {found_vars}. Please specify which to use.")
1✔
123
    return found_vars[0]
1✔
124

125

126
def get_dimensions_without(xr_obj, dims):
1✔
127
    """Return the dimensions of the xarray object without the specified dimensions."""
128
    if isinstance(dims, str):
1✔
129
        dims = [dims]
1✔
130
    data_dims = np.array(list(xr_obj.dims))
1✔
131
    return data_dims[np.isin(data_dims, dims, invert=True)].tolist()
1✔
132

133

134
def has_unique_chunking(ds):
1✔
135
    """Check if a dataset has unique chunking."""
136
    if not isinstance(ds, xr.Dataset):
1✔
UNCOV
137
        raise ValueError("Input must be an xarray Dataset.")
×
138

139
    # Create a dictionary to store unique chunk shapes for each dimension
140
    unique_chunks_per_dim = {}
1✔
141

142
    # Iterate through each variable's chunks
143
    for var_name in ds.variables:
1✔
144
        if hasattr(ds[var_name].data, "chunks"):  # is dask array
1✔
145
            var_chunks = ds[var_name].data.chunks
1✔
146
            for dim, chunks in zip(ds[var_name].dims, var_chunks, strict=False):
1✔
147
                if dim not in unique_chunks_per_dim:
1✔
148
                    unique_chunks_per_dim[dim] = set()
1✔
149
                    unique_chunks_per_dim[dim].add(chunks)
1✔
150
                if chunks not in unique_chunks_per_dim[dim]:
1✔
151
                    return False
1✔
152

153
    # If all chunks are unique for each dimension, return True
154
    return True
1✔
155

156

157
def ensure_unique_chunking(ds):
1✔
158
    """Ensure the dataset has unique chunking.
159

160
    Conversion to :py:class:`dask.dataframe.DataFrame` requires unique chunking.
161
    If the xarray.Dataset does not have unique chunking, perform ``ds.unify_chunks``.
162

163
    Variable chunks can be visualized with:
164

165
    for var in ds.data_vars:
166
        print(var, ds[var].chunks)
167

168
    """
169
    if not has_unique_chunking(ds):
1✔
170
        ds = ds.unify_chunks()
1✔
171
    return ds
1✔
172

173

174
def _xr_first_data_array(da, dim):
1✔
175
    """Return first valid value of a DataArray along a dimension."""
UNCOV
176
    mask = da.notnull()
×
UNCOV
177
    first_valid_idx = mask.argmax(dim=dim)
×
UNCOV
178
    first_valid_value = da.isel({dim: first_valid_idx})
×
UNCOV
179
    first_valid_value = first_valid_value.where(mask.any(dim=dim))
×
UNCOV
180
    return first_valid_value
×
181

182

183
def xr_first(xr_obj, dim):
1✔
184
    """Return the first valid (non-NaN) value along the specified dimension."""
UNCOV
185
    check_is_xarray(xr_obj)
×
UNCOV
186
    if isinstance(xr_obj, xr.Dataset):
×
UNCOV
187
        for var in xr_obj.data_vars:
×
UNCOV
188
            if dim in xr_obj[var]:
×
UNCOV
189
                xr_obj[var] = _xr_first_data_array(xr_obj[var], dim=dim)
×
UNCOV
190
        return xr_obj
×
UNCOV
191
    return _xr_first_data_array(xr_obj, dim=dim)
×
192

193

194
def _drop_constant_dimension_datarray(da):
1✔
195
    """Drop DataArray dimensions over which all numeric values are equal."""
UNCOV
196
    if not np.issubdtype(da.dtype, np.number):
×
UNCOV
197
        return da
×
198

UNCOV
199
    for dim in list(da.dims):
×
UNCOV
200
        if dim not in da.dims:
×
UNCOV
201
            continue
×
202
        # If the variable is constant along this dimension, drop the other dimensions.
UNCOV
203
        if (da.diff(dim=dim).sum(dim=dim) == 0).all():
×
UNCOV
204
            da = xr_first(da, dim=dim)
×
UNCOV
205
    return da
×
206

207

208
def xr_drop_constant_dimension(xr_obj):
1✔
209
    """Return the first valid (non-NaN) value along the specified dimension."""
UNCOV
210
    check_is_xarray(xr_obj)
×
UNCOV
211
    if isinstance(xr_obj, xr.Dataset):
×
UNCOV
212
        for var in xr_obj.data_vars:
×
UNCOV
213
            xr_obj[var] = _drop_constant_dimension_datarray(xr_obj[var])
×
UNCOV
214
        return xr_obj
×
UNCOV
215
    return _drop_constant_dimension_datarray(xr_obj)
×
216

217

218
def broadcast_like(xr_obj, other, add_coords=True):
1✔
219
    """Broadcast an xarray object against another one."""
UNCOV
220
    xr_obj = xr_obj.broadcast_like(other)
×
UNCOV
221
    if add_coords:
×
UNCOV
222
        xr_obj = xr_obj.assign_coords(other.coords)
×
UNCOV
223
    return xr_obj
×
224

225

226
def xr_sorted_distribution(da, values, dim):
1✔
227
    """
228
    Compute the ranked frequency distribution of integer values along a given dimension.
229

230
    Parameters
231
    ----------
232
    da : xarray.DataArray
233
        The input data array containing integer values along the specified dimension.
234
    values : array-like
235
        An array of the expected values (e.g. np.arange(1, 13) for months,
236
        np.arange(0, 24) for hours, etc.).
237
    dim : str
238
        The name of the dimension along which to compute the ranked distribution (e.g., "year").
239

240
    Returns
241
    -------
242
    ds_out : xarray.Dataset
243
        A dataset with three DataArrays along a new dimension "rank":
244
          - sorted_values: The provided values sorted in descending order of occurrence.
245
          - occurrence: The count of occurrences for each sorted value.
246
          - percentage: The percentage occurrence (relative to the size along `dim`).
247

248
        For each pixel (or location), index along "rank" to retrieve, for example,
249
        the most frequent value at rank 0, the second most at rank 1, etc.
250
    """
UNCOV
251
    values = np.asarray(values)
×
252

UNCOV
253
    def _np_sorted_distribution(arr, values):
×
254
        # Convert to integer type if not already.
UNCOV
255
        arr = arr.astype(np.int64)
×
256
        # Count the occurrences for each expected value.
UNCOV
257
        counts = np.array([np.count_nonzero(arr == v) for v in values])
×
258
        # Sort the expected values in descending order of counts.
UNCOV
259
        sort_idx = np.argsort(counts)[::-1]
×
UNCOV
260
        sorted_values = values[sort_idx].copy()
×
UNCOV
261
        sorted_counts = counts[sort_idx]
×
UNCOV
262
        total = arr.size
×
UNCOV
263
        sorted_percentage = sorted_counts / total * 100.0
×
UNCOV
264
        return sorted_values, sorted_counts, sorted_percentage
×
265

266
    # Define dask_gufunc_kwargs
UNCOV
267
    dask_gufunc_kwargs = {}
×
UNCOV
268
    dask_gufunc_kwargs["output_sizes"] = {"rank": len(values)}
×
269

270
    # Apply the distribution function along the specified dimension.
UNCOV
271
    sorted_vals, occurrence, percentage = xr.apply_ufunc(
×
272
        _np_sorted_distribution,
273
        da,
274
        input_core_dims=[[dim]],
275
        output_core_dims=[["rank"], ["rank"], ["rank"]],
276
        vectorize=True,
277
        dask="parallelized",
278
        kwargs={"values": values},
279
        output_dtypes=[int, int, float],
280
        dask_gufunc_kwargs=dask_gufunc_kwargs,
281
    )
282

UNCOV
283
    ds_out = xr.Dataset(
×
284
        {
285
            "sorted_values": sorted_vals,
286
            "occurrence": occurrence,
287
            "percentage": percentage,
288
        },
289
    )
UNCOV
290
    return ds_out
×
291

292

293
####-------------------------------------------------------------------
294
#### Unstacking dimension
295

296

297
def _check_coord_handling(coord_handling):
1✔
298
    if coord_handling not in {"keep", "drop", "unstack"}:
1✔
UNCOV
299
        raise ValueError("coord_handling must be one of 'keep', 'drop', or 'unstack'.")
×
300

301

302
def _unstack_coordinates(xr_obj, dim, prefix, suffix):
1✔
303
    # Identify coordinates that share the target dimension
304
    coords_with_dim = _get_non_dimensional_coordinates(xr_obj, dim=dim)
1✔
305
    ds = xr.Dataset()
1✔
306
    for coord_name in coords_with_dim:
1✔
307
        coord_da = xr_obj[coord_name]
1✔
308
        # Split the coordinate DataArray along the target dimension, drop coordinate and merge
309
        split_ds = unstack_datarray_dimension(coord_da, coord_handling="drop", dim=dim, prefix=prefix, suffix=suffix)
1✔
310
        ds.update(split_ds)
1✔
311
    return ds
1✔
312

313

314
def _handle_unstack_non_dim_coords(ds, source_xr_obj, coord_handling, dim, prefix, suffix):
1✔
315
    # Deal with coordinates sharing the target dimension
316
    if coord_handling == "keep":
1✔
317
        return ds
1✔
318
    if coord_handling == "unstack":
1✔
319
        ds_coords = _unstack_coordinates(source_xr_obj, dim=dim, prefix=prefix, suffix=suffix)
1✔
320
        ds.update(ds_coords)
1✔
321
    # Remove non dimensional coordinates (unstack and drop coord_handling)
322
    ds = ds.drop_vars(_get_non_dimensional_coordinates(ds, dim=dim))
1✔
323
    return ds
1✔
324

325

326
def _get_non_dimensional_coordinates(xr_obj, dim):
1✔
327
    return [coord_name for coord_name, coord_da in xr_obj.coords.items() if dim in coord_da.dims and coord_name != dim]
1✔
328

329

330
def unstack_datarray_dimension(da, dim, coord_handling="keep", prefix="", suffix=""):
1✔
331
    """
332
    Split a DataArray along a specified dimension into a Dataset with separate prefixed and suffixed variables.
333

334
    Parameters
335
    ----------
336
    da : xarray.DataArray
337
        The DataArray to split.
338
    dim : str
339
        The dimension along which to split the DataArray.
340
    coord_handling : str, optional
341
        Option to handle coordinates sharing the target dimension.
342
        Choices are 'keep', 'drop', or 'unstack'. Defaults to 'keep'.
343
    prefix : str, optional
344
        String to prepend to each new variable name.
345
    suffix : str, optional
346
        String to append to each new variable name.
347

348
    Returns
349
    -------
350
    xarray.Dataset
351
        A Dataset with each variable split along the specified dimension.
352
        The Dataset variables are named  "{prefix}{name}{suffix}{dim_value}".
353
        Coordinates sharing the target dimension are handled based on `coord_handling`.
354
    """
355
    # Retrieve DataArray name
356
    name = da.name
1✔
357
    # Unstack variables
358
    ds = da.to_dataset(dim=dim)
1✔
359
    rename_dict = {dim_value: f"{prefix}{name}{suffix}{dim_value}" for dim_value in list(ds.data_vars)}
1✔
360
    ds = ds.rename_vars(rename_dict)
1✔
361
    # Deal with coordinates sharing the target dimension
362
    return _handle_unstack_non_dim_coords(
1✔
363
        ds=ds,
364
        source_xr_obj=da,
365
        coord_handling=coord_handling,
366
        dim=dim,
367
        prefix=prefix,
368
        suffix=suffix,
369
    )
370

371

372
def unstack_dataset_dimension(ds, dim, coord_handling="keep", prefix="", suffix=""):
1✔
373
    """
374
    Split Dataset variables with the specified dimension into separate prefixed and suffixed variables.
375

376
    Parameters
377
    ----------
378
    ds : xarray.Dataset
379
        The DataArray to split.
380
    dim : str
381
        The dimension along which to split the DataArray.
382
    coord_handling : str, optional
383
        Option to handle coordinates sharing the target dimension.
384
        Choices are 'keep', 'drop', or 'unstack'. Defaults to 'keep'.
385
    prefix : str, optional
386
        String to prepend to each new variable name.
387
    suffix : str, optional
388
        String to append to each new variable name.
389

390
    Returns
391
    -------
392
    xr.Dataset
393
        A Dataset with each variable with dimension `dim` split into new variables.
394
        The new Dataset variables are named "{prefix}{name}{suffix}{dim_value}".
395
        Coordinates sharing the target dimension are handled based on `coord_handling`.
396
    """
397
    # Identify variables that have the target dimension
398
    variables_to_split = [var for var in ds.data_vars if dim in ds[var].dims]
1✔
399

400
    # Identify variables that do NOT have the target dimension
401
    variables_to_keep = [var for var in ds.data_vars if dim not in ds[var].dims]
1✔
402

403
    # Initialize the new Dataset with variables to keep
404
    ds_unstacked = ds[variables_to_keep].copy()
1✔
405

406
    # Loop over DataArray
407
    for var in variables_to_split:
1✔
408
        ds_unstacked.update(
1✔
409
            unstack_datarray_dimension(ds[var], dim=dim, coord_handling="keep", prefix=prefix, suffix=suffix),
410
        )
411

412
    # Deal with coordinates sharing the target dimension
413
    ds_unstacked = _handle_unstack_non_dim_coords(
1✔
414
        ds=ds_unstacked,
415
        source_xr_obj=ds,
416
        dim=dim,
417
        coord_handling=coord_handling,
418
        prefix=prefix,
419
        suffix=suffix,
420
    )
421
    return ds_unstacked
1✔
422

423

424
def unstack_dimension(xr_obj, dim, coord_handling="keep", prefix="", suffix=""):
1✔
425
    """
426
    Split xarray object with the specified dimension into separate prefixed and suffixed Dataset variables.
427

428
    Parameters
429
    ----------
430
    xr_obj : xarray.DataArray, xarray.Dataset
431
        The DataArray to split.
432
    dim : str
433
        The dimension along which to split the DataArray.
434
    coord_handling : str, optional
435
        Option to handle coordinates sharing the target dimension.
436
        Choices are 'keep', 'drop', or 'unstack'. Defaults to 'keep'.
437
    prefix : str, optional
438
        String to prepend to each new variable name.
439
    suffix : str, optional
440
        String to append to each new variable name.
441

442
    Returns
443
    -------
444
    xr.Dataset
445
        A Dataset with each variable with dimension `dim` split into new variables.
446
        The new Dataset variables are named "{prefix}{name}{suffix}{dim_value}".
447
        Coordinates sharing the target dimension are handled based on `coord_handling`.
448
    """
449
    check_is_xarray(xr_obj)
1✔
450
    _check_coord_handling(coord_handling)
1✔
451
    if isinstance(xr_obj, xr.DataArray):
1✔
452
        return unstack_datarray_dimension(xr_obj, dim=dim, coord_handling=coord_handling, prefix=prefix, suffix=suffix)
1✔
453
    return unstack_dataset_dimension(xr_obj, dim=dim, coord_handling=coord_handling, prefix=prefix, suffix=suffix)
1✔
454

455

456
####-------------------------------------------------------------------
457
####################
458
#### Decorators ####
459
####################
460

461

462
def ensure_dim_order_dataarray(da, func, *args, **kwargs):
1✔
463
    """Ensure that the output DataArray has the same dimensions order as the input.
464

465
    New dimensions are moved to the last positions.
466
    """
467
    # Get the original dimension order
468
    original_dims = da.dims
1✔
469
    dict_coord_dims = {coord: da[coord].dims for coord in list(da.coords)}
1✔
470

471
    # Apply the function to the DataArray
472
    da_out = func(da, *args, **kwargs)
1✔
473

474
    # Check output type
475
    if not isinstance(da_out, xr.DataArray):
1✔
476
        raise TypeError("The function does not return a xr.DataArray.")
1✔
477

478
    # Check which of the original dimensions are still present
479
    dim_order = [dim for dim in original_dims if dim in da_out.dims]
1✔
480

481
    # Transpose the result to ensure the same dimension order
482
    da_out = da_out.transpose(*dim_order, ...)
1✔
483

484
    # Transpose the coordinates to
485
    for coord in list(da_out.coords):
1✔
486
        if coord in dict_coord_dims:
1✔
487
            dim_order = [dim for dim in dict_coord_dims[coord] if dim in da_out[coord].dims]
1✔
488
            da_out[coord] = da_out[coord].transpose(*dim_order, ...)
1✔
489
    return da_out
1✔
490

491

492
def ensure_dim_order_dataset(ds, func, *args, **kwargs):
1✔
493
    """Ensure that the output Dataset has the same dimensions order as the input.
494

495
    New dimensions are moved to the last positions.
496
    """
497
    # Get the original dimension order
498
    dict_coord_dims = {coord: ds[coord].dims for coord in list(ds.coords)}
1✔
499
    dict_var_dims = {var: ds[var].dims for var in list(ds.data_vars)}
1✔
500

501
    # Apply the function to the Dataset
502
    ds_out = func(ds, *args, **kwargs)
1✔
503

504
    if not isinstance(ds_out, xr.Dataset):
1✔
505
        raise TypeError("The function does not return a xr.Dataset.")
1✔
506

507
    # Check which of the original variables and dimensions are still present and reorder
508
    for var in list(ds_out.data_vars):
1✔
509
        if var in dict_var_dims:
1✔
510
            dim_order = [dim for dim in dict_var_dims[var] if dim in ds_out[var].dims]
1✔
511
            ds_out[var] = ds_out[var].transpose(*dim_order, ...)
1✔
512
    for coord in list(ds_out.coords):
1✔
513
        if coord in dict_coord_dims:
1✔
514
            dim_order = [dim for dim in dict_coord_dims[coord] if dim in ds_out[coord].dims]
1✔
515
            ds_out[coord] = ds_out[coord].transpose(*dim_order, ...)
1✔
516

517
    return ds_out
1✔
518

519

520
def xr_ensure_dimension_order(func):
1✔
521
    """Decorator which ensures the output xarray object has same dimension order as input.
522

523
    The decorator expects that the functions return the same type of xarray object !
524

525
    The decorator can deal with functions that:
526
    - returns an xarray object with new dimensions
527
    - returns an xarray object with less dimensions than the originals
528

529
    New dimensions are moved to the last positions.
530
    """
531

532
    @functools.wraps(func)
1✔
533
    def wrapper(*args, **kwargs):
1✔
534
        xr_obj = args[0]  # Assuming the first argument is the dataset
1✔
535
        if isinstance(xr_obj, xr.Dataset):
1✔
536
            return ensure_dim_order_dataset(xr_obj, func, *args[1:], **kwargs)
1✔
537
        return ensure_dim_order_dataarray(xr_obj, func, *args[1:], **kwargs)
1✔
538

539
    return wrapper
1✔
540

541

542
@xr_ensure_dimension_order
1✔
543
def squeeze_unsqueeze_dataarray(da, func, *args, **kwargs):
1✔
544
    """Ensure that the output DataArray has the same dimensions as the input.
545

546
    Dimensions of size 1 are kept also if the function drop them !
547
    New dimensions are moved to the last positions.
548
    """
549
    # Retrieve dimension to be squeezed
550
    original_dims = set(da.dims)
1✔
551
    squeezed_dims = original_dims - set(da.squeeze().dims)
1✔
552

553
    # List coordinates which are squeezed
554
    dict_squeezed = {dim: [] for dim in squeezed_dims}
1✔
555
    for dim in squeezed_dims:
1✔
556
        for coord in list(da.coords):
1✔
557
            if dim in da[coord].dims:
1✔
558
                dict_squeezed[dim].append(coord)
1✔
559
    # Squeeze
560
    da = da.squeeze()
1✔
561

562
    # Apply function
563
    da = func(da, *args, **kwargs)  # Call the function with the squeezed dataset
1✔
564

565
    # Check output type
566
    if not isinstance(da, xr.DataArray):
1✔
567
        raise TypeError("The function does not return a xr.DataArray.")
1✔
568

569
    # Unsqueeze back
570
    for dim, coords in dict_squeezed.items():
1✔
571
        if dim not in da.dims:
1✔
572
            da = da.expand_dims(dim=dim, axis=None)
1✔
573
        for coord in coords:
1✔
574
            if dim not in da[coord].dims:  # coord with same name as dim are automatically expanded !
1✔
575
                da[coord] = da[coord].expand_dims(dim=dim, axis=None)
1✔
576

577
    # Deal with coordinates named as dimension but without such dimension !
578
    # for dim, coords in dict_squeezed.items():
579
    #     if len(coords) == 0 and dim in da.coords:
580
    #         scalar_coord_value = da[dim].data[0]
581
    #         da = da.drop_vars(dim)
582
    #         da = da.assign_coords({"___tmp_coord__": scalar_coord_value}).rename({"___tmp_coord__": dim})
583
    return da
1✔
584

585

586
@xr_ensure_dimension_order
1✔
587
def squeeze_unsqueeze_dataset(ds, func, *args, **kwargs):
1✔
588
    """Ensure that the output Dataset has the same dimensions as the input.
589

590
    Dimensions of size 1 are kept also if the function drop them !
591
    New dimensions are moved to the last positions.
592
    """
593
    # Retrieve dimension to be squeezed
594
    original_dims = set(ds.dims)
1✔
595
    squeezed_dims = original_dims - set(ds.squeeze().dims)
1✔
596

597
    # List coordinates which are squeezed
598
    dict_squeezed = {dim: [] for dim in squeezed_dims}
1✔
599
    for dim in squeezed_dims:
1✔
600
        for var in ds.variables:  # coords + variables
1✔
601
            if dim in ds[var].dims:
1✔
602
                dict_squeezed[dim].append(var)
1✔
603
    # Squeeze
604
    ds = ds.squeeze()
1✔
605

606
    # Apply function
607
    ds = func(ds, *args, **kwargs)  # Call the function with the squeezed dataset
1✔
608

609
    # Check output type
610
    if not isinstance(ds, xr.Dataset):
1✔
611
        raise TypeError("The function does not return a xr.Dataset.")
1✔
612

613
    # Unsqueeze back
614
    for dim, variables in dict_squeezed.items():
1✔
615
        for var in variables:
1✔
616
            if dim not in ds[var].dims:
1✔
617
                ds[var] = ds[var].expand_dims(dim=dim, axis=None)  # not same order as start
1✔
618
    return ds
1✔
619

620

621
def xr_squeeze_unsqueeze(func):
1✔
622
    """Decorator that squeeze-unsqueeze the xarray object before passing it to the function.
623

624
    This decorator allow to keep the dimensions of the xarray object intact.
625
    Dimensions of size 1 are kept also if the function drop them.
626
    The dimension order of the arrays is conserved.
627
    New dimensions are moved to the last positions.
628

629
    """
630

631
    @functools.wraps(func)
1✔
632
    def wrapper(*args, **kwargs):
1✔
633
        xr_obj = args[0]  # Assuming the first argument is the dataset
1✔
634
        if isinstance(xr_obj, xr.Dataset):
1✔
635
            return squeeze_unsqueeze_dataset(xr_obj, func, *args[1:], **kwargs)
1✔
636
        return squeeze_unsqueeze_dataarray(xr_obj, func, *args[1:], **kwargs)
1✔
637

638
    return wrapper
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc