• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / miranda / 2423411918

pending completion
2423411918

Pull #50

github

GitHub
Merge 3a5cf6f9b into 6d81d9443
Pull Request #50: revise structure to fit newest database definition

4 of 37 new or added lines in 5 files covered. (10.81%)

131 existing lines in 4 files now uncovered.

661 of 3285 relevant lines covered (20.12%)

0.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

10.84
/miranda/convert/_utils.py
1
import datetime
3✔
2
import json
3✔
3
import logging.config
3✔
4
import os
3✔
5
from pathlib import Path
3✔
6
from typing import Dict, Optional, Union
3✔
7

8
import netCDF4
3✔
9
import numpy as np
3✔
10
import xarray as xr
3✔
11
import zarr
3✔
12
from clisops.core import subset
3✔
13
from xclim.core import units
3✔
14
from xclim.indices import tas
3✔
15

16
from miranda.scripting import LOGGING_CONFIG
3✔
17
from miranda.units import get_time_frequency
3✔
18

19
logging.config.dictConfig(LOGGING_CONFIG)
3✔
20

21
__all__ = [
3✔
22
    "get_chunks_on_disk",
23
    "add_ar6_regions",
24
    "daily_aggregation",
25
    "delayed_write",
26
    "variable_conversion",
27
]
28

29
LATLON_COORDINATE_PRECISION = dict()
3✔
30
LATLON_COORDINATE_PRECISION["era5-land"] = 4
3✔
31

32
VERSION = datetime.datetime.now().strftime("%Y.%m.%d")
3✔
33

34

35
def load_json_data_mappings(project: str) -> dict:
3✔
36
    data_folder = Path(__file__).parent / "data"
×
37

38
    if project.startswith("era5"):
×
39
        metadata_definition = json.load(open(data_folder / "ecmwf_cf_attrs.json"))
×
40
    elif project in ["agcfsr", "agmerra2"]:  # This should handle the AG versions:
×
41
        metadata_definition = json.load(open(data_folder / "nasa_cf_attrs.json"))
×
42
    elif project == "nrcan-gridded-10km":
×
43
        raise NotImplementedError()
×
44
    elif project == "wfdei-gem-capa":
×
45
        metadata_definition = json.load(open(data_folder / "usask_cf_attrs.json"))
×
46
    else:
47
        raise NotImplementedError()
×
48

49
    return metadata_definition
×
50

51

52
def get_chunks_on_disk(nc_file: Union[os.PathLike, str]) -> dict:
3✔
53
    """
54

55
    Parameters
56
    ----------
57
    nc_file: Path or str
58

59
    Returns
60
    -------
61
    dict
62
    """
63
    # FIXME: This does not support zarr
64
    # TODO: This needs to be optimized for dask. See: https://github.com/Ouranosinc/miranda/pull/24/files#r840617216
65
    ds = netCDF4.Dataset(nc_file)
×
66
    chunks = dict()
×
67
    for v in ds.variables:
×
68
        chunks[v] = dict()
×
69
        for ii, dim in enumerate(ds[v].dimensions):
×
70
            chunks[v][dim] = ds[v].chunking()[ii]
×
71
    return chunks
×
72

73

74
def add_ar6_regions(ds: xr.Dataset) -> xr.Dataset:
3✔
75
    """Add the IPCC AR6 Regions to dataset.
76

77
    Parameters
78
    ----------
79
    ds : xarray.Dataset
80

81
    Returns
82
    -------
83
    xarray.Dataset
84
    """
85
    try:
×
86
        import regionmask  # noqa
×
87
    except ImportError:
×
88
        raise ImportError(
×
89
            f"{add_ar6_regions.__name__} functions require additional dependencies. "
90
            "Please install them with `pip install miranda[full]`."
91
        )
92

93
    mask = regionmask.defined_regions.ar6.all.mask(ds.lon, ds.lat)
×
94
    ds = ds.assign_coords(region=mask)
×
95
    return ds
×
96

97

98
def variable_conversion(ds: xr.Dataset, project: str, output_format: str) -> xr.Dataset:
3✔
99
    """Convert variables to CF-compliant format"""
100

101
    def _correct_units_names(d: xr.Dataset, p: str, m: Dict):
×
102
        key = "_corrected_units"
×
103
        for v in d.data_vars:
×
104
            if p in m["variable_entry"][v][key].keys():
×
105
                d[v].attrs["units"] = m["variable_entry"][v][key][project]
×
106

107
        if "time" in m["variable_entry"].keys():
×
108
            if p in m["variable_entry"]["time"][key].keys():
×
109
                d["time"].attrs["units"] = m["variable_entry"]["time"][key][project]
×
110

111
        return d
×
112

113
    # for de-accumulation or conversion to flux
114
    def _transform(d: xr.Dataset, p: str, m: Dict):
×
115
        key = "_transformation"
×
116
        d_out = xr.Dataset(coords=d.coords, attrs=d.attrs)
×
117
        for vv in d.data_vars:
×
118
            if p in m["variable_entry"][vv][key].keys():
×
119
                try:
×
120
                    offset, offset_meaning = get_time_frequency(d)
×
121
                except TypeError:
×
122
                    logging.error(
×
123
                        f"Unable to parse the time frequency for variable `{vv}`. "
124
                        "Verify data integrity before retrying."
125
                    )
126
                    raise
×
127

128
                if m["variable_entry"][vv][key][p] == "deaccumulate":
×
129
                    # Time-step accumulated total to time-based flux (de-accumulation)
130
                    logging.info(f"De-accumulating units for variable `{vv}`.")
×
131
                    with xr.set_options(keep_attrs=True):
×
132
                        out = d[vv].diff(dim="time")
×
133
                        out = d[vv].where(
×
134
                            getattr(d[vv].time.dt, offset_meaning) == offset[0],
135
                            out.broadcast_like(d[vv]),
136
                        )
137
                        out = units.amount2rate(out)
×
138
                    d_out[out.name] = out
×
139
                elif m["variable_entry"][vv][key][p] == "amount2rate":
×
140
                    # frequency-based totals to time-based flux
141
                    logging.info(
×
142
                        f"Performing amount-to-rate units conversion for variable `{vv}`."
143
                    )
144
                    out = units.amount2rate(
×
145
                        d[vv],
146
                        out_units=m["variable_entry"][vv]["units"],
147
                    )
148
                    d_out[out.name] = out
×
149
                else:
150
                    raise NotImplementedError(
×
151
                        f"Unknown transformation: {m['variable_entry'][vv][key][p]}"
152
                    )
153
            else:
154
                d_out[vv] = d[vv]
×
155
        return d_out
×
156

157
    def _offset_time(d: xr.Dataset, p: str, m: Dict) -> xr.Dataset:
×
158
        key = "_offset_time"
×
159
        d_out = xr.Dataset(coords=d.coords, attrs=d.attrs)
×
160
        for vv in d.data_vars:
×
161
            if p in m["variable_entry"][vv][key].keys():
×
162
                try:
×
163
                    offset, offset_meaning = get_time_frequency(d)
×
164
                except TypeError:
×
165
                    logging.error(
×
166
                        f"Unable to parse the time frequency for variable `{vv}`. "
167
                        "Verify data integrity before retrying."
168
                    )
169
                    raise
×
170

171
                if m["variable_entry"][vv][key][p]:
×
172
                    # Offset time by value of one time-step
173
                    logging.info(
×
174
                        f"Offsetting data for `{vv}` by `{offset[0]} {offset_meaning}(s)`."
175
                    )
176
                    with xr.set_options(keep_attrs=True):
×
177
                        out = d[vv]
×
178
                        out["time"] = out.time - np.timedelta64(offset[0], offset[1])
×
179
                        d_out[out.name] = out
×
180
                else:
181
                    logging.info(
×
182
                        f"No time offsetting needed for `{vv}` in `{p}` (Explicitly set to False)."
183
                    )
184
                    d_out = d
×
185
            else:
186
                logging.info(f"No time offsetting needed for `{vv}` in project `{p}`.")
×
187
                d_out = d
×
188
        return d_out
×
189

190
    # For converting variable units to standard workflow units
191
    def _units_cf_conversion(d: xr.Dataset, m: Dict) -> xr.Dataset:
×
192
        descriptions = m["variable_entry"]
×
193

194
        if "time" in m["variable_entry"].keys():
×
195
            d["time"]["units"] = m["variable_entry"]["time"]["units"]
×
196

197
        for v in d.data_vars:
×
198
            d[v] = units.convert_units_to(d[v], descriptions[v]["units"])
×
199

200
        return d
×
201

202
    # Add and update existing metadata fields
203
    def _metadata_conversion(d: xr.Dataset, p: str, o: str, m: Dict) -> xr.Dataset:
×
204
        logging.info("Converting metadata to CF-like conventions.")
×
205

206
        # Add global attributes
207
        d.attrs.update(m["Header"])
×
208
        d.attrs.update(dict(project=p, format=o))
×
209

210
        # Date-based versioning
211
        d.attrs.update(dict(version=f"v{VERSION}"))
×
212

213
        history = (
×
214
            f"{d.attrs['history']}\n[{datetime.datetime.now()}] Converted from original data to {o} "
215
            "with modified metadata for CF-like compliance."
216
        )
217
        d.attrs.update(dict(history=history))
×
218
        descriptions = m["variable_entry"]
×
219

UNCOV
220
        if "time" in m["variable_entry"].keys():
×
221
            descriptions["time"].pop("_corrected_units")
×
222

223
        # Add variable metadata
UNCOV
224
        for v in d.data_vars:
×
UNCOV
225
            descriptions[v].pop("_corrected_units")
×
UNCOV
226
            descriptions[v].pop("_offset_time")
×
227
            descriptions[v].pop("_transformation")
×
228
            d[v].attrs.update(descriptions[v])
×
229

230
        # Rename data variables
231
        for v in d.data_vars:
×
UNCOV
232
            try:
×
UNCOV
233
                cf_name = descriptions[v]["_cf_variable_name"]
×
234
                d = d.rename({v: cf_name})
×
235
                d[cf_name].attrs.update(dict(original_variable=v))
×
236
                del d[cf_name].attrs["_cf_variable_name"]
×
237
            except (ValueError, IndexError):
×
UNCOV
238
                pass
×
UNCOV
239
        return d
×
240

241
    # For renaming lat and lon dims
242
    def _dims_conversion(d: xr.Dataset):
×
243
        sort_dims = []
×
244
        for orig, new in dict(longitude="lon", latitude="lat").items():
×
245
            try:
×
246

247
                d = d.rename({orig: new})
×
248
                if new == "lon" and np.any(d.lon > 180):
×
249
                    lon1 = d.lon.where(d.lon <= 180.0, d.lon - 360.0)
×
UNCOV
250
                    d[new] = lon1
×
UNCOV
251
                sort_dims.append(new)
×
252
            except KeyError:
×
253
                pass
×
254
            if project in LATLON_COORDINATE_PRECISION.keys():
×
255
                d[new] = d[new].round(LATLON_COORDINATE_PRECISION[project])
×
UNCOV
256
        if sort_dims:
×
257
            d = d.sortby(sort_dims)
×
258
        return d
×
259

260
    metadata_definition = load_json_data_mappings(project)
×
261
    ds = _correct_units_names(ds, project, metadata_definition)
×
262
    ds = _transform(ds, project, metadata_definition)
×
263
    ds = _offset_time(ds, project, metadata_definition)
×
264
    ds = _units_cf_conversion(ds, metadata_definition)
×
265
    ds = _metadata_conversion(ds, project, output_format, metadata_definition)
×
266
    ds = _dims_conversion(ds)
×
267

268
    return ds
×
269

270

271
def daily_aggregation(ds) -> Dict[str, xr.Dataset]:
3✔
272
    logging.info("Creating daily upscaled climate variables.")
×
273

274
    daily_dataset = dict()
×
275
    for variable in ds.data_vars:
×
276
        if variable in ["tas", "tdps"]:
×
277
            # Some looping to deal with memory consumption issues
278
            for v, func in {
×
279
                f"{variable}max": "max",
280
                f"{variable}min": "min",
281
                f"{variable}": "mean",
282
            }.items():
UNCOV
283
                ds_out = xr.Dataset()
×
284
                ds_out.attrs = ds.attrs.copy()
×
285
                ds_out.attrs["frequency"] = "day"
×
286

UNCOV
287
                method = (
×
288
                    f"time: {func}{'imum' if func != 'mean' else ''} (interval: 1 day)"
289
                )
UNCOV
290
                ds_out.attrs["cell_methods"] = method
×
291

UNCOV
292
                if v == "tas" and not hasattr(ds, "tas"):
×
293
                    ds_out[v] = tas(tasmax=ds.tasmax, tasmin=ds.tasmin)
×
294
                else:
295
                    # Thanks for the help, xclim contributors
UNCOV
296
                    r = ds[variable].resample(time="D")
×
297
                    ds_out[v] = getattr(r, func)(dim="time", keep_attrs=True)
×
298

UNCOV
299
                daily_dataset[v] = ds_out
×
300
                del ds_out
×
301

302
        elif variable in [
×
303
            "evspsblpot",
304
            "hfls",
305
            "hfss",
306
            "pr",
307
            "prsn",
308
            "rsds",
309
            "rlds",
310
            "snd",
311
            "snr",
312
            "snw",
313
        ]:
UNCOV
314
            ds_out = xr.Dataset()
×
UNCOV
315
            ds_out.attrs = ds.attrs.copy()
×
UNCOV
316
            ds_out.attrs["frequency"] = "day"
×
UNCOV
317
            ds_out.attrs["cell_methods"] = "time: mean (interval: 1 day)"
×
UNCOV
318
            logging.info(f"Converting {variable} to daily time step (daily mean).")
×
UNCOV
319
            ds_out[variable] = (
×
320
                ds[variable].resample(time="D").mean(dim="time", keep_attrs=True)
321
            )
322

UNCOV
323
            daily_dataset[variable] = ds_out
×
324
            del ds_out
×
325
        else:
326
            continue
×
327

328
    return daily_dataset
×
329

330

331
def threshold_land_sea_mask(
3✔
332
    ds: Union[xr.Dataset, str, os.PathLike],
333
    *,
334
    land_sea_mask: Dict[str, Union[os.PathLike, str]],
335
    land_sea_percentage: int = 50,
336
    output_folder: Optional[Union[str, os.PathLike]] = None,
337
) -> Optional[Path]:
338
    """Land-Sea mask operations.
339

340
    Parameters
341
    ----------
342
    ds: Union[xr.Dataset, str, os.PathLike]
343
    land_sea_mask: dict
344
    land_sea_percentage: int
345
    output_folder: str or os.PathLike, optional
346

347
    Returns
348
    -------
349
    Path
350
    """
UNCOV
351
    file_name = ""
×
UNCOV
352
    if isinstance(ds, (str, os.PathLike)):
×
UNCOV
353
        if output_folder is not None:
×
UNCOV
354
            output_folder = Path(output_folder)
×
UNCOV
355
            file_name = f"{Path(ds).stem}_land-sea-masked.nc"
×
UNCOV
356
        ds = xr.open_dataset(ds)
×
357

UNCOV
358
    if output_folder is not None and file_name == "":
×
UNCOV
359
        logging.warning(
×
360
            "Cannot generate filenames from xarray.Dataset objects. Consider writing NetCDF manually."
361
        )
362

363
    try:
×
364
        project = ds.attrs["project"]
×
365
    except KeyError:
×
366
        raise ValueError("No 'project' field found for given dataset.")
×
367

368
    if project in land_sea_mask.keys():
×
369
        logging.info(
×
370
            f"Masking variable with land-sea mask at {land_sea_percentage} % cutoff."
371
        )
UNCOV
372
        land_sea_mask_variable, lsm_file = land_sea_mask[project]
×
373
        lsm_raw = xr.open_dataset(lsm_file)
×
374
        try:
×
375
            lsm_raw = lsm_raw.rename({"longitude": "lon", "latitude": "lat"})
×
376
        except ValueError:
×
UNCOV
377
            raise
×
378

379
        lon_bounds = np.array([ds.lon.min(), ds.lon.max()])
×
UNCOV
380
        lat_bounds = np.array([ds.lat.min(), ds.lat.max()])
×
381

382
        lsm = subset.subset_bbox(
×
383
            lsm_raw,
384
            lon_bnds=lon_bounds,
385
            lat_bnds=lat_bounds,
386
        ).load()
387
        lsm = lsm.where(lsm[land_sea_mask_variable] > float(land_sea_percentage) / 100)
×
UNCOV
388
        if project == "era5":
×
389
            ds = ds.where(lsm[land_sea_mask].isel(time=0, drop=True).notnull())
×
390
            try:
×
UNCOV
391
                ds = ds.rename({"longitude": "lon", "latitude": "lat"})
×
392
            except ValueError:
×
UNCOV
393
                raise
×
UNCOV
394
        elif project in ["merra2", "cfsr"]:
×
UNCOV
395
            ds = ds.where(lsm[land_sea_mask].notnull())
×
396

397
        ds.attrs["land_sea_cutoff"] = f"{land_sea_percentage} %"
×
398

399
        if len(file_name) > 0:
×
400
            out = output_folder / file_name
×
401
            ds.to_netcdf(out)
×
402
            return out
×
403
        return ds
×
404
    raise RuntimeError(f"Project `{project}` was not found in land-sea masks.")
×
405

406

407
def delayed_write(
3✔
408
    ds: xr.Dataset,
409
    outfile: Path,
410
    target_chunks: dict,
411
    output_format: str,
412
    overwrite: bool,
413
):
414
    """
415

416
    Parameters
417
    ----------
418
    ds: Union[xr.Dataset, str, os.PathLike]
419
    outfile
420
    target_chunks
421
    output_format
422
    overwrite
423

424
    Returns
425
    -------
426

427
    """
428
    # Set correct chunks in encoding options
UNCOV
429
    kwargs = dict()
×
UNCOV
430
    kwargs["encoding"] = dict()
×
UNCOV
431
    for name, da in ds.data_vars.items():
×
UNCOV
432
        chunks = list()
×
UNCOV
433
        for dim in da.dims:
×
UNCOV
434
            if dim in target_chunks.keys():
×
UNCOV
435
                chunks.append(target_chunks[str(dim)])
×
436
            else:
UNCOV
437
                chunks.append(len(da[dim]))
×
438

439
        if output_format == "netcdf":
×
440
            kwargs["encoding"][name] = {
×
441
                "chunksizes": chunks,
442
                "zlib": True,
443
            }
444
            kwargs["compute"] = False
×
445
            if not overwrite:
×
UNCOV
446
                kwargs["mode"] = "a"
×
447
        elif output_format == "zarr":
×
UNCOV
448
            ds = ds.chunk(target_chunks)
×
449
            kwargs["encoding"][name] = {
×
450
                "chunks": chunks,
451
                "compressor": zarr.Blosc(),
452
            }
UNCOV
453
            kwargs["compute"] = False
×
454
            if overwrite:
×
455
                kwargs["mode"] = "w"
×
456
    if kwargs["encoding"]:
×
457
        kwargs["encoding"]["time"] = {"dtype": "int32"}
×
458

459
    return getattr(ds, f"to_{output_format}")(outfile, **kwargs)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2024 Coveralls, Inc