• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 13903624835

17 Mar 2025 03:40PM UTC coverage: 8.111% (-0.01%) from 8.124%
13903624835

push

github

web-flow
fix get_rotpole (#308)

<!-- Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [ ] This PR addresses an already opened issue (for bug fixes /
features)
  - This PR fixes #xyz
- [ ] (If applicable) Documentation has been added / updated (for bug
fixes / features).
- [ ] (If applicable) Tests have been added.
- [x] CHANGELOG.rst has been updated (with summary of main changes).
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added.

### What kind of change does this PR introduce?

* Instead of assuming the information for rotated pole is in
`rotated_pole`, use the `grid_mapping` attrs to guess the name. Inspired
by `xscen.spatial.get_grid_mapping`.
* small fix in the doc also

### Does this PR introduce a breaking change?
no

### Other information:
This is useful for the CRCM5 which uses the attr `crs`. It was decided
in the Outils Communs meeting that we would change the code, not the
data.

Note that another way to get this information is :

```
ds.cf.grid_mapping_names # Mapping grid mapping type -> [ nom_de_la_var ]
ds[DATAVAR].cf.grid_mapping_name  # Nom du type de grid mapping
# donc pour obtenir la variable contenant les informations
# sur le grid mapping de la variable DATAVAR: 
ds[ds.cf.grid_mapping_names[ds[DATAVAR].cf.grid_mapping_name][0]]
```
but this doesn't work if you only have a DataArray.

Is it worth it to add this also in the case where we have a ds without
at `grid_mapping` attrs but the right cf.attrs ? Does that even happen ?

0 of 9 new or added lines in 2 files covered. (0.0%)

1 existing line in 1 file now uncovered.

155 of 1911 relevant lines covered (8.11%)

0.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.27
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
6✔
4

5
import json
6✔
6
import math
6✔
7
import pathlib
6✔
8
import re
6✔
9
import warnings
6✔
10
from copy import deepcopy
6✔
11
from tempfile import NamedTemporaryFile
6✔
12
from typing import Any, Callable
6✔
13

14
import cairosvg
6✔
15
import cartopy.crs as ccrs
6✔
16
import cartopy.feature as cfeature
6✔
17
import geopandas as gpd
6✔
18
import matplotlib as mpl
6✔
19
import matplotlib.axes
6✔
20
import matplotlib.colors as mcolors
6✔
21
import matplotlib.pyplot as plt
6✔
22
import numpy as np
6✔
23
import pandas as pd
6✔
24
import seaborn
6✔
25
import xarray as xr
6✔
26
import yaml
6✔
27
from matplotlib.lines import Line2D
6✔
28
from skimage.transform import resize
6✔
29
from xclim.core.options import METADATA_LOCALES
6✔
30
from xclim.core.options import OPTIONS as XC_OPTIONS
6✔
31

32
from .._logo import Logos
6✔
33

34
TERMS: dict = {}
6✔
35
"""
5✔
36
A translation directory for special terms to appear on the plots.
37

38
Keys are terms to translate and they map to "locale": "translation" dictionaries.
39
The "official" figanos terms are based on figanos/data/terms.yml.
40
"""
41

42

43
# Load terms translations
44
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
6✔
45
    TERMS = yaml.safe_load(f)
6✔
46

47

48
def get_localized_term(term, locale=None):
6✔
49
    """Get `term` translated into `locale`.
50

51
    Terms are pulled from the :py:data:`TERMS` dictionary.
52

53
    Parameters
54
    ----------
55
    term : str
56
        A word or short phrase to translate.
57
    locale : str, optional
58
        A 2-letter locale name to translate to.
59
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
60

61
    Returns
62
    -------
63
    str
64
        Translated term.
65
    """
66
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
67
    if locale == "en":
×
68
        return term
×
69

70
    if term not in TERMS:
×
71
        warnings.warn(f"No translation known for term '{term}'.")
×
72
        return term
×
73

74
    if locale not in TERMS[term]:
×
75
        warnings.warn(f"No {locale} translation known for term '{term}'.")
×
76
        return term
×
77

78
    return TERMS[term][locale]
×
79

80

81
def empty_dict(param) -> dict:
6✔
82
    """Return empty dict if input is None."""
83
    if param is None:
×
84
        param = dict()
×
85
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
86

87

88
def check_timeindex(
6✔
89
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
90
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
91
    """Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
92

93
    Parameters
94
    ----------
95
    xr_objs : xr.DataArray or xr.Dataset or dict
96
        Dictionary containing Xarray DataArrays or Datasets.
97

98
    Returns
99
    -------
100
    xr.DataArray or xr.Dataset or dict
101
        Dictionary of xarray objects with a pandas DatetimeIndex
102
    """
103
    if isinstance(xr_objs, dict):
×
104
        for name, obj in xr_objs.items():
×
105
            if "time" in obj.dims:
×
106
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
107
                    conv_obj = obj.convert_calendar(
×
108
                        "standard", use_cftime=None, align_on="year"
109
                    )
110
                    xr_objs[name] = conv_obj
×
111
                    warnings.warn(
×
112
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
113
                    )
114

115
    else:
116
        if "time" in xr_objs.dims:
×
117
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
118
                conv_obj = xr_objs.convert_calendar(
×
119
                    "standard", use_cftime=None, align_on="year"
120
                )
121
                xr_objs = conv_obj
×
122
                warnings.warn(
×
123
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
124
                )
125

126
    return xr_objs
×
127

128

129
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
6✔
130
    """Get an array category, which determines how to plot the array.
131

132
    Parameters
133
    ----------
134
    array : Dataset or DataArray
135
        The array being categorized.
136

137
    Returns
138
    -------
139
    str
140
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
141
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
142
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
143
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
144
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
145
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
146
        DS: any Dataset that is not  recognized as an ensemble
147
        DA: DataArray
148
    """
149
    if isinstance(array, xr.Dataset):
×
150
        if (
×
151
            pd.notnull(
152
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
153
            ).sum()
154
            >= 2
155
        ):
156
            cat = "ENS_PCT_VAR_DS"
×
157
        elif (
×
158
            pd.notnull(
159
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
160
            ).sum()
161
            >= 2
162
        ):
163
            cat = "ENS_STATS_VAR_DS"
×
164
        elif "percentiles" in array.dims:
×
165
            cat = "ENS_PCT_DIM_DS"
×
166
        elif "realization" in array.dims:
×
167
            cat = "ENS_REALS_DS"
×
168
        else:
169
            cat = "DS"
×
170

171
    elif isinstance(array, xr.DataArray):
×
172
        if "percentiles" in array.dims:
×
173
            cat = "ENS_PCT_DIM_DA"
×
174
        elif "realization" in array.dims:
×
175
            cat = "ENS_REALS_DA"
×
176
        else:
177
            cat = "DA"
×
178
    else:
179
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
180

181
    return cat
×
182

183

184
def get_attributes(
6✔
185
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
186
) -> str:
187
    """Fetch attributes or dims corresponding to keys from Xarray objects.
188

189
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
190
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
191

192
    Parameters
193
    ----------
194
    string : str
195
        String corresponding to an attribute name.
196
    xr_obj : DataArray or Dataset
197
        The Xarray object containing the attributes.
198
    locale : str, optional
199
        A 2-letter locale name to translate to.
200
        Default is None, which will pull the locale
201
        from xclim's "metadata_locales" option (taking the first).
202

203
    Returns
204
    -------
205
    str
206
        Xarray attribute value as string or empty string if not found
207
    """
208
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
209
    if locale != "en":
×
210
        names = [f"{string}_{locale}", string]
×
211
    else:
212
        names = [string]
×
213

214
    for name in names:
×
215
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
216
            return xr_obj.attrs[name]
×
217

218
        if (
×
219
            isinstance(xr_obj, xr.Dataset)
220
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
221
        ):  # DataArray of first variable
222
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
223

224
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
225
            return xr_obj.attrs[name]
×
226

227
    warnings.warn(f'Attribute "{string}" not found.')
×
228
    return ""
×
229

230

231
def set_plot_attrs(
6✔
232
    attr_dict: dict[str, Any],
233
    xr_obj: xr.DataArray | xr.Dataset,
234
    ax: matplotlib.axes.Axes | None = None,
235
    title_loc: str = "center",
236
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
237
    wrap_kw: dict[str, Any] | None = None,
238
) -> matplotlib.axes.Axes:
239
    """Set plot elements according to Dataset or DataArray attributes.
240

241
    Uses get_attributes() to check for and get the string.
242

243
    Parameters
244
    ----------
245
    attr_dict : dict
246
        Dictionary containing specified attribute keys.
247
    xr_obj : Dataset or DataArray
248
        The Xarray object containing the attributes.
249
    ax : matplotlib axis
250
        The matplotlib axis of the plot.
251
    title_loc : str
252
        Location of the title.
253
    wrap_kw : dict, optional
254
        Arguments to pass to the wrap_text function for the title.
255

256
    Returns
257
    -------
258
    matplotlib.axes.Axes
259
    """
260
    wrap_kw = empty_dict(wrap_kw)
×
261

262
    #  check
263
    for key in attr_dict:
×
264
        if key not in [
×
265
            "title",
266
            "ylabel",
267
            "yunits",
268
            "xlabel",
269
            "xunits",
270
            "cbar_label",
271
            "cbar_units",
272
            "suptitle",
273
        ]:
274
            warnings.warn(f'Use_attrs element "{key}" not supported')
×
275

276
    if "title" in attr_dict:
×
277
        title = get_attributes(attr_dict["title"], xr_obj)
×
278
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
279

280
    if "ylabel" in attr_dict:
×
281
        if (
×
282
            "yunits" in attr_dict
283
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
284
        ):  # second condition avoids '[]' as label
285
            ylabel = wrap_text(
×
286
                get_attributes(attr_dict["ylabel"], xr_obj)
287
                + " ("
288
                + get_attributes(attr_dict["yunits"], xr_obj)
289
                + ")"
290
            )
291
        else:
292
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
293

294
        ax.set_ylabel(ylabel)
×
295

296
    if "xlabel" in attr_dict:
×
297
        if (
×
298
            "xunits" in attr_dict
299
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
300
        ):  # second condition avoids '[]' as label
301
            xlabel = wrap_text(
×
302
                get_attributes(attr_dict["xlabel"], xr_obj)
303
                + " ("
304
                + get_attributes(attr_dict["xunits"], xr_obj)
305
                + ")"
306
            )
307
        else:
308
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
309

310
        ax.set_xlabel(xlabel)
×
311

312
    # cbar label has to be assigned in main function, ignore.
313
    if "cbar_label" in attr_dict:
×
314
        pass
315

316
    if "cbar_units" in attr_dict:
×
317
        pass
318

319
    if facetgrid:
×
320
        if "suptitle" in attr_dict:
×
321
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
322
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
323
            facetgrid.set_titles(template="{value}")
×
324
        return facetgrid
×
325

326
    else:
327
        return ax
×
328

329

330
def get_suffix(string: str) -> str:
6✔
331
    """Get suffix of typical Xclim variable names."""
332
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
333
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
334
        return suffix
×
335
    else:
336
        raise ValueError(f"Mean, min or max not found in {string}")
×
337

338

339
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
6✔
340
    """Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
341

342
    Parameters
343
    ----------
344
    array_dict : dict
345
        Dictionary of format {'name': array...}.
346

347
    Returns
348
    -------
349
    dict
350
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
351
    """
352
    if len(array_dict) != 3:
×
353
        raise ValueError("Ensembles must contain exactly three arrays")
×
354

355
    sorted_lines = {}
×
356

357
    for name in array_dict.keys():
×
358
        suffix = get_suffix(name)
×
359

360
        if suffix.isalpha():
×
361
            if suffix in ["max", "Max"]:
×
362
                sorted_lines["upper"] = name
×
363
            elif suffix in ["min", "Min"]:
×
364
                sorted_lines["lower"] = name
×
365
            elif suffix in ["mean", "Mean"]:
×
366
                sorted_lines["middle"] = name
×
367
        elif suffix.isdigit():
×
368
            if int(suffix) >= 51:
×
369
                sorted_lines["upper"] = name
×
370
            elif int(suffix) <= 49:
×
371
                sorted_lines["lower"] = name
×
372
            elif int(suffix) == 50:
×
373
                sorted_lines["middle"] = name
×
374
        else:
375
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
376
    return sorted_lines
×
377

378

379
def loc_mpl(
6✔
380
    loc: str | tuple[int | float, int | float] | int,
381
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
382
    """Find coordinates and alignment associated to loc string.
383

384
    Parameters
385
    ----------
386
    loc : string, int, or tuple[float, float]
387
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
388
        If a tuple, must be in axes coordinates.
389

390
    Returns
391
    -------
392
    tuple(float, float), tuple(float, float), str, str
393
    """
394
    ha = "left"
×
395
    va = "bottom"
×
396

397
    loc_strings = [
×
398
        "upper right",
399
        "upper left",
400
        "lower left",
401
        "lower right",
402
        "right",
403
        "center left",
404
        "center right",
405
        "lower center",
406
        "upper center",
407
        "center",
408
    ]
409

410
    if isinstance(loc, int):
×
411
        try:
×
412
            loc = loc_strings[loc - 1]
×
413
        except IndexError:
×
414
            raise ValueError("loc must be between 1 and 10, inclusively")
×
415

416
    if loc in loc_strings:
×
417
        # ha
418
        if "left" in loc:
×
419
            ha = "left"
×
420
        elif "right" in loc:
×
421
            ha = "right"
×
422
        else:
423
            ha = "center"
×
424

425
        # va
426
        if "lower" in loc:
×
427
            va = "bottom"
×
428
        elif "upper" in loc:
×
429
            va = "top"
×
430
        else:
431
            va = "center"
×
432

433
        # transAxes
434
        if loc == "upper right":
×
435
            loc = (0.97, 0.97)
×
436
            box_a = (1, 1)
×
437
        elif loc == "upper left":
×
438
            loc = (0.03, 0.97)
×
439
            box_a = (0, 1)
×
440
        elif loc == "lower left":
×
441
            loc = (0.03, 0.03)
×
442
            box_a = (0, 0)
×
443
        elif loc == "lower right":
×
444
            loc = (0.97, 0.03)
×
445
            box_a = (1, 0)
×
446
        elif loc == "right":
×
447
            loc = (0.97, 0.5)
×
448
            box_a = (1, 0.5)
×
449
        elif loc == "center left":
×
450
            loc = (0.03, 0.5)
×
451
            box_a = (0, 0.5)
×
452
        elif loc == "center right":
×
453
            loc = (0.97, 0.5)
×
454
            box_a = (0.97, 0.5)
×
455
        elif loc == "lower center":
×
456
            loc = (0.5, 0.03)
×
457
            box_a = (0.5, 0)
×
458
        elif loc == "upper center":
×
459
            loc = (0.5, 0.97)
×
460
            box_a = (0.5, 1)
×
461
        else:
462
            loc = (0.5, 0.5)
×
463
            box_a = (0.5, 0.5)
×
464

465
    elif isinstance(loc, tuple):
×
466
        box_a = []
×
467
        for i in loc:
×
468
            if i > 1 or i < 0:
×
469
                raise ValueError(
×
470
                    "Text location coordinates must be between 0 and 1, inclusively"
471
                )
472
            elif i > 0.5:
×
473
                box_a.append(1)
×
474
            else:
475
                box_a.append(0)
×
476
        box_a = tuple(box_a)
×
477
    else:
478
        raise ValueError(
×
479
            "loc must be a string, int or tuple. "
480
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
481
        )
482

483
    return loc, box_a, ha, va
×
484

485

486
def plot_coords(
6✔
487
    ax: matplotlib.axes.Axes | None,
488
    xr_obj: xr.DataArray | xr.Dataset,
489
    loc: str | tuple[float, float] | int,
490
    param: str | None = None,
491
    backgroundalpha: float = 1,
492
) -> matplotlib.axes.Axes:
493
    """Place coordinates on plot area.
494

495
    Parameters
496
    ----------
497
    ax : matplotlib.axes.Axes or None
498
        Matplotlib axes object on which to place the text.
499
        If None, will use plt.figtext instead (should be used for facetgrids).
500
    xr_obj : xr.DataArray or xr.Dataset
501
        The xarray object from which to fetch the text content.
502
    param : {"location", "time"}, optional
503
        The parameter used.
504
    loc : string, int or tuple
505
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
506
        If a tuple, must be in axes coordinates.
507
    backgroundalpha : float
508
        Transparency of the text background. 1 is opaque, 0 is transparent.
509

510
    Returns
511
    -------
512
    matplotlib.axes.Axes
513
    """
514
    text = None
×
515
    if param == "location":
×
516
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
517
            text = "lat={:.2f}, lon={:.2f}".format(
×
518
                float(xr_obj["lat"]), float(xr_obj["lon"])
519
            )
520
        else:
521
            warnings.warn(
×
522
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords'
523
            )
524
    if param == "time":
×
525
        if "time" in xr_obj.coords:
×
526
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
527

528
        else:
529
            warnings.warn('show_time set to True, but "time" not found in coords')
×
530

531
    loc, box_a, ha, va = loc_mpl(loc)
×
532

533
    if text:
×
534
        if ax:
×
535
            t = mpl.offsetbox.TextArea(
×
536
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
537
            )
538

539
            tt = mpl.offsetbox.AnnotationBbox(
×
540
                t,
541
                loc,
542
                xycoords="axes fraction",
543
                box_alignment=box_a,
544
                pad=0.05,
545
                bboxprops=dict(
546
                    facecolor="white",
547
                    alpha=backgroundalpha,
548
                    edgecolor="w",
549
                    boxstyle="Square, pad=0.5",
550
                ),
551
            )
552
            ax.add_artist(tt)
×
553
            return ax
×
554
        elif not ax:
×
555
            """
556
            if loc == "top left":
557
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
558
            elif loc == "top right":
559
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
560
            elif loc == "bottom left":
561
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
562
            elif loc == "bottom right" or loc is True:
563
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
564
            elif isinstance(loc, tuple):
565
                        else:
566
                raise ValueError(
567
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
568
                    f"'bottom right' or a tuple of coordinates."
569
                )
570
            """
571
            plt.figtext(
×
572
                loc[0],
573
                loc[1],
574
                text,
575
                ha=ha,
576
                va=va,
577
                fontsize=12,
578
            )
579

580
            return None
×
581

582

583
def find_logo(logo: str | pathlib.Path) -> str:
6✔
584
    """Read a logo file."""
585
    logos = Logos()
×
586
    if logo:
×
587
        logo_path = logos[logo]
×
588
    else:
589
        logo_path = logos.default
×
590

591
    if logo_path is None:
×
592
        raise ValueError(
×
593
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
594
        )
595
    return logo_path
×
596

597

598
def load_image(
6✔
599
    im: str | pathlib.Path,
600
    height: float | None,
601
    width: float | None,
602
    keep_ratio: bool = True,
603
) -> np.ndarray:
604
    """Scale an image to a specified height and width.
605

606
    Parameters
607
    ----------
608
    im : str or Path
609
        The image to be scaled. PNG and SVG formats are supported.
610
    height : float, optional
611
        The desired height of the image. If None, the original height is used.
612
    width : float, optional
613
        The desired width of the image. If None, the original width is used.
614
    keep_ratio : bool
615
        If True, the aspect ratio of the original image is maintained. Default is True.
616

617
    Returns
618
    -------
619
    np.ndarray
620
        The scaled image.
621
    """
622
    if pathlib.Path(im).suffix == ".png":
×
623
        image = mpl.pyplot.imread(im)
×
624
        original_height, original_width = image.shape[:2]
×
625

626
        if height is None and width is None:
×
627
            return image
×
628

629
        warnings.warn(
×
630
            "The scikit-image library is used to resize PNG images. This may affect logo image quality."
631
        )
632
        if not keep_ratio:
×
633
            height = original_height or height
×
634
            width = original_width or width
×
635
        else:
636
            if width is not None:
×
637
                if height is not None:
×
638
                    warnings.warn("Both height and width provided, using height.")
×
639
                # Only width is provided, derive zoom factor for height based on aspect ratio
640
                height = (width / original_width) * original_height
×
641
            elif height is not None:
×
642
                # Only height is provided, derive zoom factor for width based on aspect ratio
643
                width = (height / original_height) * original_width
×
644

645
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
646

647
    elif pathlib.Path(im).suffix == ".svg":
×
648
        cairo_kwargs = dict(url=im)
×
649
        if not keep_ratio:
×
650
            if height is not None and width is not None:
×
651
                cairo_kwargs.update(output_height=height, output_width=width)
×
652
        elif width is not None:
×
653
            if height is not None:
×
654
                warnings.warn("Both height and width provided, using height.")
×
655
            cairo_kwargs.update(output_width=width)
×
656
        elif height is not None:
×
657
            cairo_kwargs.update(output_height=height)
×
658

659
        with NamedTemporaryFile(suffix=".png") as png_file:
×
660
            cairo_kwargs.update(write_to=png_file.name)
×
661
            cairosvg.svg2png(**cairo_kwargs)
×
662
            return mpl.pyplot.imread(png_file.name)
×
663

664

665
def plot_logo(
6✔
666
    ax: matplotlib.axes.Axes,
667
    loc: str | tuple[float, float] | int,
668
    logo: str | pathlib.Path | Logos | None = None,
669
    height: float | None = None,
670
    width: float | None = None,
671
    keep_ratio: bool = True,
672
    **offset_image_kwargs,
673
) -> matplotlib.axes.Axes:
674
    r"""Place logo of plot area.
675

676
    Parameters
677
    ----------
678
    ax : matplotlib.axes.Axes
679
        Matplotlib axes object on which to place the text.
680
    loc : string, int or tuple
681
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
682
        If a tuple, must be in axes coordinates.
683
    logo : str, Path, figanos.Logos, optional
684
        A name (str) or Path to a logo file, or a name of an already-installed logo.
685
        If an existing is not found, the logo will be installed and accessible via the filename.
686
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
687
        The logo must be in 'png' format.
688
    height : float, optional
689
        The desired height of the image. If None, the original height is used.
690
    width : float, optional
691
        The desired width of the image. If None, the original width is used.
692
    keep_ratio : bool, optional
693
        If True, the aspect ratio of the original image is maintained. Default is True.
694
    \*\*offset_image_kwargs
695
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
696

697
    Returns
698
    -------
699
    matplotlib.axes.Axes
700
    """
701
    if offset_image_kwargs is None:
×
702
        offset_image_kwargs = {}
×
703

704
    if isinstance(logo, Logos):
×
705
        logo_path = logo.default
×
706
    else:
707
        logo_path = find_logo(logo)
×
708

709
    image = load_image(logo_path, height, width, keep_ratio)
×
710
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
711

712
    loc, box_a, ha, va = loc_mpl(loc)
×
713
    ab = mpl.offsetbox.AnnotationBbox(
×
714
        imagebox,
715
        loc,
716
        frameon=False,
717
        xycoords="axes fraction",
718
        box_alignment=box_a,
719
        pad=0.05,
720
    )
721
    ax.add_artist(ab)
×
722
    return ax
×
723

724

725
def split_legend(
6✔
726
    ax: matplotlib.axes.Axes,
727
    in_plot: bool = False,
728
    axis_factor: float = 0.15,
729
    label_gap: float = 0.02,
730
) -> matplotlib.axes.Axes:
731
    #  TODO: check for and fix overlapping labels
732
    """Draw line labels at the end of each line, or outside the plot.
733

734
    Parameters
735
    ----------
736
    ax : matplotlib.axes.Axes
737
        The axis containing the legend.
738
    in_plot : bool
739
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
740
    axis_factor : float
741
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
742
    label_gap : float
743
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
744

745
    Returns
746
    -------
747
    matplotlib.axes.Axes
748
    """
749
    # create extra space
750
    init_xbound = ax.get_xbound()
×
751

752
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
753
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
754

755
    if in_plot is True:
×
756
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
757

758
    # get legend and plot
759

760
    handles, labels = ax.get_legend_handles_labels()
×
761
    for handle, label in zip(handles, labels):
×
762
        last_x = handle.get_xdata()[-1]
×
763
        last_y = handle.get_ydata()[-1]
×
764

765
        if isinstance(last_x, np.datetime64):
×
766
            last_x = mpl.dates.date2num(last_x)
×
767

768
        color = handle.get_color()
×
769
        # ls = handle.get_linestyle()
770

771
        if in_plot is True:
×
772
            ax.text(
×
773
                last_x + label_bump,
774
                last_y,
775
                label,
776
                ha="left",
777
                va="center",
778
                color=color,
779
            )
780
        else:
781
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
782
            ax.text(
×
783
                1.01,
784
                last_y,
785
                label,
786
                ha="left",
787
                va="center",
788
                color=color,
789
                transform=trans,
790
            )
791

792
    return ax
×
793

794

795
def fill_between_label(
6✔
796
    sorted_lines: dict[str, Any],
797
    name: str,
798
    array_categ: dict[str, Any],
799
    legend: str,
800
) -> str:
801
    """Create a label for the shading around a line in line plots.
802

803
    Parameters
804
    ----------
805
    sorted_lines : dict
806
        Dictionary created by the sort_lines() function.
807
    name : str
808
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
809
    array_categ : dict
810
        The categories of the array, as created by the get_array_categ function.
811
    legend : str
812
        Legend mode.
813

814
    Returns
815
    -------
816
    str
817
        Label to be applied to the legend element representing the shading.
818
    """
819
    if legend != "full":
×
820
        label = None
×
821
    elif array_categ[name] in [
×
822
        "ENS_PCT_VAR_DS",
823
        "ENS_PCT_DIM_DS",
824
        "ENS_PCT_DIM_DA",
825
    ]:
826
        label = get_localized_term("{}th-{}th percentiles").format(
×
827
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
828
        )
829
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
830
        label = get_localized_term("min-max range")
×
831
    else:
832
        label = None
×
833

834
    return label
×
835

836

837
def get_var_group(
6✔
838
    path_to_json: str | pathlib.Path,
839
    da: xr.DataArray | None = None,
840
    unique_str: str | None = None,
841
) -> str:
842
    """Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
843

844
    If `da` is a Dataset, look in the DataArray of the first variable.
845
    """
846
    # create dict
847
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
848
        var_dict = json.load(_f)
×
849

850
    matches = []
×
851

852
    if unique_str:
×
853
        for v in var_dict:
×
854
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
855
            if re.search(regex, unique_str):
×
856
                matches.append(var_dict[v])
×
857

858
    else:
859
        if isinstance(da, xr.Dataset):
×
860
            da = da[list(da.data_vars)[0]]
×
861
        # look in DataArray name
862
        if hasattr(da, "name") and isinstance(da.name, str):
×
863
            for v in var_dict:
×
864
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
865
                if re.search(regex, da.name):
×
866
                    matches.append(var_dict[v])
×
867

868
        # look in history
869
        if hasattr(da, "history") and len(matches) == 0:
×
870
            for v in var_dict:
×
871
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
872
                if re.search(regex, da.history):
×
873
                    matches.append(var_dict[v])
×
874

875
    matches = np.unique(matches)
×
876

877
    if len(matches) == 0:
×
878
        warnings.warn(
×
879
            "Colormap warning: Variable group not found. Use the cmap argument."
880
        )
881
        return "misc"
×
882
    elif len(matches) >= 2:
×
883
        warnings.warn(
×
884
            "Colormap warning: More than one variable group found. Use the cmap argument."
885
        )
886
        return "misc"
×
887
    else:
888
        return matches[0]
×
889

890

891
def create_cmap(
6✔
892
    var_group: str | None = None,
893
    divergent: bool | int = False,
894
    filename: str | None = None,
895
) -> matplotlib.colors.Colormap:
896
    """Create colormap according to variable group.
897

898
    Parameters
899
    ----------
900
    var_group : str, optional
901
        Variable group from IPCC scheme.
902
    divergent : bool or int
903
        Diverging colormap. If False, use sequential colormap.
904
    filename : str, optional
905
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
906

907
    Returns
908
    -------
909
    matplotlib.colors.Colormap
910
    """
911
    reverse = False
×
912

913
    if filename:
×
914
        folder = "continuous_colormaps_rgb_0-255"
×
915
        filename = filename.replace(".txt", "")
×
916

917
        if filename.endswith("_r"):
×
918
            reverse = True
×
919
            filename = filename[:-2]
×
920

921
    else:
922
        # filename
923
        if divergent is not False:
×
924
            if var_group == "misc2":
×
925
                var_group = "misc"
×
926
            filename = var_group + "_div"
×
927
        else:
928
            if var_group == "misc":
×
929
                filename = var_group + "_seq_3"  # Batlow
×
930
            elif var_group == "misc2":
×
931
                filename = "misc_seq_2"  # freezing rain
×
932
            else:
933
                filename = var_group + "_seq"
×
934

935
        folder = "continuous_colormaps_rgb_0-255"
×
936

937
    # parent should be 'figanos/'
938
    path = (
×
939
        pathlib.Path(__file__).parents[1]
940
        / "data"
941
        / "ipcc_colors"
942
        / folder
943
        / (filename + ".txt")
944
    )
945

946
    rgb_data = np.loadtxt(path)
×
947

948
    # convert to 0-1 RGB
949
    rgb_data = rgb_data / 255
×
950

951
    cmap = mcolors.LinearSegmentedColormap.from_list("cmap", rgb_data, N=256)
×
952
    if reverse is True:
×
953
        cmap = cmap.reversed()
×
954

955
    return cmap
×
956

957

958
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
6✔
959
    """Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
960

961
    Parameters
962
    ----------
963
    xr_obj : xr.DataArray or xr.Dataset
964
        The xarray object from which to look for the attributes.
965

966
    Returns
967
    -------
968
    ccrs.RotatedPole or None
969
    """
970
    try:
×
971

NEW
972
        if isinstance(xr_obj, xr.Dataset):
×
NEW
973
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
974

NEW
975
            if len(gridmap) > 1:
×
NEW
976
                warnings.warn(
×
977
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}."
978
                )
979

NEW
980
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
981
        else:
982
            # If it can't find grid_mapping, assume it's rotated_pole
NEW
983
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
984

UNCOV
985
        rotpole = ccrs.RotatedPole(
×
986
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
987
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
988
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
989
        )
990
        return rotpole
×
991

992
    except AttributeError:
×
993
        warnings.warn("Rotated pole not found. Specify a transform if necessary.")
×
994
        return None
×
995

996

997
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
6✔
998
    """Wrap text.
999

1000
    Parameters
1001
    ----------
1002
    text : str
1003
        The text to wrap.
1004
    min_line_len : int
1005
        Minimum length of each line.
1006
    max_line_len : int
1007
        Maximum length of each line.
1008

1009
    Returns
1010
    -------
1011
    str
1012
        Wrapped text
1013
    """
1014
    start = min_line_len
×
1015
    stop = max_line_len
×
1016
    sep = "\n"
×
1017
    remaining = len(text)
×
1018

1019
    if len(text) >= max_line_len:
×
1020
        while remaining > max_line_len:
×
1021
            if ". " in text[start:stop]:
×
1022
                pos = text.find(". ", start, stop) + 1
×
1023
            elif ": " in text[start:stop]:
×
1024
                pos = text.find(": ", start, stop) + 1
×
1025
            elif " " in text[start:stop]:
×
1026
                pos = text.rfind(" ", start, stop)
×
1027
            else:
1028
                warnings.warn("No spaces, points or colons to break line at.")
×
1029
                break
×
1030

1031
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1032

1033
            remaining = len(text) - len(text[:pos])
×
1034
            start = pos + 1 + min_line_len
×
1035
            stop = pos + 1 + max_line_len
×
1036

1037
    return text
×
1038

1039

1040
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
6✔
1041
    """Open shapefile with geopandas and convert to cartopy projection.
1042

1043
    Parameters
1044
    ----------
1045
    df : gpd.GeoDataFrame
1046
        GeoDataFrame (geopandas) geometry to be added to axis.
1047
    proj : ccrs.CRS
1048
        Projection to use, taken from the cartopy.crs options.
1049

1050
    Returns
1051
    -------
1052
    gpd.GeoDataFrame
1053
        GeoDataFrame adjusted to given projection
1054
    """
1055
    prj4 = proj.proj4_init
×
1056
    return df.to_crs(prj4)
×
1057

1058

1059
def convert_scen_name(name: str) -> str:
6✔
1060
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1061
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1062
    if matches:
×
1063
        for s in matches:
×
1064
            if sum(c.isdigit() for c in s) == 3:
×
1065
                new_s = s.replace(
×
1066
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1067
                ).upper()  # ssp245 to SSP2-4.5
1068
                new_name = name.replace(s, new_s)  # put back in name
×
1069
            elif sum(c.isdigit() for c in s) == 2:
×
1070
                new_s = s.replace(
×
1071
                    s[-2:], s[-2] + "." + s[-1]
1072
                ).upper()  # rcp45 to RCP4.5
1073
                new_name = name.replace(s, new_s)
×
1074
            else:
1075
                new_s = s.upper()  # cmip5 to CMIP5
×
1076
                new_name = name.replace(s, new_s)
×
1077

1078
        return new_name
×
1079
    else:
1080
        return name
×
1081

1082

1083
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
6✔
1084
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1085
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1086
        color_dict = json.load(_f)
×
1087

1088
    color = None
×
1089
    for entry in color_dict:
×
1090
        if entry in name:
×
1091
            color = color_dict[entry]
×
1092
            break
×
1093

1094
    return color
×
1095

1096

1097
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
6✔
1098
    """Apply function to dictionary keys."""
1099
    old_keys = [key for key in dct]
×
1100
    for old_key in old_keys:
×
1101
        new_key = func(old_key)
×
1102
        dct[new_key] = dct.pop(old_key)
×
1103
    return dct
×
1104

1105

1106
def categorical_colors() -> dict[str, str]:
6✔
1107
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1108
    path = (
×
1109
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1110
    )
1111
    with path.open(encoding="utf-8") as _f:
×
1112
        cat = json.load(_f)
×
1113

1114
        return cat
×
1115

1116

1117
def get_mpl_styles() -> dict[str, pathlib.Path]:
6✔
1118
    """Get the available matplotlib styles and their paths as a dictionary."""
1119
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1120
    styles = {style.stem: style for style in files}
×
1121
    return styles
×
1122

1123

1124
def set_mpl_style(*args: str, reset: bool = False) -> None:
6✔
1125
    """Set the matplotlib style using one or more stylesheets.
1126

1127
    Parameters
1128
    ----------
1129
    args : str
1130
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1131
    reset : bool
1132
        If True, reset style to matplotlib default before applying the stylesheets.
1133

1134
    Returns
1135
    -------
1136
    None
1137
    """
1138
    if reset is True:
×
1139
        mpl.style.use("default")
×
1140
    for s in args:
×
1141
        if s.endswith(".mplstyle") is True:
×
1142
            mpl.style.use(s)
×
1143
        elif s in get_mpl_styles():
×
1144
            mpl.style.use(get_mpl_styles()[s])
×
1145
        else:
1146
            warnings.warn(f"Style {s} not found.")
×
1147

1148

1149
def add_cartopy_features(
6✔
1150
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1151
) -> matplotlib.axes.Axes:
1152
    """Add cartopy features to matplotlib axes.
1153

1154
    Parameters
1155
    ----------
1156
    ax : matplotlib.axes.Axes
1157
        The axes on which to add the features.
1158
    features : list or dict
1159
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1160

1161
    Returns
1162
    -------
1163
    matplotlib.axes.Axes
1164
        The axis with added features.
1165
    """
1166
    if isinstance(features, list):
×
1167
        features = {f: {} for f in features}
×
1168

1169
    for feat in features:
×
1170
        if "scale" not in features[feat]:
×
1171
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1172
        else:
1173
            scale = features[feat].pop("scale")
×
1174
            ax.add_feature(
×
1175
                getattr(cfeature, feat.upper()).with_scale(scale),
1176
                **features[feat],
1177
            )
1178
            features[feat]["scale"] = scale  # put back
×
1179
    return ax
×
1180

1181

1182
def custom_cmap_norm(
6✔
1183
    cmap,
1184
    vmin: int | float,
1185
    vmax: int | float,
1186
    levels: int | list[int | float] | None = None,
1187
    divergent: bool | int | float = False,
1188
    linspace_out: bool = False,
1189
) -> matplotlib.colors.Normalize | np.ndarray:
1190
    """Get matplotlib normalization according to main function arguments.
1191

1192
    Parameters
1193
    ----------
1194
    cmap: matplotlib.colormap
1195
        Colormap to be used with the normalization.
1196
    vmin: int or float
1197
        Minimum of the data to be plotted with the colormap.
1198
    vmax: int or float
1199
        Maximum of the data to be plotted with the colormap.
1200
    levels : int or list, optional
1201
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1202
    divergent : bool or int or float
1203
        If int or float, becomes center of cmap. Default center is 0.
1204
    linspace_out: bool
1205
        If True, return array created by np.linspace() instead of normalization instance.
1206

1207
    Returns
1208
    -------
1209
    matplotlib.colors.Normalize
1210
    """
1211
    # get cmap if string
1212
    if isinstance(cmap, str):
×
1213
        if cmap in plt.colormaps():
×
1214
            cmap = matplotlib.colormaps[cmap]
×
1215
        else:
1216
            raise ValueError("Colormap not found")
×
1217

1218
    # make vmin and vmax prettier
1219
    if (vmax - vmin) >= 25:
×
1220
        rvmax = math.ceil(vmax / 10.0) * 10
×
1221
        rvmin = math.floor(vmin / 10.0) * 10
×
1222
    elif 1 <= (vmax - vmin) < 25:
×
1223
        rvmax = math.ceil(vmax / 1) * 1
×
1224
        rvmin = math.floor(vmin / 1) * 1
×
1225
    elif 0.1 <= (vmax - vmin) < 1:
×
1226
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1227
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1228
    else:
1229
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1230
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1231

1232
    # center
1233
    center = None
×
1234
    if divergent is not False:
×
1235
        if divergent is True:
×
1236
            center = 0
×
1237
        elif isinstance(divergent, (int, float)):
×
1238
            center = divergent
×
1239

1240
    # build norm with options
1241
    if center is not None and isinstance(levels, int):
×
1242
        if center <= rvmin or center >= rvmax:
×
1243
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1244
        if levels % 2 == 1:
×
1245
            half_levels = int((levels + 1) / 2) + 1
×
1246
        else:
1247
            half_levels = int(levels / 2) + 1
×
1248

1249
        lin = np.concatenate(
×
1250
            (
1251
                np.linspace(rvmin, center, num=half_levels),
1252
                np.linspace(center, rvmax, num=half_levels)[1:],
1253
            )
1254
        )
1255
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1256

1257
        if linspace_out:
×
1258
            return lin
×
1259

1260
    elif levels is not None:
×
1261
        if isinstance(levels, list):
×
1262
            if center is not None:
×
1263
                warnings.warn(
×
1264
                    "Divergent argument ignored when levels is a list. Use levels as a number instead."
1265
                )
1266
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1267
        else:
1268
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1269
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1270

1271
            if linspace_out:
×
1272
                return lin
×
1273

1274
    elif center is not None:
×
1275
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1276
    else:
1277
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1278

1279
    return norm
×
1280

1281

1282
def norm2range(
6✔
1283
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1284
) -> np.ndarray:
1285
    """Normalize data across a specific range."""
1286
    if data_range is None:
×
1287
        if len(data) > 1:
×
1288
            data_range = (np.nanmin(data), np.nanmax(data))
×
1289
        else:
1290
            raise ValueError(" if data is not an array, data_range must be specified")
×
1291

1292
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1293

1294
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1295

1296

1297
def size_legend_elements(
6✔
1298
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1299
) -> list[matplotlib.lines.Line2D]:
1300
    """Create handles to use in a point-size legend.
1301

1302
    Parameters
1303
    ----------
1304
    data : np.ndarray
1305
        Data used to determine the point sizes.
1306
    sizes : np.ndarray
1307
        Array of point sizes.
1308
    max_entries : int
1309
        Maximum number of entries in the legend.
1310
    marker: str
1311
        Marker to use in legend.
1312

1313
    Returns
1314
    -------
1315
    list of matplotlib.lines.Line2D
1316
    """
1317
    # how many increments of 10 pts**2 are there in the sizes
1318
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1319

1320
    # divide data in those increments
1321
    lgd_data = np.linspace(min(data), max(data), n)
×
1322

1323
    # round according to range
1324
    ratio = abs(max(data) - min(data) / n)
×
1325

1326
    if ratio >= 1000:
×
1327
        rounding = 1000
×
1328
    elif 100 <= ratio < 1000:
×
1329
        rounding = 100
×
1330
    elif 10 <= ratio < 100:
×
1331
        rounding = 10
×
1332
    elif 5 <= ratio < 10:
×
1333
        rounding = 5
×
1334
    elif 1 <= ratio < 5:
×
1335
        rounding = 1
×
1336
    elif 0.1 <= ratio < 1:
×
1337
        rounding = 0.1
×
1338
    elif 0.01 <= ratio < 0.1:
×
1339
        rounding = 0.01
×
1340
    else:
1341
        rounding = 0.001
×
1342

1343
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1344

1345
    # convert back to sizes
1346
    lgd_sizes = norm2range(
×
1347
        data=lgd_data,
1348
        data_range=(min(data), max(data)),
1349
        target_range=(min(sizes), max(sizes)),
1350
    )
1351

1352
    legend_elements = []
×
1353

1354
    for s, d in zip(lgd_sizes, lgd_data):
×
1355
        if isinstance(d, float) and d.is_integer():
×
1356
            label = str(int(d))
×
1357
        else:
1358
            label = str(d)
×
1359

1360
        legend_elements.append(
×
1361
            Line2D(
1362
                [0],
1363
                [0],
1364
                marker=marker,
1365
                color="k",
1366
                lw=0,
1367
                markerfacecolor="w",
1368
                label=label,
1369
                markersize=np.sqrt(np.abs(s)),
1370
            )
1371
        )
1372

1373
    if len(legend_elements) > max_entries:
×
1374
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1375
    else:
1376
        return legend_elements
×
1377

1378

1379
def add_features_map(
6✔
1380
    data,
1381
    ax,
1382
    use_attrs,
1383
    projection,
1384
    features,
1385
    geometries_kw,
1386
    frame,
1387
) -> matplotlib.axes.Axes:
1388
    """Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1389

1390
    Parameters
1391
    ----------
1392
    data : dict, DataArray or Dataset
1393
        Input data do plot. If dictionary, must have only one entry.
1394
    ax : matplotlib axis
1395
        Matplotlib axis on which to plot, with the same projection as the one specified.
1396
    use_attrs : dict
1397
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1398
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1399
        Only the keys found in the default dict can be used.
1400
    projection : ccrs.Projection
1401
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1402
    features : list or dict
1403
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1404
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1405
    geometries_kw : dict
1406
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1407
    frame : bool
1408
        Show or hide frame. Default False.
1409

1410
    Returns
1411
    -------
1412
    matplotlib.axes.Axes
1413
    """
1414
    # add features
1415
    if features:
×
1416
        add_cartopy_features(ax, features)
×
1417

1418
    set_plot_attrs(use_attrs, data, ax)
×
1419

1420
    if frame is False:
×
1421
        ax.spines["geo"].set_visible(False)
×
1422

1423
    # add geometries
1424
    if geometries_kw:
×
1425
        if "geoms" not in geometries_kw.keys():
×
1426
            warnings.warn(
×
1427
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})'
1428
            )
1429
        if "crs" in geometries_kw.keys():
×
1430
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1431
                geometries_kw["geoms"], geometries_kw["crs"]
1432
            )
1433
        else:
1434
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1435
        geometries_kw = {
×
1436
            "crs": projection,
1437
            "facecolor": "none",
1438
            "edgecolor": "black",
1439
        } | geometries_kw
1440

1441
        ax.add_geometries(**geometries_kw)
×
1442
    return ax
×
1443

1444

1445
def masknan_sizes_key(data, sizes) -> xr.Dataset:
6✔
1446
    """Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1447

1448
    Parameters
1449
    ----------
1450
    data: xr.Dataset
1451
        xr.Dataset used to plot
1452
    sizes: str
1453
        Variable used to plot markersize
1454

1455
    Returns
1456
    -------
1457
    xr.Dataset
1458
    """
1459
    # find variable name
1460
    kl = list(data.keys())
×
1461
    kl.remove(sizes)
×
1462
    key = kl[0]
×
1463

1464
    # Create a mask for missing 'sizes' data
1465
    size_mask = np.isnan(data[sizes])
×
1466

1467
    # Set 'key' values to NaN where 'sizes' is missing
1468
    data[key] = data[key].where(~size_mask)
×
1469

1470
    # Create a mask for missing 'key' data
1471
    key_mask = np.isnan(data[key])
×
1472

1473
    # Set 'sizes' values to NaN where 'key' is missing
1474
    data[sizes] = data[sizes].where(~key_mask)
×
1475
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc