• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / xscen / 5893860722

17 Aug 2023 05:08PM UTC coverage: 51.927% (+2.1%) from 49.851%
5893860722

Pull #238

github

web-flow
Merge dcb42bd64 into 9f2266370
Pull Request #238: Add diagnostics.health_checks

117 of 117 new or added lines in 1 file covered. (100.0%)

72 existing lines in 1 file now uncovered.

1792 of 3451 relevant lines covered (51.93%)

1.56 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

57.31
/xscen/diagnostics.py
1
# noqa: D100
2
import logging
3✔
3
import warnings
3✔
4
from collections.abc import Sequence
3✔
5
from copy import deepcopy
3✔
6
from pathlib import Path, PosixPath
3✔
7
from types import ModuleType
3✔
8
from typing import Optional, Tuple, Union
3✔
9

10
import numpy as np
3✔
11
import xarray as xr
3✔
12
import xclim as xc
3✔
13
import xclim.core.dataflags
3✔
14
from xclim.core.indicator import Indicator
3✔
15

16
from .config import parse_config
3✔
17
from .indicators import load_xclim_module
3✔
18
from .utils import change_units, clean_up, standardize_periods, unstack_fill_nan
3✔
19

20
logger = logging.getLogger(__name__)
3✔
21

22
__all__ = [
3✔
23
    "health_checks",
24
    "properties_and_measures",
25
    "measures_heatmap",
26
    "measures_improvement",
27
]
28

29

30
def health_checks(
3✔
31
    ds: Union[xr.Dataset, xr.DataArray],
32
    *,
33
    structure: dict = None,
34
    calendar: str = None,
35
    start_date: str = None,
36
    end_date: str = None,
37
    variables_and_units: dict = None,
38
    cfchecks: dict = None,
39
    freq: str = None,
40
    missing: Union[dict, str, list] = None,
41
    flags: dict = None,
42
    flags_kwargs: dict = None,
43
    return_flags: bool = False,
44
    raise_on: list = None,
45
) -> Union[None, xr.Dataset]:
46
    """
47
    Perform a series of health checks on the dataset. Be aware that missing data checks and flag checks can be slow.
48

49
    Parameters
50
    ----------
51
    ds: xr.Dataset | xr.DataArray
52
        Dataset to check.
53
    structure: dict
54
        Dictionary with keys "dims" and "coords" containing the expected dimensions and coordinates.
55
    calendar: str
56
        Expected calendar. Synonyms should be detected correctly (e.g. "standard" and "gregorian").
57
    start_date: str
58
        To check if the dataset starts at least at this date.
59
    end_date: str
60
        To check if the dataset ends at least at this date.
61
    variables_and_units: dict
62
        Dictionary containing the expected variables and units.
63
    cfchecks: dict
64
        Dictionary where the key is the variable to check and the values are the cfchecks.
65
        The cfchecks themselves must be a dictionary with the keys being the cfcheck names and the values being the arguments to pass to the cfcheck.
66
        See `xclim.core.cfchecks` for more details.
67
    freq: str
68
        Expected frequency, written as the result of xr.infer_freq(ds.time).
69
    missing: dict | str | list
70
        String, list of strings, or dictionary where the key is the method to check for missing data and the values are the arguments to pass to the method.
71
        The methods are: "missing_any", "at_least_n_valid", "missing_pct", "missing_wmo". See :py:func:`xclim.core.missing` for more details.
72
    flags: dict
73
        Dictionary where the key is the variable to check and the values are the flags.
74
        The flags themselves must be a dictionary with the keys being the data_flags names and the values being the arguments to pass to the data_flags.
75
        If `None` is passed instead of a dictionary, then xclim's default flags for the given variable are run. See :py:data:`xclim.core.utils.VARIABLES`.
76
        See :py:func:`xclim.core.dataflags.data_flags` for the list of possible flags.
77
    flags_kwargs: dict
78
        Additional keyword arguments to pass to the data_flags ("dims" and "freq").
79
    return_flags: bool
80
        Whether to return the Dataset created by data_flags.
81
    raise_on: list
82
        Whether to raise an error if a check fails, else there will only be a warning. The possible values are the names of the checks.
83
        Use ["all"] to raise on all checks.
84

85
    Returns
86
    -------
87
    xr.Dataset | None
88
        Dataset containing the flags if return_flags is True & raise_on is False for the "flags" check.
89
    """
90
    if isinstance(ds, xr.DataArray):
3✔
91
        ds = ds.to_dataset()
3✔
92
    raise_on = raise_on or []
3✔
93
    if "all" in raise_on:
3✔
94
        raise_on = [
3✔
95
            "structure",
96
            "calendar",
97
            "start_date",
98
            "end_date",
99
            "variables_and_units",
100
            "cfchecks",
101
            "freq",
102
            "missing",
103
            "flags",
104
        ]
105

106
    # Check the dimensions and coordinates
107
    if structure is not None:
3✔
108
        if "dims" in structure:
3✔
109
            for dim in structure["dims"]:
3✔
110
                if dim not in ds.dims:
3✔
111
                    err = f"The dimension '{dim}' is missing."
3✔
112
                    if "structure" in raise_on:
3✔
113
                        raise ValueError(err)
3✔
114
                    else:
115
                        warnings.warn(err, UserWarning, stacklevel=1)
3✔
116
        if "coords" in structure:
3✔
117
            for coord in structure["coords"]:
3✔
118
                if coord not in ds.coords:
3✔
119
                    if coord in ds.data_vars:
3✔
120
                        err = f"'{coord}' is detected as a data variable, not a coordinate."
3✔
121
                        if "structure" in raise_on:
3✔
122
                            raise ValueError(err)
3✔
123
                        else:
124
                            warnings.warn(err, UserWarning, stacklevel=1)
3✔
125
                    else:
126
                        err = f"The coordinate '{coord}' is missing."
3✔
127
                        if "structure" in raise_on:
3✔
128
                            raise ValueError(err)
3✔
129
                        else:
130
                            warnings.warn(err, UserWarning, stacklevel=1)
3✔
131

132
    # Check the calendar
133
    if calendar is not None:
3✔
134
        cal = xc.core.calendar.get_calendar(ds.time)
3✔
135
        if xc.core.calendar.common_calendar([calendar]).replace(
3✔
136
            "default", "standard"
137
        ) != xc.core.calendar.common_calendar([cal]).replace("default", "standard"):
138
            err = f"The calendar is not '{calendar}'. Received '{cal}'."
3✔
139
            if "calendar" in raise_on:
3✔
140
                raise ValueError(err)
3✔
141
            else:
142
                warnings.warn(err, UserWarning, stacklevel=1)
3✔
143

144
    # Check the start/end dates
145
    if (start_date is not None) or (end_date is not None):
3✔
146
        if isinstance(ds.time.min().values, np.datetime64):
3✔
147
            ds_start = xr.cftime_range(
3✔
148
                start=str(ds.time.min().values.astype("datetime64[D]")),
149
                periods=1,
150
                freq="D",
151
                calendar=ds.time.dt.calendar,
152
            )[0]
153
            ds_end = xr.cftime_range(
3✔
154
                start=str(ds.time.max().values.astype("datetime64[D]")),
155
                periods=1,
156
                freq="D",
157
                calendar=ds.time.dt.calendar,
158
            )[0]
159
        else:
160
            ds_start = ds.time.min().values
3✔
161
            ds_end = ds.time.max().values
3✔
162
    if start_date is not None:
3✔
163
        # Create cf_time objects to compare the dates
164
        start_date = xr.cftime_range(
3✔
165
            start=start_date, periods=1, freq="D", calendar=ds.time.dt.calendar
166
        )[0]
167
        if not ((ds_start <= start_date) and (ds_end > start_date)):
3✔
168
            err = f"The start date is not at least {start_date}. Received {ds.time.min().values.astype('datetime64[m]')}."
3✔
169
            if "start_date" in raise_on:
3✔
170
                raise ValueError(err)
3✔
171
            else:
172
                warnings.warn(err, UserWarning, stacklevel=1)
3✔
173
    if end_date is not None:
3✔
174
        # Create cf_time objects to compare the dates
175
        end_date = xr.cftime_range(
3✔
176
            start=end_date, periods=1, freq="D", calendar=ds.time.dt.calendar
177
        )[0]
178

179
        if not ((ds_start < end_date) and (ds_end >= end_date)):
3✔
180
            err = f"The end date is not at least {end_date}. Received {ds.time.max().values.astype('datetime64[m]')}."
3✔
181
            if "end_date" in raise_on:
3✔
182
                raise ValueError(err)
3✔
183
            else:
184
                warnings.warn(err, UserWarning, stacklevel=1)
3✔
185

186
    # Check variables
187
    if variables_and_units is not None:
3✔
188
        for v in variables_and_units:
3✔
189
            if v not in ds:
3✔
190
                raise ValueError(f"The variable '{v}' is missing.")
3✔
191
            if ds[v].attrs.get("units", None) != variables_and_units[v]:
3✔
192
                xc.core.units.check_units(
3✔
193
                    ds[v], variables_and_units[v]
194
                )  # Will always raise an error if the units are not compatible
195
                err = f"The variable '{v}' does not have the expected units '{variables_and_units[v]}'. Received '{ds[v].attrs['units']}'."
3✔
196
                if "variables_and_units" in raise_on:
3✔
197
                    raise ValueError(err)
3✔
198
                else:
199
                    warnings.warn(err, UserWarning, stacklevel=1)
3✔
200

201
    # Check CF conventions
202
    if cfchecks is not None:
3✔
203
        cfchecks = deepcopy(cfchecks)
3✔
204
        for v in cfchecks:
3✔
205
            for check in cfchecks[v]:
3✔
206
                if check == "check_valid":
3✔
207
                    cfchecks[v][check]["var"] = ds[v]
3✔
208
                elif check == "cfcheck_from_name":
3✔
209
                    cfchecks[v][check].setdefault("varname", v)
3✔
210
                    cfchecks[v][check]["vardata"] = ds[v]
3✔
211
                else:
212
                    raise ValueError(f"Check '{check}' is not in xclim.")
3✔
213
                if "cfchecks" in raise_on:
3✔
214
                    warnings.simplefilter("error")
3✔
215
                    try:
3✔
216
                        getattr(xc.core.cfchecks, check)(**cfchecks[v][check])
3✔
217
                    except UserWarning as e:
3✔
218
                        raise ValueError(e)
3✔
219
                    warnings.resetwarnings()
3✔
220
                else:
221
                    getattr(xc.core.cfchecks, check)(**cfchecks[v][check])
3✔
222

223
    if freq is not None:
3✔
224
        inferred_freq = xr.infer_freq(ds.time)
3✔
225
        if inferred_freq is None:
3✔
226
            err = "The timesteps are irregular or cannot be inferred by xarray."
3✔
227
            if "freq" in raise_on:
3✔
228
                raise ValueError(err)
3✔
229
            else:
230
                warnings.warn(err, UserWarning, stacklevel=1)
3✔
231
        elif freq.replace("YS", "AS-JAN") != inferred_freq:
3✔
232
            err = f"The frequency is not '{freq}'. Received '{inferred_freq}'."
3✔
233
            if "freq" in raise_on:
3✔
234
                raise ValueError(err)
3✔
235
            else:
236
                warnings.warn(err, UserWarning, stacklevel=1)
3✔
237

238
    if missing is not None:
3✔
239
        inferred_freq = xr.infer_freq(ds.time)
3✔
240
        if inferred_freq not in ["M", "MS", "D", "H"]:
3✔
241
            warnings.warn(
3✔
242
                f"Frequency {inferred_freq} is not supported for missing data checks. That check will be skipped.",
243
                UserWarning,
244
                stacklevel=1,
245
            )
246
        else:
247
            if isinstance(missing, str):
3✔
248
                missing = {missing: {}}
3✔
249
            elif isinstance(missing, list):
3✔
250
                missing = {m: {} for m in missing}
3✔
251
            for method, kwargs in missing.items():
3✔
252
                kwargs.setdefault("freq", "YS")
3✔
253
                for v in ds.data_vars:
3✔
254
                    ms = getattr(xc.core.missing, method)(ds[v], **kwargs)
3✔
255
                    if ms.any():
3✔
256
                        err = f"The variable '{v}' has missing values according to the '{method}' method."
3✔
257
                        if "missing" in raise_on:
3✔
258
                            raise ValueError(err)
3✔
259
                        else:
260
                            warnings.warn(err, UserWarning, stacklevel=1)
3✔
261

262
    if flags is not None:
3✔
263
        if return_flags:
3✔
264
            out = xr.Dataset()
3✔
265
        for v in flags:
3✔
266
            dsflags = xc.core.dataflags.data_flags(
3✔
267
                ds[v],
268
                ds,
269
                flags=flags[v],
270
                raise_flags="flags" in raise_on,
271
                **(flags_kwargs or {}),
272
            )
273
            if (("flags" in raise_on) is False) and (
3✔
274
                np.any([dsflags[dv] for dv in dsflags.data_vars])
275
            ):
276
                bad_checks = [dv for dv in dsflags.data_vars if dsflags[dv].any()]
3✔
277
                warnings.warn(
3✔
278
                    f"Data quality flags indicate suspicious values for the variable '{v}'. Flags raised are: {bad_checks}.",
279
                    UserWarning,
280
                    stacklevel=1,
281
                )
282
            if return_flags:
3✔
283
                dsflags = dsflags.rename({dv: f"{v}_{dv}" for dv in dsflags.data_vars})
3✔
284
                out = xr.merge([out, dsflags])
3✔
285
        if return_flags:
3✔
286
            return out
3✔
287

288

289
# TODO: just measures?
290
@parse_config
3✔
291
def properties_and_measures(
3✔
292
    ds: xr.Dataset,
293
    properties: Union[
294
        str, PosixPath, Sequence[Indicator], Sequence[tuple[str, Indicator]], ModuleType
295
    ],
296
    period: list = None,
297
    unstack: bool = False,
298
    rechunk: dict = None,
299
    dref_for_measure: Optional[xr.Dataset] = None,
300
    change_units_arg: Optional[dict] = None,
301
    to_level_prop: str = "diag-properties",
302
    to_level_meas: str = "diag-measures",
303
):
304
    """Calculate properties and measures of a dataset.
305

306
    Parameters
307
    ----------
308
    ds : xr.Dataset
309
        Input dataset.
310
    properties : Union[str, PosixPath, Sequence[Indicator], Sequence[Tuple[str, Indicator]]]
311
        Path to a YAML file that instructs on how to calculate properties.
312
        Can be the indicator module directly, or a sequence of indicators or a sequence of
313
        tuples (indicator name, indicator) as returned by `iter_indicators()`.
314
    period : list
315
        [start, end] of the period to be evaluated. The period will be selected on ds
316
        and dref_for_measure if it is given.
317
    unstack : bool
318
        Whether to unstack ds before computing the properties.
319
    rechunk : dict
320
        Dictionary of chunks to use for a rechunk before computing the properties.
321
    dref_for_measure : xr.Dataset
322
        Dataset of properties to be used as the ref argument in the computation of the measure.
323
        Ideally, this is the first output (prop) of a previous call to this function.
324
        Only measures on properties that are provided both in this dataset and in the properties list will be computed.
325
        If None, the second output of the function (meas) will be an empty Dataset.
326
    change_units_arg : dict
327
        If not None, calls `xscen.utils.change_units` on ds before computing properties using
328
        this dictionary for the `variables_and_units` argument.
329
        It can be useful to convert units before computing the properties, because it is sometimes
330
        easier to convert the units of the variables than the units of the properties (eg. variance).
331
    to_level_prop : str
332
        processing_level to give the first output (prop)
333
    to_level_meas : str
334
        processing_level to give the second output (meas)
335

336
    Returns
337
    -------
338
    prop : xr.Dataset
339
        Dataset of properties of ds
340
    meas : xr.Dataset
341
        Dataset of measures between prop and dref_for_meas
342

343
    See Also
344
    --------
345
    xclim.sdba.properties, xclim.sdba.measures, xclim.core.indicator.build_indicator_module_from_yaml
346
    """
UNCOV
347
    if isinstance(properties, (str, Path)):
×
348
        logger.debug("Loading properties module.")
×
UNCOV
349
        module = load_xclim_module(properties)
×
UNCOV
350
        properties = module.iter_indicators()
×
UNCOV
351
    elif hasattr(properties, "iter_indicators"):
×
UNCOV
352
        properties = properties.iter_indicators()
×
353

UNCOV
354
    try:
×
355
        N = len(properties)
×
356
    except TypeError:
×
UNCOV
357
        N = None
×
358
    else:
359
        logger.info(f"Computing {N} properties.")
×
360

361
    period = standardize_periods(period, multiple=False)
×
362
    # select period for ds
UNCOV
363
    if period is not None and "time" in ds:
×
364
        ds = ds.sel({"time": slice(period[0], period[1])})
×
365

366
    # select periods for ref_measure
367
    if (
×
368
        dref_for_measure is not None
369
        and period is not None
370
        and "time" in dref_for_measure
371
    ):
372
        dref_for_measure = dref_for_measure.sel({"time": slice(period[0], period[1])})
×
373

374
    if unstack:
×
375
        ds = unstack_fill_nan(ds)
×
376

377
    if rechunk:
×
378
        ds = ds.chunk(rechunk)
×
379

UNCOV
380
    if change_units_arg:
×
381
        ds = change_units(ds, variables_and_units=change_units_arg)
×
382

UNCOV
383
    prop = xr.Dataset()  # dataset with all properties
×
UNCOV
384
    meas = xr.Dataset()  # dataset with all measures
×
UNCOV
385
    for i, ind in enumerate(properties, 1):
×
386
        if isinstance(ind, tuple):
×
387
            iden, ind = ind
×
388
        else:
UNCOV
389
            iden = ind.identifier
×
390
        # Make the call to xclim
391
        logger.info(f"{i} - Computing {iden}.")
×
392
        out = ind(ds=ds)
×
393
        vname = out.name
×
394
        prop[vname] = out
×
395

UNCOV
396
        if period is not None:
×
397
            prop[vname].attrs["period"] = f"{period[0]}-{period[1]}"
×
398

399
        # calculate the measure if a reference dataset is given for the measure
400
        if dref_for_measure and vname in dref_for_measure:
×
401
            meas[vname] = ind.get_measure()(
×
402
                sim=prop[vname], ref=dref_for_measure[vname]
403
            )
404
            # create a merged long_name
UNCOV
405
            prop_ln = prop[vname].attrs.get("long_name", "").replace(".", "")
×
UNCOV
406
            meas_ln = meas[vname].attrs.get("long_name", "").lower()
×
UNCOV
407
            meas[vname].attrs["long_name"] = f"{prop_ln} {meas_ln}"
×
408

UNCOV
409
    for ds1 in [prop, meas]:
×
UNCOV
410
        ds1.attrs = ds.attrs
×
UNCOV
411
        ds1.attrs["cat:xrfreq"] = "fx"
×
UNCOV
412
        ds1.attrs.pop("cat:variable", None)
×
UNCOV
413
        ds1.attrs["cat:frequency"] = "fx"
×
414

415
        # to be able to save in zarr, convert object to string
UNCOV
416
        if "season" in ds1:
×
UNCOV
417
            ds1["season"] = ds1.season.astype("str")
×
418

UNCOV
419
    prop.attrs["cat:processing_level"] = to_level_prop
×
UNCOV
420
    meas.attrs["cat:processing_level"] = to_level_meas
×
421

UNCOV
422
    return prop, meas
×
423

424

425
def measures_heatmap(meas_datasets: Union[list, dict], to_level: str = "diag-heatmap"):
3✔
426
    """Create a heatmap to compare the performance of the different datasets.
427

428
    The columns are properties and the rows are datasets.
429
    Each point is the absolute value of the mean of the measure over the whole domain.
430
    Each column is normalized from 0 (best) to 1 (worst).
431

432
    Parameters
433
    ----------
434
    meas_datasets : list or dict
435
        List or dictionary of datasets of measures of properties.
436
        If it is a dictionary, the keys will be used to name the rows.
437
        If it is a list, the rows will be given a number.
438
    to_level: str
439
        processing_level to assign to the output
440

441
    Returns
442
    -------
443
    xr.DataArray
444
    """
UNCOV
445
    name_of_datasets = None
×
UNCOV
446
    if isinstance(meas_datasets, dict):
×
447
        name_of_datasets = list(meas_datasets.keys())
×
UNCOV
448
        meas_datasets = list(meas_datasets.values())
×
449

UNCOV
450
    hmap = []
×
UNCOV
451
    for meas in meas_datasets:
×
UNCOV
452
        row = []
×
453
        # iterate through all available properties
UNCOV
454
        for var_name in meas:
×
UNCOV
455
            da = meas[var_name]
×
456
            # mean the absolute value of the bias over all positions and add to heat map
UNCOV
457
            if "xclim.sdba.measures.RATIO" in da.attrs["history"]:
×
458
                # if ratio, best is 1, this moves "best to 0 to compare with bias
459
                row.append(abs(da - 1).mean().values)
×
460
            else:
UNCOV
461
                row.append(abs(da).mean().values)
×
462
        # append all properties
UNCOV
463
        hmap.append(row)
×
464

465
    # plot heatmap of biases (1 column per properties, 1 row per dataset)
UNCOV
466
    hmap = np.array(hmap)
×
467
    # normalize to 0-1 -> best-worst
UNCOV
468
    hmap = np.array(
×
469
        [
470
            (c - np.min(c)) / (np.max(c) - np.min(c))
471
            if np.max(c) != np.min(c)
472
            else [0.5] * len(c)
473
            for c in hmap.T
474
        ]
475
    ).T
476

477
    name_of_datasets = name_of_datasets or list(range(1, hmap.shape[0] + 1))
×
478
    ds_hmap = xr.DataArray(
×
479
        hmap,
480
        coords={
481
            "realization": name_of_datasets,
482
            "properties": list(meas_datasets[0].data_vars),
483
        },
484
        dims=["realization", "properties"],
485
    )
UNCOV
486
    ds_hmap = ds_hmap.to_dataset(name="heatmap")
×
487

UNCOV
488
    ds_hmap.attrs = xr.core.merge.merge_attrs(
×
489
        [ds.attrs for ds in meas_datasets], combine_attrs="drop_conflicts"
490
    )
UNCOV
491
    ds_hmap = clean_up(
×
492
        ds=ds_hmap,
493
        common_attrs_only=meas_datasets,
494
    )
UNCOV
495
    ds_hmap.attrs["cat:processing_level"] = to_level
×
UNCOV
496
    ds_hmap.attrs.pop("cat:variable", None)
×
UNCOV
497
    ds_hmap["heatmap"].attrs["long_name"] = "Ranking of measure performance"
×
498

UNCOV
499
    return ds_hmap
×
500

501

502
def measures_improvement(
3✔
503
    meas_datasets: Union[list, dict], to_level: str = "diag-improved"
504
):
505
    """
506
    Calculate the fraction of improved grid points for each properties between two datasets of measures.
507

508
    Parameters
509
    ----------
510
    meas_datasets: list|dict
511
     List of 2 datasets: Initial dataset of measures and final (improved) dataset of measures.
512
     Both datasets must have the same variables.
513
     It is also possible to pass a dictionary where the values are the datasets and the key are not used.
514
    to_level: str
515
        processing_level to assign to the output dataset
516

517
    Returns
518
    -------
519
    xr.Dataset
520

521
    """
522
    if isinstance(meas_datasets, dict):
×
523
        meas_datasets = list(meas_datasets.values())
×
524

525
    if len(meas_datasets) != 2:
×
526
        warnings.warn(
×
527
            "meas_datasets has more than 2 datasets."
528
            " Only the first 2 will be compared."
529
        )
UNCOV
530
    ds1 = meas_datasets[0]
×
UNCOV
531
    ds2 = meas_datasets[1]
×
532
    percent_better = []
×
UNCOV
533
    for var in ds2.data_vars:
×
534
        if "xclim.sdba.measures.RATIO" in ds1[var].attrs["history"]:
×
UNCOV
535
            diff_bias = abs(ds1[var] - 1) - abs(ds2[var] - 1)
×
536
        else:
537
            diff_bias = abs(ds1[var]) - abs(ds2[var])
×
538
        diff_bias = diff_bias.values.ravel()
×
539
        diff_bias = diff_bias[~np.isnan(diff_bias)]
×
540

541
        total = ds2[var].values.ravel()
×
UNCOV
542
        total = total[~np.isnan(total)]
×
543

UNCOV
544
        improved = diff_bias >= 0
×
UNCOV
545
        percent_better.append(np.sum(improved) / len(total))
×
546

UNCOV
547
    ds_better = xr.DataArray(
×
548
        percent_better, coords={"properties": list(ds2.data_vars)}, dims="properties"
549
    )
550

UNCOV
551
    ds_better = ds_better.to_dataset(name="improved_grid_points")
×
552

UNCOV
553
    ds_better["improved_grid_points"].attrs[
×
554
        "long_name"
555
    ] = "Fraction of improved grid cells"
UNCOV
556
    ds_better.attrs = ds2.attrs
×
UNCOV
557
    ds_better.attrs["cat:processing_level"] = to_level
×
UNCOV
558
    ds_better.attrs.pop("cat:variable", None)
×
559

UNCOV
560
    return ds_better
×
561

562

563
def measures_improvement_2d(dict_input: dict, to_level: str = "diag-improved-2d"):
3✔
564
    """
565
    Create a 2D dataset with dimension `realization` showing the fraction of improved grid cell.
566

567
    Parameters
568
    ----------
569
    dict_input: dict
570
      If dict of datasets, the datasets should be the output of `measures_improvement`.
571
      If dict of dict/list, the dict/list should be the input to `measures_improvement`.
572
      The keys will be the values of the dimension `realization`.
573
    to_level: str
574
      Processing_level to assign to the output dataset.
575

576
    Returns
577
    -------
578
    xr.Dataset
579
      Dataset with extra `realization` coordinates.
580
    """
UNCOV
581
    merge = {}
×
UNCOV
582
    for name, value in dict_input.items():
×
583
        # if dataset, assume the value is already the output of `measures_improvement`
UNCOV
584
        if isinstance(value, xr.Dataset):
×
UNCOV
585
            out = value.expand_dims(dim={"realization": [name]})
×
586
        # else, compute the `measures_improvement`
587
        else:
UNCOV
588
            out = measures_improvement(value).expand_dims(dim={"realization": [name]})
×
UNCOV
589
        merge[name] = out
×
590
    # put everything in one dataset with dim datasets
UNCOV
591
    ds_merge = xr.concat(list(merge.values()), dim="realization")
×
UNCOV
592
    ds_merge["realization"] = ds_merge["realization"].astype(str)
×
UNCOV
593
    ds_merge = clean_up(
×
594
        ds=ds_merge,
595
        common_attrs_only=merge,
596
    )
UNCOV
597
    ds_merge.attrs["cat:processing_level"] = to_level
×
598

UNCOV
599
    return ds_merge
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc