• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 18505281621

14 Oct 2025 05:45PM UTC coverage: 8.147% (+0.07%) from 8.075%
18505281621

Pull #365

github

web-flow
Merge 4379e1fba into 46b46ffc4
Pull Request #365: Make get_var_group usable outside

3 of 5 new or added lines in 1 file covered. (60.0%)

236 existing lines in 1 file now uncovered.

157 of 1927 relevant lines covered (8.15%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.5
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
5✔
4
import json
5✔
5
import math
5✔
6
import pathlib
5✔
7
import re
5✔
8
import warnings
5✔
9
from collections.abc import Callable
5✔
10
from copy import deepcopy
5✔
11
from pathlib import Path
5✔
12
from tempfile import NamedTemporaryFile
5✔
13
from typing import Any
5✔
14

15
import cairosvg
5✔
16
import cartopy.crs as ccrs
5✔
17
import cartopy.feature as cfeature
5✔
18
import geopandas as gpd
5✔
19
import matplotlib as mpl
5✔
20
import matplotlib.axes
5✔
21
import matplotlib.colors as mcolors
5✔
22
import matplotlib.pyplot as plt
5✔
23
import numpy as np
5✔
24
import pandas as pd
5✔
25
import seaborn
5✔
26
import xarray as xr
5✔
27
import yaml
5✔
28
from matplotlib.lines import Line2D
5✔
29
from skimage.transform import resize
5✔
30
from xclim.core.options import METADATA_LOCALES
5✔
31
from xclim.core.options import OPTIONS as XC_OPTIONS
5✔
32

33
from .._logo import Logos
5✔
34

35

36
TERMS: dict = {}
5✔
37
"""
5✔
38
A translation directory for special terms to appear on the plots.
39

40
Keys are terms to translate and they map to "locale": "translation" dictionaries.
41
The "official" figanos terms are based on figanos/data/terms.yml.
42
"""
43

44

45
# Load terms translations
46
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
5✔
47
    TERMS = yaml.safe_load(f)
5✔
48

49

50
def get_localized_term(term, locale=None):
5✔
51
    """
52
    Get `term` translated into `locale`.
53

54
    Terms are pulled from the :py:data:`TERMS` dictionary.
55

56
    Parameters
57
    ----------
58
    term : str
59
        A word or short phrase to translate.
60
    locale : str, optional
61
        A 2-letter locale name to translate to.
62
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
63

64
    Returns
65
    -------
66
    str
67
        Translated term.
68
    """
UNCOV
69
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
UNCOV
70
    if locale == "en":
×
UNCOV
71
        return term
×
72

73
    if term not in TERMS:
×
74
        warnings.warn(f"No translation known for term '{term}'.", stacklevel=2)
×
UNCOV
75
        return term
×
76

77
    if locale not in TERMS[term]:
×
78
        warnings.warn(f"No {locale} translation known for term '{term}'.", stacklevel=2)
×
UNCOV
79
        return term
×
80

81
    return TERMS[term][locale]
×
82

83

84
def empty_dict(param) -> dict:
5✔
85
    """Return empty dict if input is None."""
UNCOV
86
    if param is None:
×
UNCOV
87
        param = dict()
×
UNCOV
88
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
89

90

91
def check_timeindex(
5✔
92
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
93
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
94
    """
95
    Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
96

97
    Parameters
98
    ----------
99
    xr_objs : xr.DataArray or xr.Dataset or dict
100
        Dictionary containing Xarray DataArrays or Datasets.
101

102
    Returns
103
    -------
104
    xr.DataArray or xr.Dataset or dict
105
        Dictionary of xarray objects with a pandas DatetimeIndex
106
    """
UNCOV
107
    if isinstance(xr_objs, dict):
×
UNCOV
108
        for name, obj in xr_objs.items():
×
UNCOV
109
            if "time" in obj.dims:
×
110
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
111
                    conv_obj = obj.convert_calendar(
×
112
                        "standard", use_cftime=None, align_on="year"
113
                    )
114
                    xr_objs[name] = conv_obj
×
UNCOV
115
                    warnings.warn(
×
116
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
117
                    )
118

119
    else:
UNCOV
120
        if "time" in xr_objs.dims:
×
UNCOV
121
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
UNCOV
122
                conv_obj = xr_objs.convert_calendar(
×
123
                    "standard", use_cftime=None, align_on="year"
124
                )
125
                xr_objs = conv_obj
×
UNCOV
126
                warnings.warn(
×
127
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
128
                )
129

UNCOV
130
    return xr_objs
×
131

132

133
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
5✔
134
    """
135
    Get an array category, which determines how to plot the array.
136

137
    Parameters
138
    ----------
139
    array : Dataset or DataArray
140
        The array being categorized.
141

142
    Returns
143
    -------
144
    str
145
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
146
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
147
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
148
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
149
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
150
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
151
        DS: any Dataset that is not  recognized as an ensemble
152
        DA: DataArray
153
    """
UNCOV
154
    if isinstance(array, xr.Dataset):
×
UNCOV
155
        if (
×
156
            pd.notnull(
157
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
158
            ).sum()
159
            >= 2
160
        ):
UNCOV
161
            cat = "ENS_PCT_VAR_DS"
×
UNCOV
162
        elif (
×
163
            pd.notnull(
164
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
165
            ).sum()
166
            >= 2
167
        ):
UNCOV
168
            cat = "ENS_STATS_VAR_DS"
×
UNCOV
169
        elif "percentiles" in array.dims:
×
UNCOV
170
            cat = "ENS_PCT_DIM_DS"
×
171
        elif "realization" in array.dims:
×
172
            cat = "ENS_REALS_DS"
×
173
        else:
174
            cat = "DS"
×
175

UNCOV
176
    elif isinstance(array, xr.DataArray):
×
177
        if "percentiles" in array.dims:
×
UNCOV
178
            cat = "ENS_PCT_DIM_DA"
×
179
        elif "realization" in array.dims:
×
180
            cat = "ENS_REALS_DA"
×
181
        else:
182
            cat = "DA"
×
183
    else:
UNCOV
184
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
185

UNCOV
186
    return cat
×
187

188

189
def get_attributes(
5✔
190
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
191
) -> str:
192
    """
193
    Fetch attributes or dims corresponding to keys from Xarray objects.
194

195
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
196
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
197

198
    Parameters
199
    ----------
200
    string : str
201
        String corresponding to an attribute name.
202
    xr_obj : DataArray or Dataset
203
        The Xarray object containing the attributes.
204
    locale : str, optional
205
        A 2-letter locale name to translate to.
206
        Default is None, which will pull the locale
207
        from xclim's "metadata_locales" option (taking the first).
208

209
    Returns
210
    -------
211
    str
212
        Xarray attribute value as string or empty string if not found
213
    """
UNCOV
214
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
UNCOV
215
    if locale != "en":
×
UNCOV
216
        names = [f"{string}_{locale}", string]
×
217
    else:
218
        names = [string]
×
219

UNCOV
220
    for name in names:
×
221
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
UNCOV
222
            return xr_obj.attrs[name]
×
223

224
        if (
×
225
            isinstance(xr_obj, xr.Dataset)
226
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
227
        ):  # DataArray of first variable
UNCOV
228
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
229

UNCOV
230
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
231
            return xr_obj.attrs[name]
×
232

233
    warnings.warn(f'Attribute "{string}" not found.', stacklevel=2)
×
234
    return ""
×
235

236

237
def set_plot_attrs(
5✔
238
    attr_dict: dict[str, Any],
239
    xr_obj: xr.DataArray | xr.Dataset,
240
    ax: matplotlib.axes.Axes | None = None,
241
    title_loc: str = "center",
242
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
243
    wrap_kw: dict[str, Any] | None = None,
244
) -> matplotlib.axes.Axes:
245
    """
246
    Set plot elements according to Dataset or DataArray attributes.
247

248
    Uses get_attributes() to check for and get the string.
249

250
    Parameters
251
    ----------
252
    attr_dict : dict
253
        Dictionary containing specified attribute keys.
254
    xr_obj : Dataset or DataArray
255
        The Xarray object containing the attributes.
256
    ax : matplotlib axis
257
        The matplotlib axis of the plot.
258
    title_loc : str
259
        Location of the title.
260
    wrap_kw : dict, optional
261
        Arguments to pass to the wrap_text function for the title.
262

263
    Returns
264
    -------
265
    matplotlib.axes.Axes
266
    """
UNCOV
267
    wrap_kw = empty_dict(wrap_kw)
×
268

269
    #  check
270
    for key in attr_dict:
×
UNCOV
271
        if key not in [
×
272
            "title",
273
            "ylabel",
274
            "yunits",
275
            "xlabel",
276
            "xunits",
277
            "cbar_label",
278
            "cbar_units",
279
            "suptitle",
280
        ]:
UNCOV
281
            warnings.warn(f'Use_attrs element "{key}" not supported', stacklevel=2)
×
282

UNCOV
283
    if "title" in attr_dict:
×
284
        title = get_attributes(attr_dict["title"], xr_obj)
×
UNCOV
285
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
286

287
    if "ylabel" in attr_dict:
×
288
        if (
×
289
            "yunits" in attr_dict
290
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
291
        ):  # second condition avoids '[]' as label
UNCOV
292
            ylabel = wrap_text(
×
293
                get_attributes(attr_dict["ylabel"], xr_obj)
294
                + " ("
295
                + get_attributes(attr_dict["yunits"], xr_obj)
296
                + ")"
297
            )
298
        else:
UNCOV
299
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
300

UNCOV
301
        ax.set_ylabel(ylabel)
×
302

UNCOV
303
    if "xlabel" in attr_dict:
×
304
        if (
×
305
            "xunits" in attr_dict
306
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
307
        ):  # second condition avoids '[]' as label
UNCOV
308
            xlabel = wrap_text(
×
309
                get_attributes(attr_dict["xlabel"], xr_obj)
310
                + " ("
311
                + get_attributes(attr_dict["xunits"], xr_obj)
312
                + ")"
313
            )
314
        else:
UNCOV
315
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
316

UNCOV
317
        ax.set_xlabel(xlabel)
×
318

319
    # cbar label has to be assigned in main function, ignore.
320
    if "cbar_label" in attr_dict:
×
UNCOV
321
        pass
×
322

323
    if "cbar_units" in attr_dict:
×
324
        pass
×
325

326
    if facetgrid:
×
327
        if "suptitle" in attr_dict:
×
UNCOV
328
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
329
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
330
            facetgrid.set_titles(template="{value}")
×
331
        return facetgrid
×
332

333
    else:
334
        return ax
×
335

336

337
def get_suffix(string: str) -> str:
5✔
338
    """Get suffix of typical Xclim variable names."""
UNCOV
339
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
UNCOV
340
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
UNCOV
341
        return suffix
×
342
    else:
343
        raise ValueError(f"Mean, min or max not found in {string}")
×
344

345

346
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
5✔
347
    """
348
    Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
349

350
    Parameters
351
    ----------
352
    array_dict : dict
353
        Dictionary of format {'name': array...}.
354

355
    Returns
356
    -------
357
    dict
358
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
359
    """
UNCOV
360
    if len(array_dict) != 3:
×
UNCOV
361
        raise ValueError("Ensembles must contain exactly three arrays")
×
362

363
    sorted_lines = {}
×
364

UNCOV
365
    for name in array_dict.keys():
×
366
        suffix = get_suffix(name)
×
367

368
        if suffix.isalpha():
×
369
            if suffix in ["max", "Max"]:
×
UNCOV
370
                sorted_lines["upper"] = name
×
371
            elif suffix in ["min", "Min"]:
×
372
                sorted_lines["lower"] = name
×
373
            elif suffix in ["mean", "Mean"]:
×
374
                sorted_lines["middle"] = name
×
375
        elif suffix.isdigit():
×
376
            if int(suffix) >= 51:
×
377
                sorted_lines["upper"] = name
×
378
            elif int(suffix) <= 49:
×
379
                sorted_lines["lower"] = name
×
380
            elif int(suffix) == 50:
×
381
                sorted_lines["middle"] = name
×
382
        else:
383
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
384
    return sorted_lines
×
385

386

387
def loc_mpl(
5✔
388
    loc: str | tuple[int | float, int | float] | int,
389
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
390
    """
391
    Find coordinates and alignment associated to loc string.
392

393
    Parameters
394
    ----------
395
    loc : string, int, or tuple[float, float]
396
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
397
        If a tuple, must be in axes coordinates.
398

399
    Returns
400
    -------
401
    tuple(float, float), tuple(float, float), str, str
402
    """
UNCOV
403
    ha = "left"
×
UNCOV
404
    va = "bottom"
×
405

406
    loc_strings = [
×
407
        "upper right",
408
        "upper left",
409
        "lower left",
410
        "lower right",
411
        "right",
412
        "center left",
413
        "center right",
414
        "lower center",
415
        "upper center",
416
        "center",
417
    ]
418

UNCOV
419
    if isinstance(loc, int):
×
UNCOV
420
        try:
×
UNCOV
421
            loc = loc_strings[loc - 1]
×
422
        except IndexError as err:
×
423
            raise ValueError("loc must be between 1 and 10, inclusively") from err
×
424

425
    if loc in loc_strings:
×
426
        # ha
UNCOV
427
        if "left" in loc:
×
428
            ha = "left"
×
UNCOV
429
        elif "right" in loc:
×
430
            ha = "right"
×
431
        else:
432
            ha = "center"
×
433

434
        # va
435
        if "lower" in loc:
×
UNCOV
436
            va = "bottom"
×
UNCOV
437
        elif "upper" in loc:
×
438
            va = "top"
×
439
        else:
440
            va = "center"
×
441

442
        # transAxes
443
        if loc == "upper right":
×
UNCOV
444
            loc = (0.97, 0.97)
×
UNCOV
445
            box_a = (1, 1)
×
446
        elif loc == "upper left":
×
447
            loc = (0.03, 0.97)
×
448
            box_a = (0, 1)
×
449
        elif loc == "lower left":
×
450
            loc = (0.03, 0.03)
×
451
            box_a = (0, 0)
×
452
        elif loc == "lower right":
×
453
            loc = (0.97, 0.03)
×
454
            box_a = (1, 0)
×
455
        elif loc == "right":
×
456
            loc = (0.97, 0.5)
×
457
            box_a = (1, 0.5)
×
458
        elif loc == "center left":
×
459
            loc = (0.03, 0.5)
×
460
            box_a = (0, 0.5)
×
461
        elif loc == "center right":
×
462
            loc = (0.97, 0.5)
×
463
            box_a = (0.97, 0.5)
×
464
        elif loc == "lower center":
×
465
            loc = (0.5, 0.03)
×
466
            box_a = (0.5, 0)
×
467
        elif loc == "upper center":
×
468
            loc = (0.5, 0.97)
×
469
            box_a = (0.5, 1)
×
470
        else:
471
            loc = (0.5, 0.5)
×
472
            box_a = (0.5, 0.5)
×
473

474
    elif isinstance(loc, tuple):
×
475
        box_a = []
×
UNCOV
476
        for i in loc:
×
477
            if i > 1 or i < 0:
×
478
                raise ValueError(
×
479
                    "Text location coordinates must be between 0 and 1, inclusively"
480
                )
481
            elif i > 0.5:
×
UNCOV
482
                box_a.append(1)
×
483
            else:
484
                box_a.append(0)
×
485
        box_a = tuple(box_a)
×
486
    else:
487
        raise ValueError(
×
488
            "loc must be a string, int or tuple. "
489
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
490
        )
491

UNCOV
492
    return loc, box_a, ha, va
×
493

494

495
def plot_coords(
5✔
496
    ax: matplotlib.axes.Axes | None,
497
    xr_obj: xr.DataArray | xr.Dataset,
498
    loc: str | tuple[float, float] | int,
499
    param: str | None = None,
500
    backgroundalpha: float = 1,
501
) -> matplotlib.axes.Axes:
502
    """
503
    Place coordinates on plot area.
504

505
    Parameters
506
    ----------
507
    ax : matplotlib.axes.Axes or None
508
        Matplotlib axes object on which to place the text.
509
        If None, will use plt.figtext instead (should be used for facetgrids).
510
    xr_obj : xr.DataArray or xr.Dataset
511
        The xarray object from which to fetch the text content.
512
    param : {"location", "time"}, optional
513
        The parameter used.
514
    loc : string, int or tuple
515
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
516
        If a tuple, must be in axes coordinates.
517
    backgroundalpha : float
518
        Transparency of the text background. 1 is opaque, 0 is transparent.
519

520
    Returns
521
    -------
522
    matplotlib.axes.Axes
523
    """
UNCOV
524
    text = None
×
UNCOV
525
    if param == "location":
×
UNCOV
526
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
527
            text = "lat={:.2f}, lon={:.2f}".format(
×
528
                float(xr_obj["lat"]), float(xr_obj["lon"])
529
            )
530
        else:
UNCOV
531
            warnings.warn(
×
532
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords', stacklevel=2
533
            )
534
    if param == "time":
×
UNCOV
535
        if "time" in xr_obj.coords:
×
UNCOV
536
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
537

538
        else:
539
            warnings.warn('show_time set to True, but "time" not found in coords', stacklevel=2)
×
540

UNCOV
541
    loc, box_a, ha, va = loc_mpl(loc)
×
542

UNCOV
543
    if text:
×
544
        if ax:
×
UNCOV
545
            t = mpl.offsetbox.TextArea(
×
546
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
547
            )
548

UNCOV
549
            tt = mpl.offsetbox.AnnotationBbox(
×
550
                t,
551
                loc,
552
                xycoords="axes fraction",
553
                box_alignment=box_a,
554
                pad=0.05,
555
                bboxprops=dict(
556
                    facecolor="white",
557
                    alpha=backgroundalpha,
558
                    edgecolor="w",
559
                    boxstyle="Square, pad=0.5",
560
                ),
561
            )
UNCOV
562
            ax.add_artist(tt)
×
UNCOV
563
            return ax
×
UNCOV
564
        elif not ax:
×
565
            """
×
566
            if loc == "top left":
567
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
568
            elif loc == "top right":
569
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
570
            elif loc == "bottom left":
571
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
572
            elif loc == "bottom right" or loc is True:
573
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
574
            elif isinstance(loc, tuple):
575
                        else:
576
                raise ValueError(
577
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
578
                    f"'bottom right' or a tuple of coordinates."
579
                )
580
            """
UNCOV
581
            plt.figtext(
×
582
                loc[0],
583
                loc[1],
584
                text,
585
                ha=ha,
586
                va=va,
587
                fontsize=12,
588
            )
589

UNCOV
590
            return None
×
591

592

593
def find_logo(logo: str | pathlib.Path) -> str:
5✔
594
    """Read a logo file."""
UNCOV
595
    logos = Logos()
×
UNCOV
596
    if logo:
×
UNCOV
597
        logo_path = logos[logo]
×
598
    else:
599
        logo_path = logos.default
×
600

UNCOV
601
    if logo_path is None:
×
602
        raise ValueError(
×
603
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
604
        )
605
    return logo_path
×
606

607

608
def load_image(
5✔
609
    im: str | pathlib.Path,
610
    height: float | None,
611
    width: float | None,
612
    keep_ratio: bool = True,
613
) -> np.ndarray:
614
    """
615
    Scale an image to a specified height and width.
616

617
    Parameters
618
    ----------
619
    im : str or Path
620
        The image to be scaled. PNG and SVG formats are supported.
621
    height : float, optional
622
        The desired height of the image. If None, the original height is used.
623
    width : float, optional
624
        The desired width of the image. If None, the original width is used.
625
    keep_ratio : bool
626
        If True, the aspect ratio of the original image is maintained. Default is True.
627

628
    Returns
629
    -------
630
    np.ndarray
631
        The scaled image.
632
    """
UNCOV
633
    if pathlib.Path(im).suffix == ".png":
×
UNCOV
634
        image = mpl.pyplot.imread(im)
×
UNCOV
635
        original_height, original_width = image.shape[:2]
×
636

637
        if height is None and width is None:
×
638
            return image
×
639

640
        warnings.warn(
×
641
            "The scikit-image library is used to resize PNG images. This may affect logo image quality.", stacklevel=2
642
        )
643
        if not keep_ratio:
×
UNCOV
644
            height = original_height or height
×
UNCOV
645
            width = original_width or width
×
646
        else:
647
            if width is not None:
×
648
                if height is not None:
×
UNCOV
649
                    warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
650
                # Only width is provided, derive zoom factor for height based on aspect ratio
651
                height = (width / original_width) * original_height
×
652
            elif height is not None:
×
653
                # Only height is provided, derive zoom factor for width based on aspect ratio
654
                width = (height / original_height) * original_width
×
655

UNCOV
656
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
657

UNCOV
658
    elif pathlib.Path(im).suffix == ".svg":
×
659
        cairo_kwargs = dict(url=im)
×
UNCOV
660
        if not keep_ratio:
×
661
            if height is not None and width is not None:
×
662
                cairo_kwargs.update(output_height=height, output_width=width)
×
663
        elif width is not None:
×
664
            if height is not None:
×
665
                warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
666
            cairo_kwargs.update(output_width=width)
×
667
        elif height is not None:
×
668
            cairo_kwargs.update(output_height=height)
×
669

670
        with NamedTemporaryFile(suffix=".png") as png_file:
×
671
            cairo_kwargs.update(write_to=png_file.name)
×
UNCOV
672
            cairosvg.svg2png(**cairo_kwargs)
×
673
            return mpl.pyplot.imread(png_file.name)
×
674

675

676
def plot_logo(
5✔
677
    ax: matplotlib.axes.Axes,
678
    loc: str | tuple[float, float] | int,
679
    logo: str | pathlib.Path | Logos | None = None,
680
    height: float | None = None,
681
    width: float | None = None,
682
    keep_ratio: bool = True,
683
    **offset_image_kwargs,
684
) -> matplotlib.axes.Axes:
685
    r"""
686
    Place logo of plot area.
687

688
    Parameters
689
    ----------
690
    ax : matplotlib.axes.Axes
691
        Matplotlib axes object on which to place the text.
692
    loc : string, int or tuple
693
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
694
        If a tuple, must be in axes coordinates.
695
    logo : str, Path, figanos.Logos, optional
696
        A name (str) or Path to a logo file, or a name of an already-installed logo.
697
        If an existing is not found, the logo will be installed and accessible via the filename.
698
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
699
        The logo must be in 'png' format.
700
    height : float, optional
701
        The desired height of the image. If None, the original height is used.
702
    width : float, optional
703
        The desired width of the image. If None, the original width is used.
704
    keep_ratio : bool, optional
705
        If True, the aspect ratio of the original image is maintained. Default is True.
706
    \*\*offset_image_kwargs
707
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
708

709
    Returns
710
    -------
711
    matplotlib.axes.Axes
712
    """
UNCOV
713
    if offset_image_kwargs is None:
×
UNCOV
714
        offset_image_kwargs = {}
×
715

716
    if isinstance(logo, Logos):
×
717
        logo_path = logo.default
×
718
    else:
719
        logo_path = find_logo(logo)
×
720

UNCOV
721
    image = load_image(logo_path, height, width, keep_ratio)
×
722
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
723

724
    loc, box_a, ha, va = loc_mpl(loc)
×
725
    ab = mpl.offsetbox.AnnotationBbox(
×
726
        imagebox,
727
        loc,
728
        frameon=False,
729
        xycoords="axes fraction",
730
        box_alignment=box_a,
731
        pad=0.05,
732
    )
UNCOV
733
    ax.add_artist(ab)
×
UNCOV
734
    return ax
×
735

736

737
def split_legend(
5✔
738
    ax: matplotlib.axes.Axes,
739
    in_plot: bool = False,
740
    axis_factor: float = 0.15,
741
    label_gap: float = 0.02,
742
) -> matplotlib.axes.Axes:
743
    #  TODO: check for and fix overlapping labels
744
    """
745
    Draw line labels at the end of each line, or outside the plot.
746

747
    Parameters
748
    ----------
749
    ax : matplotlib.axes.Axes
750
        The axis containing the legend.
751
    in_plot : bool
752
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
753
    axis_factor : float
754
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
755
    label_gap : float
756
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
757

758
    Returns
759
    -------
760
    matplotlib.axes.Axes
761
    """
762
    # create extra space
UNCOV
763
    init_xbound = ax.get_xbound()
×
764

UNCOV
765
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
766
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
767

768
    if in_plot is True:
×
769
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
770

771
    # get legend and plot
772

UNCOV
773
    handles, labels = ax.get_legend_handles_labels()
×
UNCOV
774
    for handle, label in zip(handles, labels, strict=False):
×
UNCOV
775
        last_x = handle.get_xdata()[-1]
×
776
        last_y = handle.get_ydata()[-1]
×
777

778
        if isinstance(last_x, np.datetime64):
×
779
            last_x = mpl.dates.date2num(last_x)
×
780

781
        color = handle.get_color()
×
782
        # ls = handle.get_linestyle()
783

784
        if in_plot is True:
×
UNCOV
785
            ax.text(
×
786
                last_x + label_bump,
787
                last_y,
788
                label,
789
                ha="left",
790
                va="center",
791
                color=color,
792
            )
793
        else:
UNCOV
794
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
UNCOV
795
            ax.text(
×
796
                1.01,
797
                last_y,
798
                label,
799
                ha="left",
800
                va="center",
801
                color=color,
802
                transform=trans,
803
            )
804

UNCOV
805
    return ax
×
806

807

808
def fill_between_label(
5✔
809
    sorted_lines: dict[str, Any],
810
    name: str,
811
    array_categ: dict[str, Any],
812
    legend: str,
813
) -> str:
814
    """
815
    Create a label for the shading around a line in line plots.
816

817
    Parameters
818
    ----------
819
    sorted_lines : dict
820
        Dictionary created by the sort_lines() function.
821
    name : str
822
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
823
    array_categ : dict
824
        The categories of the array, as created by the get_array_categ function.
825
    legend : str
826
        Legend mode.
827

828
    Returns
829
    -------
830
    str
831
        Label to be applied to the legend element representing the shading.
832
    """
UNCOV
833
    if legend != "full":
×
UNCOV
834
        label = None
×
UNCOV
835
    elif array_categ[name] in [
×
836
        "ENS_PCT_VAR_DS",
837
        "ENS_PCT_DIM_DS",
838
        "ENS_PCT_DIM_DA",
839
    ]:
UNCOV
840
        label = get_localized_term("{}th-{}th percentiles").format(
×
841
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
842
        )
843
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
UNCOV
844
        label = get_localized_term("min-max range")
×
845
    else:
846
        label = None
×
847

UNCOV
848
    return label
×
849

850

851
def get_var_group(
5✔
852
    da: xr.DataArray | None = None,
853
    unique_str: str | None = None,
854
    path_to_json: str | pathlib.Path = Path(__file__).parents[1]
855
    / "data/ipcc_colors/variable_groups.json",
856
) -> str:
857
    """
858
    Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
859

860
    If `da` is a Dataset, look in the DataArray of the first variable.
861
    """
862
    # create dict
UNCOV
863
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
NEW
864
        var_dict = json.load(_f)
×
865

NEW
866
    matches = []
×
867

868
    if unique_str:
×
869
        for v in var_dict:
×
UNCOV
870
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
871
            if re.search(regex, unique_str):
×
UNCOV
872
                matches.append(var_dict[v])
×
873

874
    else:
875
        if isinstance(da, xr.Dataset):
×
876
            da = da[list(da.data_vars)[0]]
×
877
        # look in DataArray name
UNCOV
878
        if hasattr(da, "name") and isinstance(da.name, str):
×
UNCOV
879
            for v in var_dict:
×
880
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
881
                if re.search(regex, da.name):
×
UNCOV
882
                    matches.append(var_dict[v])
×
883

884
        # look in history
885
        if hasattr(da, "history") and len(matches) == 0:
×
886
            for v in var_dict:
×
887
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
UNCOV
888
                if re.search(regex, da.history):
×
UNCOV
889
                    matches.append(var_dict[v])
×
890

891
    matches = np.unique(matches)
×
892

893
    if len(matches) == 0:
×
894
        warnings.warn(
×
895
            "Colormap warning: Variable group not found. Use the cmap argument.", stacklevel=2
896
        )
UNCOV
897
        return "misc"
×
898
    elif len(matches) >= 2:
×
899
        warnings.warn(
×
900
            "Colormap warning: More than one variable group found. Use the cmap argument.", stacklevel=2
901
        )
902
        return "misc"
×
903
    else:
904
        return matches[0]
×
905

906

907
def create_cmap(
5✔
908
    var_group: str | None = None,
909
    divergent: bool | int = False,
910
    filename: str | None = None,
911
) -> matplotlib.colors.Colormap:
912
    """
913
    Create colormap according to variable group.
914

915
    Parameters
916
    ----------
917
    var_group : str, optional
918
        Variable group from IPCC scheme.
919
    divergent : bool or int
920
        Diverging colormap. If False, use sequential colormap.
921
    filename : str, optional
922
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
923

924
    Returns
925
    -------
926
    matplotlib.colors.Colormap
927
    """
UNCOV
928
    reverse = False
×
929

UNCOV
930
    if filename:
×
UNCOV
931
        folder = "continuous_colormaps_rgb_0-255"
×
UNCOV
932
        filename = filename.replace(".txt", "")
×
933

UNCOV
934
        if filename.endswith("_r"):
×
935
            reverse = True
×
936
            filename = filename[:-2]
×
937

938
    else:
939
        # filename
940
        if divergent is not False:
×
941
            if var_group == "misc2":
×
UNCOV
942
                var_group = "misc"
×
UNCOV
943
            filename = var_group + "_div"
×
944
        else:
945
            if var_group == "misc":
×
946
                filename = var_group + "_seq_3"  # Batlow
×
947
            elif var_group == "misc2":
×
948
                filename = "misc_seq_2"  # freezing rain
×
949
            else:
950
                filename = var_group + "_seq"
×
951

952
        folder = "continuous_colormaps_rgb_0-255"
×
953

954
    # parent should be 'figanos/'
955
    path = (
×
956
        pathlib.Path(__file__).parents[1]
957
        / "data"
958
        / "ipcc_colors"
959
        / folder
960
        / (filename + ".txt")
961
    )
962

UNCOV
963
    rgb_data = np.loadtxt(path)
×
964

965
    # convert to 0-1 RGB
UNCOV
966
    rgb_data = rgb_data / 255
×
967

968
    cmap = mcolors.LinearSegmentedColormap.from_list("cmap", rgb_data, N=256)
×
UNCOV
969
    if reverse is True:
×
UNCOV
970
        cmap = cmap.reversed()
×
971

UNCOV
972
    return cmap
×
973

974

975
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
5✔
976
    """
977
    Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
978

979
    Parameters
980
    ----------
981
    xr_obj : xr.DataArray or xr.Dataset
982
        The xarray object from which to look for the attributes.
983

984
    Returns
985
    -------
986
    ccrs.RotatedPole or None
987
    """
UNCOV
988
    try:
×
989

UNCOV
990
        if isinstance(xr_obj, xr.Dataset):
×
UNCOV
991
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
992

993
            if len(gridmap) > 1:
×
UNCOV
994
                warnings.warn(
×
995
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}.", stacklevel=2
996
                )
997

998
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
999
        else:
1000
            # If it can't find grid_mapping, assume it's rotated_pole
UNCOV
1001
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
1002

1003
        rotpole = ccrs.RotatedPole(
×
1004
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
1005
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
1006
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
1007
        )
1008
        return rotpole
×
1009

UNCOV
1010
    except AttributeError:
×
UNCOV
1011
        warnings.warn("Rotated pole not found. Specify a transform if necessary.", stacklevel=2)
×
UNCOV
1012
        return None
×
1013

1014

1015
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
5✔
1016
    """
1017
    Wrap text.
1018

1019
    Parameters
1020
    ----------
1021
    text : str
1022
        The text to wrap.
1023
    min_line_len : int
1024
        Minimum length of each line.
1025
    max_line_len : int
1026
        Maximum length of each line.
1027

1028
    Returns
1029
    -------
1030
    str
1031
        Wrapped text
1032
    """
UNCOV
1033
    start = min_line_len
×
UNCOV
1034
    stop = max_line_len
×
UNCOV
1035
    sep = "\n"
×
UNCOV
1036
    remaining = len(text)
×
1037

1038
    if len(text) >= max_line_len:
×
1039
        while remaining > max_line_len:
×
1040
            if ". " in text[start:stop]:
×
1041
                pos = text.find(". ", start, stop) + 1
×
UNCOV
1042
            elif ": " in text[start:stop]:
×
1043
                pos = text.find(": ", start, stop) + 1
×
1044
            elif " " in text[start:stop]:
×
1045
                pos = text.rfind(" ", start, stop)
×
1046
            else:
1047
                warnings.warn("No spaces, points or colons to break line at.", stacklevel=2)
×
1048
                break
×
1049

1050
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1051

1052
            remaining = len(text) - len(text[:pos])
×
1053
            start = pos + 1 + min_line_len
×
UNCOV
1054
            stop = pos + 1 + max_line_len
×
1055

UNCOV
1056
    return text
×
1057

1058

1059
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
5✔
1060
    """
1061
    Open shapefile with geopandas and convert to cartopy projection.
1062

1063
    Parameters
1064
    ----------
1065
    df : gpd.GeoDataFrame
1066
        GeoDataFrame (geopandas) geometry to be added to axis.
1067
    proj : ccrs.CRS
1068
        Projection to use, taken from the cartopy.crs options.
1069

1070
    Returns
1071
    -------
1072
    gpd.GeoDataFrame
1073
        GeoDataFrame adjusted to given projection
1074
    """
UNCOV
1075
    prj4 = proj.proj4_init
×
UNCOV
1076
    return df.to_crs(prj4)
×
1077

1078

1079
def convert_scen_name(name: str) -> str:
5✔
1080
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1081
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
UNCOV
1082
    if matches:
×
UNCOV
1083
        for s in matches:
×
UNCOV
1084
            if sum(c.isdigit() for c in s) == 3:
×
UNCOV
1085
                new_s = s.replace(
×
1086
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1087
                ).upper()  # ssp245 to SSP2-4.5
1088
                new_name = name.replace(s, new_s)  # put back in name
×
1089
            elif sum(c.isdigit() for c in s) == 2:
×
1090
                new_s = s.replace(
×
1091
                    s[-2:], s[-2] + "." + s[-1]
1092
                ).upper()  # rcp45 to RCP4.5
1093
                new_name = name.replace(s, new_s)
×
1094
            else:
1095
                new_s = s.upper()  # cmip5 to CMIP5
×
UNCOV
1096
                new_name = name.replace(s, new_s)
×
1097

1098
        return new_name
×
1099
    else:
1100
        return name
×
1101

1102

1103
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
5✔
1104
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1105
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
UNCOV
1106
        color_dict = json.load(_f)
×
1107

UNCOV
1108
    color = None
×
UNCOV
1109
    for entry in color_dict:
×
1110
        if entry in name:
×
1111
            color = color_dict[entry]
×
UNCOV
1112
            color = tuple([i / 255 for i in color])
×
1113
            break
×
1114

1115
    return color
×
1116

1117

1118
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
5✔
1119
    """Apply function to dictionary keys."""
1120
    old_keys = [key for key in dct]
×
UNCOV
1121
    for old_key in old_keys:
×
UNCOV
1122
        new_key = func(old_key)
×
UNCOV
1123
        dct[new_key] = dct.pop(old_key)
×
UNCOV
1124
    return dct
×
1125

1126

1127
def categorical_colors() -> dict[str, str]:
5✔
1128
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1129
    path = (
×
1130
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1131
    )
UNCOV
1132
    with path.open(encoding="utf-8") as _f:
×
UNCOV
1133
        cat = json.load(_f)
×
1134

UNCOV
1135
        return cat
×
1136

1137

1138
def get_mpl_styles() -> dict[str, pathlib.Path]:
5✔
1139
    """Get the available matplotlib styles and their paths as a dictionary."""
1140
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
UNCOV
1141
    styles = {style.stem: style for style in files}
×
UNCOV
1142
    return styles
×
1143

1144

1145
def set_mpl_style(*args: str, reset: bool = False) -> None:
5✔
1146
    """
1147
    Set the matplotlib style using one or more stylesheets.
1148

1149
    Parameters
1150
    ----------
1151
    args : str
1152
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1153
    reset : bool
1154
        If True, reset style to matplotlib default before applying the stylesheets.
1155

1156
    Returns
1157
    -------
1158
    None
1159
    """
UNCOV
1160
    if reset is True:
×
UNCOV
1161
        mpl.style.use("default")
×
UNCOV
1162
    for s in args:
×
UNCOV
1163
        if s.endswith(".mplstyle") is True:
×
UNCOV
1164
            mpl.style.use(s)
×
1165
        elif s in get_mpl_styles():
×
1166
            mpl.style.use(get_mpl_styles()[s])
×
1167
        else:
1168
            warnings.warn(f"Style {s} not found.", stacklevel=2)
×
1169

1170

1171
def add_cartopy_features(
5✔
1172
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1173
) -> matplotlib.axes.Axes:
1174
    """
1175
    Add cartopy features to matplotlib axes.
1176

1177
    Parameters
1178
    ----------
1179
    ax : matplotlib.axes.Axes
1180
        The axes on which to add the features.
1181
    features : list or dict
1182
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1183

1184
    Returns
1185
    -------
1186
    matplotlib.axes.Axes
1187
        The axis with added features.
1188
    """
UNCOV
1189
    if isinstance(features, list):
×
UNCOV
1190
        features = {f: {} for f in features}
×
1191

UNCOV
1192
    for feat in features:
×
UNCOV
1193
        if "scale" not in features[feat]:
×
1194
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1195
        else:
UNCOV
1196
            scale = features[feat].pop("scale")
×
1197
            ax.add_feature(
×
1198
                getattr(cfeature, feat.upper()).with_scale(scale),
1199
                **features[feat],
1200
            )
1201
            features[feat]["scale"] = scale  # put back
×
1202
    return ax
×
1203

1204

1205
def custom_cmap_norm(
5✔
1206
    cmap,
1207
    vmin: int | float,
1208
    vmax: int | float,
1209
    levels: int | list[int | float] | None = None,
1210
    divergent: bool | int | float = False,
1211
    linspace_out: bool = False,
1212
) -> matplotlib.colors.Normalize | np.ndarray:
1213
    """
1214
    Get matplotlib normalization according to main function arguments.
1215

1216
    Parameters
1217
    ----------
1218
    cmap: matplotlib.colormap
1219
        Colormap to be used with the normalization.
1220
    vmin: int or float
1221
        Minimum of the data to be plotted with the colormap.
1222
    vmax: int or float
1223
        Maximum of the data to be plotted with the colormap.
1224
    levels : int or list, optional
1225
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1226
    divergent : bool or int or float
1227
        If int or float, becomes center of cmap. Default center is 0.
1228
    linspace_out: bool
1229
        If True, return array created by np.linspace() instead of normalization instance.
1230

1231
    Returns
1232
    -------
1233
    matplotlib.colors.Normalize
1234
    """
1235
    # get cmap if string
UNCOV
1236
    if isinstance(cmap, str):
×
UNCOV
1237
        if cmap in plt.colormaps():
×
UNCOV
1238
            cmap = matplotlib.colormaps[cmap]
×
1239
        else:
UNCOV
1240
            raise ValueError("Colormap not found")
×
1241

1242
    # make vmin and vmax prettier
1243
    if (vmax - vmin) >= 25:
×
UNCOV
1244
        rvmax = math.ceil(vmax / 10.0) * 10
×
1245
        rvmin = math.floor(vmin / 10.0) * 10
×
UNCOV
1246
    elif 1 <= (vmax - vmin) < 25:
×
UNCOV
1247
        rvmax = math.ceil(vmax / 1) * 1
×
1248
        rvmin = math.floor(vmin / 1) * 1
×
1249
    elif 0.1 <= (vmax - vmin) < 1:
×
1250
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1251
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1252
    else:
1253
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1254
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1255

1256
    # center
UNCOV
1257
    center = None
×
1258
    if divergent is not False:
×
1259
        if divergent is True:
×
UNCOV
1260
            center = 0
×
UNCOV
1261
        elif isinstance(divergent, int | float):
×
1262
            center = divergent
×
1263

1264
    # build norm with options
1265
    if center is not None and isinstance(levels, int):
×
1266
        if center <= rvmin or center >= rvmax:
×
1267
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
UNCOV
1268
        if levels % 2 == 1:
×
UNCOV
1269
            half_levels = int((levels + 1) / 2) + 1
×
1270
        else:
1271
            half_levels = int(levels / 2) + 1
×
1272

1273
        lin = np.concatenate(
×
1274
            (
1275
                np.linspace(rvmin, center, num=half_levels),
1276
                np.linspace(center, rvmax, num=half_levels)[1:],
1277
            )
1278
        )
UNCOV
1279
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1280

UNCOV
1281
        if linspace_out:
×
UNCOV
1282
            return lin
×
1283

1284
    elif levels is not None:
×
UNCOV
1285
        if isinstance(levels, list):
×
1286
            if center is not None:
×
1287
                warnings.warn(
×
1288
                    "Divergent argument ignored when levels is a list. Use levels as a number instead.", stacklevel=2
1289
                )
1290
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1291
        else:
1292
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
UNCOV
1293
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1294

1295
            if linspace_out:
×
UNCOV
1296
                return lin
×
1297

1298
    elif center is not None:
×
UNCOV
1299
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1300
    else:
1301
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1302

1303
    return norm
×
1304

1305

1306
def norm2range(
5✔
1307
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1308
) -> np.ndarray:
1309
    """Normalize data across a specific range."""
UNCOV
1310
    if data_range is None:
×
UNCOV
1311
        if len(data) > 1:
×
UNCOV
1312
            data_range = (np.nanmin(data), np.nanmax(data))
×
1313
        else:
UNCOV
1314
            raise ValueError(" if data is not an array, data_range must be specified")
×
1315

1316
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1317

UNCOV
1318
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1319

1320

1321
def size_legend_elements(
5✔
1322
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1323
) -> list[matplotlib.lines.Line2D]:
1324
    """
1325
    Create handles to use in a point-size legend.
1326

1327
    Parameters
1328
    ----------
1329
    data : np.ndarray
1330
        Data used to determine the point sizes.
1331
    sizes : np.ndarray
1332
        Array of point sizes.
1333
    max_entries : int
1334
        Maximum number of entries in the legend.
1335
    marker: str
1336
        Marker to use in legend.
1337

1338
    Returns
1339
    -------
1340
    list of matplotlib.lines.Line2D
1341
    """
1342
    # how many increments of 10 pts**2 are there in the sizes
UNCOV
1343
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1344

1345
    # divide data in those increments
UNCOV
1346
    lgd_data = np.linspace(min(data), max(data), n)
×
1347

1348
    # round according to range
UNCOV
1349
    ratio = abs(max(data) - min(data) / n)
×
1350

1351
    if ratio >= 1000:
×
UNCOV
1352
        rounding = 1000
×
UNCOV
1353
    elif 100 <= ratio < 1000:
×
1354
        rounding = 100
×
UNCOV
1355
    elif 10 <= ratio < 100:
×
1356
        rounding = 10
×
1357
    elif 5 <= ratio < 10:
×
1358
        rounding = 5
×
1359
    elif 1 <= ratio < 5:
×
1360
        rounding = 1
×
1361
    elif 0.1 <= ratio < 1:
×
1362
        rounding = 0.1
×
1363
    elif 0.01 <= ratio < 0.1:
×
1364
        rounding = 0.01
×
1365
    else:
1366
        rounding = 0.001
×
1367

1368
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1369

1370
    # convert back to sizes
1371
    lgd_sizes = norm2range(
×
1372
        data=lgd_data,
1373
        data_range=(min(data), max(data)),
1374
        target_range=(min(sizes), max(sizes)),
1375
    )
1376

UNCOV
1377
    legend_elements = []
×
1378

UNCOV
1379
    for s, d in zip(lgd_sizes, lgd_data, strict=False):
×
UNCOV
1380
        if isinstance(d, float) and d.is_integer():
×
UNCOV
1381
            label = str(int(d))
×
1382
        else:
UNCOV
1383
            label = str(d)
×
1384

1385
        legend_elements.append(
×
1386
            Line2D(
1387
                [0],
1388
                [0],
1389
                marker=marker,
1390
                color="k",
1391
                lw=0,
1392
                markerfacecolor="w",
1393
                label=label,
1394
                markersize=np.sqrt(np.abs(s)),
1395
            )
1396
        )
1397

UNCOV
1398
    if len(legend_elements) > max_entries:
×
UNCOV
1399
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1400
    else:
UNCOV
1401
        return legend_elements
×
1402

1403

1404
def add_features_map(
5✔
1405
    data,
1406
    ax,
1407
    use_attrs,
1408
    projection,
1409
    features,
1410
    geometries_kw,
1411
    frame,
1412
) -> matplotlib.axes.Axes:
1413
    """
1414
    Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1415

1416
    Parameters
1417
    ----------
1418
    data : dict, DataArray or Dataset
1419
        Input data do plot. If dictionary, must have only one entry.
1420
    ax : matplotlib axis
1421
        Matplotlib axis on which to plot, with the same projection as the one specified.
1422
    use_attrs : dict
1423
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1424
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1425
        Only the keys found in the default dict can be used.
1426
    projection : ccrs.Projection
1427
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1428
    features : list or dict
1429
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1430
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1431
    geometries_kw : dict
1432
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1433
    frame : bool
1434
        Show or hide frame. Default False.
1435

1436
    Returns
1437
    -------
1438
    matplotlib.axes.Axes
1439
    """
1440
    # add features
UNCOV
1441
    if features:
×
UNCOV
1442
        add_cartopy_features(ax, features)
×
1443

UNCOV
1444
    set_plot_attrs(use_attrs, data, ax)
×
1445

1446
    if frame is False:
×
1447
        ax.spines["geo"].set_visible(False)
×
1448

1449
    # add geometries
UNCOV
1450
    if geometries_kw:
×
1451
        if "geoms" not in geometries_kw.keys():
×
1452
            warnings.warn(
×
1453
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})', stacklevel=2
1454
            )
1455
        if "crs" in geometries_kw.keys():
×
1456
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1457
                geometries_kw["geoms"], geometries_kw["crs"]
1458
            )
1459
        else:
1460
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1461
        geometries_kw = {
×
1462
            "crs": projection,
1463
            "facecolor": "none",
1464
            "edgecolor": "black",
1465
        } | geometries_kw
1466

UNCOV
1467
        ax.add_geometries(**geometries_kw)
×
UNCOV
1468
    return ax
×
1469

1470

1471
def masknan_sizes_key(data, sizes) -> xr.Dataset:
5✔
1472
    """
1473
    Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1474

1475
    Parameters
1476
    ----------
1477
    data: xr.Dataset
1478
        xr.Dataset used to plot
1479
    sizes: str
1480
        Variable used to plot markersize
1481

1482
    Returns
1483
    -------
1484
    xr.Dataset
1485
    """
1486
    # find variable name
UNCOV
1487
    kl = list(data.keys())
×
UNCOV
1488
    kl.remove(sizes)
×
UNCOV
1489
    key = kl[0]
×
1490

1491
    # Create a mask for missing 'sizes' data
1492
    size_mask = np.isnan(data[sizes])
×
1493

1494
    # Set 'key' values to NaN where 'sizes' is missing
UNCOV
1495
    data[key] = data[key].where(~size_mask)
×
1496

1497
    # Create a mask for missing 'key' data
UNCOV
1498
    key_mask = np.isnan(data[key])
×
1499

1500
    # Set 'sizes' values to NaN where 'key' is missing
UNCOV
1501
    data[sizes] = data[sizes].where(~key_mask)
×
UNCOV
1502
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc