• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 22629955296

03 Mar 2026 03:26PM UTC coverage: 9.11%. First build
22629955296

Pull #368

github

web-flow
Merge 5ea218e5f into 2dc2780f0
Pull Request #368: Enregistrement des cmaps de l'IPCC dans matplotlib

12 of 35 new or added lines in 2 files covered. (34.29%)

173 of 1899 relevant lines covered (9.11%)

0.64 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

14.14
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
7✔
4
import importlib.resources
7✔
5
import json
7✔
6
import math
7✔
7
import pathlib
7✔
8
import re
7✔
9
import warnings
7✔
10
from collections.abc import Callable
7✔
11
from copy import deepcopy
7✔
12
from pathlib import Path
7✔
13
from tempfile import NamedTemporaryFile
7✔
14
from typing import Any
7✔
15

16
import cairosvg
7✔
17
import cartopy.crs as ccrs
7✔
18
import cartopy.feature as cfeature
7✔
19
import geopandas as gpd
7✔
20
import matplotlib as mpl
7✔
21
import matplotlib.axes
7✔
22
import matplotlib.colors as mcolors
7✔
23
import matplotlib.pyplot as plt
7✔
24
import numpy as np
7✔
25
import pandas as pd
7✔
26
import seaborn
7✔
27
import xarray as xr
7✔
28
import yaml
7✔
29
from matplotlib.lines import Line2D
7✔
30
from skimage.transform import resize
7✔
31
from xclim.core.options import METADATA_LOCALES
7✔
32
from xclim.core.options import OPTIONS as XC_OPTIONS
7✔
33

34
from .._logo import Logos
7✔
35

36

37
# file to map variable key words to variable group for IPCC color scheme
38
VARJSON = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
7✔
39

40
TERMS: dict = {}
7✔
41
"""
7✔
42
A translation directory for special terms to appear on the plots.
43

44
Keys are terms to translate and they map to "locale": "translation" dictionaries.
45
The "official" figanos terms are based on figanos/data/terms.yml.
46
"""
47

48

49
# Load terms translations
50
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
7✔
51
    TERMS = yaml.safe_load(f)
7✔
52

53

54
def get_localized_term(term, locale=None):
7✔
55
    """
56
    Get `term` translated into `locale`.
57

58
    Terms are pulled from the :py:data:`TERMS` dictionary.
59

60
    Parameters
61
    ----------
62
    term : str
63
        A word or short phrase to translate.
64
    locale : str, optional
65
        A 2-letter locale name to translate to.
66
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
67

68
    Returns
69
    -------
70
    str
71
        Translated term.
72
    """
73
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
74
    if locale == "en":
×
75
        return term
×
76

77
    if term not in TERMS:
×
78
        warnings.warn(f"No translation known for term '{term}'.", stacklevel=2)
×
79
        return term
×
80

81
    if locale not in TERMS[term]:
×
82
        warnings.warn(f"No {locale} translation known for term '{term}'.", stacklevel=2)
×
83
        return term
×
84

85
    return TERMS[term][locale]
×
86

87

88
def empty_dict(param) -> dict:
7✔
89
    """Return empty dict if input is None."""
90
    if param is None:
×
91
        param = dict()
×
92
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
93

94

95
def check_timeindex(
7✔
96
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
97
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
98
    """
99
    Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
100

101
    Parameters
102
    ----------
103
    xr_objs : xr.DataArray or xr.Dataset or dict
104
        Dictionary containing Xarray DataArrays or Datasets.
105

106
    Returns
107
    -------
108
    xr.DataArray or xr.Dataset or dict
109
        Dictionary of xarray objects with a pandas DatetimeIndex
110
    """
111
    if isinstance(xr_objs, dict):
×
112
        for name, obj in xr_objs.items():
×
113
            if "time" in obj.dims:
×
114
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
115
                    conv_obj = obj.convert_calendar(
×
116
                        "standard", use_cftime=None, align_on="year"
117
                    )
118
                    xr_objs[name] = conv_obj
×
119
                    warnings.warn(
×
120
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
121
                    )
122

123
    else:
124
        if "time" in xr_objs.dims:
×
125
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
126
                conv_obj = xr_objs.convert_calendar(
×
127
                    "standard", use_cftime=None, align_on="year"
128
                )
129
                xr_objs = conv_obj
×
130
                warnings.warn(
×
131
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
132
                )
133

134
    return xr_objs
×
135

136

137
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
7✔
138
    """
139
    Get an array category, which determines how to plot the array.
140

141
    Parameters
142
    ----------
143
    array : Dataset or DataArray
144
        The array being categorized.
145

146
    Returns
147
    -------
148
    str
149
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
150
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
151
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
152
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
153
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
154
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
155
        DS: any Dataset that is not  recognized as an ensemble
156
        DA: DataArray
157
    """
158
    if isinstance(array, xr.Dataset):
×
159
        if (
×
160
            pd.notnull(
161
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
162
            ).sum()
163
            >= 2
164
        ):
165
            cat = "ENS_PCT_VAR_DS"
×
166
        elif (
×
167
            pd.notnull(
168
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
169
            ).sum()
170
            >= 2
171
        ):
172
            cat = "ENS_STATS_VAR_DS"
×
173
        elif "percentiles" in array.dims:
×
174
            cat = "ENS_PCT_DIM_DS"
×
175
        elif "realization" in array.dims:
×
176
            cat = "ENS_REALS_DS"
×
177
        else:
178
            cat = "DS"
×
179

180
    elif isinstance(array, xr.DataArray):
×
181
        if "percentiles" in array.dims:
×
182
            cat = "ENS_PCT_DIM_DA"
×
183
        elif "realization" in array.dims:
×
184
            cat = "ENS_REALS_DA"
×
185
        else:
186
            cat = "DA"
×
187
    else:
188
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
189

190
    return cat
×
191

192

193
def get_attributes(
7✔
194
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
195
) -> str:
196
    """
197
    Fetch attributes or dims corresponding to keys from Xarray objects.
198

199
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
200
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
201

202
    Parameters
203
    ----------
204
    string : str
205
        String corresponding to an attribute name.
206
    xr_obj : DataArray or Dataset
207
        The Xarray object containing the attributes.
208
    locale : str, optional
209
        A 2-letter locale name to translate to.
210
        Default is None, which will pull the locale
211
        from xclim's "metadata_locales" option (taking the first).
212

213
    Returns
214
    -------
215
    str
216
        Xarray attribute value as string or empty string if not found
217
    """
218
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
219
    if locale != "en":
×
220
        names = [f"{string}_{locale}", string]
×
221
    else:
222
        names = [string]
×
223

224
    for name in names:
×
225
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
226
            return xr_obj.attrs[name]
×
227

228
        if (
×
229
            isinstance(xr_obj, xr.Dataset)
230
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
231
        ):  # DataArray of first variable
232
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
233

234
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
235
            return xr_obj.attrs[name]
×
236

237
    warnings.warn(f'Attribute "{string}" not found.', stacklevel=2)
×
238
    return ""
×
239

240

241
def set_plot_attrs(
7✔
242
    attr_dict: dict[str, Any],
243
    xr_obj: xr.DataArray | xr.Dataset,
244
    ax: matplotlib.axes.Axes | None = None,
245
    title_loc: str = "center",
246
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
247
    wrap_kw: dict[str, Any] | None = None,
248
) -> matplotlib.axes.Axes:
249
    """
250
    Set plot elements according to Dataset or DataArray attributes.
251

252
    Uses get_attributes() to check for and get the string.
253

254
    Parameters
255
    ----------
256
    attr_dict : dict
257
        Dictionary containing specified attribute keys.
258
    xr_obj : Dataset or DataArray
259
        The Xarray object containing the attributes.
260
    ax : matplotlib axis
261
        The matplotlib axis of the plot.
262
    title_loc : str
263
        Location of the title.
264
    wrap_kw : dict, optional
265
        Arguments to pass to the wrap_text function for the title.
266

267
    Returns
268
    -------
269
    matplotlib.axes.Axes
270
    """
271
    wrap_kw = empty_dict(wrap_kw)
×
272

273
    #  check
274
    for key in attr_dict:
×
275
        if key not in [
×
276
            "title",
277
            "ylabel",
278
            "yunits",
279
            "xlabel",
280
            "xunits",
281
            "cbar_label",
282
            "cbar_units",
283
            "suptitle",
284
        ]:
285
            warnings.warn(f'Use_attrs element "{key}" not supported', stacklevel=2)
×
286

287
    if "title" in attr_dict:
×
288
        title = get_attributes(attr_dict["title"], xr_obj)
×
289
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
290

291
    if "ylabel" in attr_dict:
×
292
        if (
×
293
            "yunits" in attr_dict
294
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
295
        ):  # second condition avoids '[]' as label
296
            ylabel = wrap_text(
×
297
                get_attributes(attr_dict["ylabel"], xr_obj)
298
                + " ("
299
                + get_attributes(attr_dict["yunits"], xr_obj)
300
                + ")"
301
            )
302
        else:
303
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
304

305
        ax.set_ylabel(ylabel)
×
306

307
    if "xlabel" in attr_dict:
×
308
        if (
×
309
            "xunits" in attr_dict
310
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
311
        ):  # second condition avoids '[]' as label
312
            xlabel = wrap_text(
×
313
                get_attributes(attr_dict["xlabel"], xr_obj)
314
                + " ("
315
                + get_attributes(attr_dict["xunits"], xr_obj)
316
                + ")"
317
            )
318
        else:
319
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
320

321
        ax.set_xlabel(xlabel)
×
322

323
    # cbar label has to be assigned in main function, ignore.
324
    if "cbar_label" in attr_dict:
×
325
        pass
×
326

327
    if "cbar_units" in attr_dict:
×
328
        pass
×
329

330
    if facetgrid:
×
331
        if "suptitle" in attr_dict:
×
332
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
333
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
334
            facetgrid.set_titles(template="{value}")
×
335
        return facetgrid
×
336

337
    else:
338
        return ax
×
339

340

341
def get_suffix(string: str) -> str:
7✔
342
    """Get suffix of typical Xclim variable names."""
343
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
344
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
345
        return suffix
×
346
    else:
347
        raise ValueError(f"Mean, min or max not found in {string}")
×
348

349

350
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
7✔
351
    """
352
    Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
353

354
    Parameters
355
    ----------
356
    array_dict : dict
357
        Dictionary of format {'name': array...}.
358

359
    Returns
360
    -------
361
    dict
362
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
363
    """
364
    if len(array_dict) != 3:
×
365
        raise ValueError("Ensembles must contain exactly three arrays")
×
366

367
    sorted_lines = {}
×
368

369
    for name in array_dict.keys():
×
370
        suffix = get_suffix(name)
×
371

372
        if suffix.isalpha():
×
373
            if suffix in ["max", "Max"]:
×
374
                sorted_lines["upper"] = name
×
375
            elif suffix in ["min", "Min"]:
×
376
                sorted_lines["lower"] = name
×
377
            elif suffix in ["mean", "Mean"]:
×
378
                sorted_lines["middle"] = name
×
379
        elif suffix.isdigit():
×
380
            if int(suffix) >= 51:
×
381
                sorted_lines["upper"] = name
×
382
            elif int(suffix) <= 49:
×
383
                sorted_lines["lower"] = name
×
384
            elif int(suffix) == 50:
×
385
                sorted_lines["middle"] = name
×
386
        else:
387
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
388
    return sorted_lines
×
389

390

391
def loc_mpl(
7✔
392
    loc: str | tuple[int | float, int | float] | int,
393
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
394
    """
395
    Find coordinates and alignment associated to loc string.
396

397
    Parameters
398
    ----------
399
    loc : string, int, or tuple[float, float]
400
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
401
        If a tuple, must be in axes coordinates.
402

403
    Returns
404
    -------
405
    tuple(float, float), tuple(float, float), str, str
406
    """
407
    ha = "left"
×
408
    va = "bottom"
×
409

410
    loc_strings = [
×
411
        "upper right",
412
        "upper left",
413
        "lower left",
414
        "lower right",
415
        "right",
416
        "center left",
417
        "center right",
418
        "lower center",
419
        "upper center",
420
        "center",
421
    ]
422

423
    if isinstance(loc, int):
×
424
        try:
×
425
            loc = loc_strings[loc - 1]
×
426
        except IndexError as err:
×
427
            raise ValueError("loc must be between 1 and 10, inclusively") from err
×
428

429
    if loc in loc_strings:
×
430
        # ha
431
        if "left" in loc:
×
432
            ha = "left"
×
433
        elif "right" in loc:
×
434
            ha = "right"
×
435
        else:
436
            ha = "center"
×
437

438
        # va
439
        if "lower" in loc:
×
440
            va = "bottom"
×
441
        elif "upper" in loc:
×
442
            va = "top"
×
443
        else:
444
            va = "center"
×
445

446
        # transAxes
447
        if loc == "upper right":
×
448
            loc = (0.97, 0.97)
×
449
            box_a = (1, 1)
×
450
        elif loc == "upper left":
×
451
            loc = (0.03, 0.97)
×
452
            box_a = (0, 1)
×
453
        elif loc == "lower left":
×
454
            loc = (0.03, 0.03)
×
455
            box_a = (0, 0)
×
456
        elif loc == "lower right":
×
457
            loc = (0.97, 0.03)
×
458
            box_a = (1, 0)
×
459
        elif loc == "right":
×
460
            loc = (0.97, 0.5)
×
461
            box_a = (1, 0.5)
×
462
        elif loc == "center left":
×
463
            loc = (0.03, 0.5)
×
464
            box_a = (0, 0.5)
×
465
        elif loc == "center right":
×
466
            loc = (0.97, 0.5)
×
467
            box_a = (0.97, 0.5)
×
468
        elif loc == "lower center":
×
469
            loc = (0.5, 0.03)
×
470
            box_a = (0.5, 0)
×
471
        elif loc == "upper center":
×
472
            loc = (0.5, 0.97)
×
473
            box_a = (0.5, 1)
×
474
        else:
475
            loc = (0.5, 0.5)
×
476
            box_a = (0.5, 0.5)
×
477

478
    elif isinstance(loc, tuple):
×
479
        box_a = []
×
480
        for i in loc:
×
481
            if i > 1 or i < 0:
×
482
                raise ValueError(
×
483
                    "Text location coordinates must be between 0 and 1, inclusively"
484
                )
485
            elif i > 0.5:
×
486
                box_a.append(1)
×
487
            else:
488
                box_a.append(0)
×
489
        box_a = tuple(box_a)
×
490
    else:
491
        raise ValueError(
×
492
            "loc must be a string, int or tuple. "
493
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
494
        )
495

496
    return loc, box_a, ha, va
×
497

498

499
def plot_coords(
7✔
500
    ax: matplotlib.axes.Axes | None,
501
    xr_obj: xr.DataArray | xr.Dataset,
502
    loc: str | tuple[float, float] | int,
503
    param: str | None = None,
504
    backgroundalpha: float = 1,
505
) -> matplotlib.axes.Axes:
506
    """
507
    Place coordinates on plot area.
508

509
    Parameters
510
    ----------
511
    ax : matplotlib.axes.Axes or None
512
        Matplotlib axes object on which to place the text.
513
        If None, will use plt.figtext instead (should be used for facetgrids).
514
    xr_obj : xr.DataArray or xr.Dataset
515
        The xarray object from which to fetch the text content.
516
    param : {"location", "time"}, optional
517
        The parameter used.
518
    loc : string, int or tuple
519
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
520
        If a tuple, must be in axes coordinates.
521
    backgroundalpha : float
522
        Transparency of the text background. 1 is opaque, 0 is transparent.
523

524
    Returns
525
    -------
526
    matplotlib.axes.Axes
527
    """
528
    text = None
×
529
    if param == "location":
×
530
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
531
            text = "lat={:.2f}, lon={:.2f}".format(
×
532
                float(xr_obj["lat"]), float(xr_obj["lon"])
533
            )
534
        else:
535
            warnings.warn(
×
536
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords', stacklevel=2
537
            )
538
    if param == "time":
×
539
        if "time" in xr_obj.coords:
×
540
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
541

542
        else:
543
            warnings.warn('show_time set to True, but "time" not found in coords', stacklevel=2)
×
544

545
    loc, box_a, ha, va = loc_mpl(loc)
×
546

547
    if text:
×
548
        if ax:
×
549
            t = mpl.offsetbox.TextArea(
×
550
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
551
            )
552

553
            tt = mpl.offsetbox.AnnotationBbox(
×
554
                t,
555
                loc,
556
                xycoords="axes fraction",
557
                box_alignment=box_a,
558
                pad=0.05,
559
                bboxprops=dict(
560
                    facecolor="white",
561
                    alpha=backgroundalpha,
562
                    edgecolor="w",
563
                    boxstyle="Square, pad=0.5",
564
                ),
565
            )
566
            ax.add_artist(tt)
×
567
            return ax
×
568
        elif not ax:
×
569
            """
×
570
            if loc == "top left":
571
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
572
            elif loc == "top right":
573
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
574
            elif loc == "bottom left":
575
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
576
            elif loc == "bottom right" or loc is True:
577
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
578
            elif isinstance(loc, tuple):
579
                        else:
580
                raise ValueError(
581
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
582
                    f"'bottom right' or a tuple of coordinates."
583
                )
584
            """
585
            plt.figtext(
×
586
                loc[0],
587
                loc[1],
588
                text,
589
                ha=ha,
590
                va=va,
591
                fontsize=12,
592
            )
593

594
            return None
×
595

596

597
def find_logo(logo: str | pathlib.Path) -> str:
7✔
598
    """Read a logo file."""
599
    logos = Logos()
×
600
    if logo:
×
601
        logo_path = logos[logo]
×
602
    else:
603
        logo_path = logos.default
×
604

605
    if logo_path is None:
×
606
        raise ValueError(
×
607
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
608
        )
609
    return logo_path
×
610

611

612
def load_image(
7✔
613
    im: str | pathlib.Path,
614
    height: float | None,
615
    width: float | None,
616
    keep_ratio: bool = True,
617
) -> np.ndarray:
618
    """
619
    Scale an image to a specified height and width.
620

621
    Parameters
622
    ----------
623
    im : str or Path
624
        The image to be scaled. PNG and SVG formats are supported.
625
    height : float, optional
626
        The desired height of the image. If None, the original height is used.
627
    width : float, optional
628
        The desired width of the image. If None, the original width is used.
629
    keep_ratio : bool
630
        If True, the aspect ratio of the original image is maintained. Default is True.
631

632
    Returns
633
    -------
634
    np.ndarray
635
        The scaled image.
636
    """
637
    if pathlib.Path(im).suffix == ".png":
×
638
        image = mpl.pyplot.imread(im)
×
639
        original_height, original_width = image.shape[:2]
×
640

641
        if height is None and width is None:
×
642
            return image
×
643

644
        warnings.warn(
×
645
            "The scikit-image library is used to resize PNG images. This may affect logo image quality.", stacklevel=2
646
        )
647
        if not keep_ratio:
×
648
            height = original_height or height
×
649
            width = original_width or width
×
650
        else:
651
            if width is not None:
×
652
                if height is not None:
×
653
                    warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
654
                # Only width is provided, derive zoom factor for height based on aspect ratio
655
                height = (width / original_width) * original_height
×
656
            elif height is not None:
×
657
                # Only height is provided, derive zoom factor for width based on aspect ratio
658
                width = (height / original_height) * original_width
×
659

660
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
661

662
    elif pathlib.Path(im).suffix == ".svg":
×
663
        cairo_kwargs = dict(url=im)
×
664
        if not keep_ratio:
×
665
            if height is not None and width is not None:
×
666
                cairo_kwargs.update(output_height=height, output_width=width)
×
667
        elif width is not None:
×
668
            if height is not None:
×
669
                warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
670
            cairo_kwargs.update(output_width=width)
×
671
        elif height is not None:
×
672
            cairo_kwargs.update(output_height=height)
×
673

674
        with NamedTemporaryFile(suffix=".png") as png_file:
×
675
            cairo_kwargs.update(write_to=png_file.name)
×
676
            cairosvg.svg2png(**cairo_kwargs)
×
677
            return mpl.pyplot.imread(png_file.name)
×
678

679

680
def plot_logo(
7✔
681
    ax: matplotlib.axes.Axes,
682
    loc: str | tuple[float, float] | int,
683
    logo: str | pathlib.Path | Logos | None = None,
684
    height: float | None = None,
685
    width: float | None = None,
686
    keep_ratio: bool = True,
687
    **offset_image_kwargs,
688
) -> matplotlib.axes.Axes:
689
    r"""
690
    Place logo of plot area.
691

692
    Parameters
693
    ----------
694
    ax : matplotlib.axes.Axes
695
        Matplotlib axes object on which to place the text.
696
    loc : string, int or tuple
697
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
698
        If a tuple, must be in axes coordinates.
699
    logo : str, Path, figanos.Logos, optional
700
        A name (str) or Path to a logo file, or a name of an already-installed logo.
701
        If an existing is not found, the logo will be installed and accessible via the filename.
702
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
703
        The logo must be in 'png' format.
704
    height : float, optional
705
        The desired height of the image. If None, the original height is used.
706
    width : float, optional
707
        The desired width of the image. If None, the original width is used.
708
    keep_ratio : bool, optional
709
        If True, the aspect ratio of the original image is maintained. Default is True.
710
    \*\*offset_image_kwargs
711
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
712

713
    Returns
714
    -------
715
    matplotlib.axes.Axes
716
    """
717
    if offset_image_kwargs is None:
×
718
        offset_image_kwargs = {}
×
719

720
    if isinstance(logo, Logos):
×
721
        logo_path = logo.default
×
722
    else:
723
        logo_path = find_logo(logo)
×
724

725
    image = load_image(logo_path, height, width, keep_ratio)
×
726
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
727

728
    loc, box_a, ha, va = loc_mpl(loc)
×
729
    ab = mpl.offsetbox.AnnotationBbox(
×
730
        imagebox,
731
        loc,
732
        frameon=False,
733
        xycoords="axes fraction",
734
        box_alignment=box_a,
735
        pad=0.05,
736
    )
737
    ax.add_artist(ab)
×
738
    return ax
×
739

740

741
def split_legend(
7✔
742
    ax: matplotlib.axes.Axes,
743
    in_plot: bool = False,
744
    axis_factor: float = 0.15,
745
    label_gap: float = 0.02,
746
) -> matplotlib.axes.Axes:
747
    #  TODO: check for and fix overlapping labels
748
    """
749
    Draw line labels at the end of each line, or outside the plot.
750

751
    Parameters
752
    ----------
753
    ax : matplotlib.axes.Axes
754
        The axis containing the legend.
755
    in_plot : bool
756
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
757
    axis_factor : float
758
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
759
    label_gap : float
760
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
761

762
    Returns
763
    -------
764
    matplotlib.axes.Axes
765
    """
766
    # create extra space
767
    init_xbound = ax.get_xbound()
×
768

769
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
770
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
771

772
    if in_plot is True:
×
773
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
774

775
    # get legend and plot
776

777
    handles, labels = ax.get_legend_handles_labels()
×
778
    for handle, label in zip(handles, labels, strict=False):
×
779
        last_x = handle.get_xdata()[-1]
×
780
        last_y = handle.get_ydata()[-1]
×
781

782
        if isinstance(last_x, np.datetime64):
×
783
            last_x = mpl.dates.date2num(last_x)
×
784

785
        color = handle.get_color()
×
786
        # ls = handle.get_linestyle()
787

788
        if in_plot is True:
×
789
            ax.text(
×
790
                last_x + label_bump,
791
                last_y,
792
                label,
793
                ha="left",
794
                va="center",
795
                color=color,
796
            )
797
        else:
798
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
799
            ax.text(
×
800
                1.01,
801
                last_y,
802
                label,
803
                ha="left",
804
                va="center",
805
                color=color,
806
                transform=trans,
807
            )
808

809
    return ax
×
810

811

812
def fill_between_label(
7✔
813
    sorted_lines: dict[str, Any],
814
    name: str,
815
    array_categ: dict[str, Any],
816
    legend: str,
817
) -> str:
818
    """
819
    Create a label for the shading around a line in line plots.
820

821
    Parameters
822
    ----------
823
    sorted_lines : dict
824
        Dictionary created by the sort_lines() function.
825
    name : str
826
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
827
    array_categ : dict
828
        The categories of the array, as created by the get_array_categ function.
829
    legend : str
830
        Legend mode.
831

832
    Returns
833
    -------
834
    str
835
        Label to be applied to the legend element representing the shading.
836
    """
837
    if legend != "full":
×
838
        label = None
×
839
    elif array_categ[name] in [
×
840
        "ENS_PCT_VAR_DS",
841
        "ENS_PCT_DIM_DS",
842
        "ENS_PCT_DIM_DA",
843
    ]:
844
        label = get_localized_term("{}th-{}th percentiles").format(
×
845
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
846
        )
847
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
848
        label = get_localized_term("min-max range")
×
849
    else:
850
        label = None
×
851

852
    return label
×
853

854

855
def get_var_group(
7✔
856
    da: xr.DataArray | None = None,
857
    unique_str: str | None = None,
858
    path_to_json: str | pathlib.Path | None = None,
859
) -> str:
860
    """
861
    Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
862

863
    If `da` is a Dataset, look in the DataArray of the first variable.
864
    """
865
    if path_to_json is None:
×
866
        path_to_json = VARJSON
×
867

868
    # create dict
869
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
870
        var_dict = json.load(_f)
×
871

872
    matches = []
×
873

874
    if unique_str:
×
875
        for v in var_dict:
×
876
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
877
            if re.search(regex, unique_str):
×
878
                matches.append(var_dict[v])
×
879

880
    else:
881
        if isinstance(da, xr.Dataset):
×
882
            da = da[list(da.data_vars)[0]]
×
883
        # look in DataArray name
884
        if hasattr(da, "name") and isinstance(da.name, str):
×
885
            for v in var_dict:
×
886
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
887
                if re.search(regex, da.name):
×
888
                    matches.append(var_dict[v])
×
889

890
        # look in history
891
        if hasattr(da, "history") and len(matches) == 0:
×
892
            for v in var_dict:
×
893
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
894
                if re.search(regex, da.history):
×
895
                    matches.append(var_dict[v])
×
896

897
    matches = np.unique(matches)
×
898

899
    if len(matches) == 0:
×
900
        warnings.warn(
×
901
            "Colormap warning: Variable group not found. Use the cmap argument.", stacklevel=2
902
        )
903
        return "misc"
×
904
    elif len(matches) >= 2:
×
905
        warnings.warn(
×
906
            "Colormap warning: More than one variable group found. Use the cmap argument.", stacklevel=2
907
        )
908
        return "misc"
×
909
    else:
910
        return matches[0]
×
911

912

913
def get_ipcc_cmap_name(
7✔
914
    var_group: str | None = None,
915
    divergent: bool | int = False,
916
    filename: str | None = None,
917
    reverse: bool = False
918
) -> matplotlib.colors.Colormap:
919
    """
920
    Get colormap name according to variable group or filename.
921

922
    Parameters
923
    ----------
924
    var_group : str, optional
925
        Variable group from IPCC scheme.
926
    divergent : bool or int
927
        Diverging colormap. If False, use sequential colormap.
928
    filename : str, optional
929
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
930
    reverse: bool
931
        If True, the name of the reverse color order colormap is returned.
932

933
    Returns
934
    -------
935
    str
936
        Name of the colormap in `matplotlib.colormaps`.
937
    """
938
    if filename:
×
939
        filename = filename.replace(".txt", "")
×
940
    else:
941
        # filename
942
        if divergent is not False:
×
943
            if var_group == "misc2":
×
944
                var_group = "misc"
×
945
            filename = var_group + "_div"
×
946
        else:
947
            if var_group == "misc":
×
948
                filename = var_group + "_seq_3"  # Batlow
×
949
            elif var_group == "misc2":
×
950
                filename = "misc_seq_2"  # freezing rain
×
951
            else:
952
                filename = var_group + "_seq"
×
NEW
953
    return filename
×
954

955

956
def create_ipcc_cmap(filename: str, reverse: bool = False):
7✔
957
    """Create an IPCC colormap from a filename."""
958
    # Ensure filename is only the stem (no ext, no parents)
959
    filename = pathlib.Path(filename).stem
7✔
960
    if filename.endswith("_r"):
7✔
NEW
961
        reverse = True
×
NEW
962
        filename = filename[:-2]
×
963

964
    with importlib.resources.path('figanos.data.ipcc_colors.continuous_colormaps_rgb_0-255', f'{filename}.txt') as p:
7✔
965
        rgb_data = np.loadtxt(p)
7✔
966

967
    # convert to 0-1 RGB
968
    rgb_data = rgb_data / 255
7✔
969
    cmap = mcolors.LinearSegmentedColormap.from_list(filename, rgb_data, N=256)
7✔
970
    if reverse is True:
7✔
971
        # this also adds "_r" to the name
972
        cmap = cmap.reversed()
7✔
973

974
    return cmap
7✔
975

976

977
# Register cmaps
978
for name in importlib.resources.contents('figanos.data.ipcc_colors.continuous_colormaps_rgb_0-255'):
7✔
979
    name = name.replace('.txt', '')
7✔
980
    for reverse in [True, False]:
7✔
981
        mpl.colormaps.register(create_ipcc_cmap(name, reverse))
7✔
982

983

984
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
7✔
985
    """
986
    Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
987

988
    Parameters
989
    ----------
990
    xr_obj : xr.DataArray or xr.Dataset
991
        The xarray object from which to look for the attributes.
992

993
    Returns
994
    -------
995
    ccrs.RotatedPole or None
996
    """
997
    try:
×
998

999
        if isinstance(xr_obj, xr.Dataset):
×
1000
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
1001

1002
            if len(gridmap) > 1:
×
1003
                warnings.warn(
×
1004
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}.", stacklevel=2
1005
                )
1006

1007
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
1008
        else:
1009
            # If it can't find grid_mapping, assume it's rotated_pole
1010
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
1011

1012
        rotpole = ccrs.RotatedPole(
×
1013
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
1014
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
1015
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
1016
        )
1017
        return rotpole
×
1018

1019
    except AttributeError:
×
1020
        warnings.warn("Rotated pole not found. Specify a transform if necessary.", stacklevel=2)
×
1021
        return None
×
1022

1023

1024
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
7✔
1025
    """
1026
    Wrap text.
1027

1028
    Parameters
1029
    ----------
1030
    text : str
1031
        The text to wrap.
1032
    min_line_len : int
1033
        Minimum length of each line.
1034
    max_line_len : int
1035
        Maximum length of each line.
1036

1037
    Returns
1038
    -------
1039
    str
1040
        Wrapped text
1041
    """
1042
    start = min_line_len
×
1043
    stop = max_line_len
×
1044
    sep = "\n"
×
1045
    remaining = len(text)
×
1046

1047
    if len(text) >= max_line_len:
×
1048
        while remaining > max_line_len:
×
1049
            if ". " in text[start:stop]:
×
1050
                pos = text.find(". ", start, stop) + 1
×
1051
            elif ": " in text[start:stop]:
×
1052
                pos = text.find(": ", start, stop) + 1
×
1053
            elif " " in text[start:stop]:
×
1054
                pos = text.rfind(" ", start, stop)
×
1055
            else:
1056
                warnings.warn("No spaces, points or colons to break line at.", stacklevel=2)
×
1057
                break
×
1058

1059
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1060

1061
            remaining = len(text) - len(text[:pos])
×
1062
            start = pos + 1 + min_line_len
×
1063
            stop = pos + 1 + max_line_len
×
1064

1065
    return text
×
1066

1067

1068
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
7✔
1069
    """
1070
    Open shapefile with geopandas and convert to cartopy projection.
1071

1072
    Parameters
1073
    ----------
1074
    df : gpd.GeoDataFrame
1075
        GeoDataFrame (geopandas) geometry to be added to axis.
1076
    proj : ccrs.CRS
1077
        Projection to use, taken from the cartopy.crs options.
1078

1079
    Returns
1080
    -------
1081
    gpd.GeoDataFrame
1082
        GeoDataFrame adjusted to given projection
1083
    """
1084
    prj4 = proj.proj4_init
×
1085
    return df.to_crs(prj4)
×
1086

1087

1088
def convert_scen_name(name: str) -> str:
7✔
1089
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1090
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1091
    if matches:
×
1092
        for s in matches:
×
1093
            if sum(c.isdigit() for c in s) == 3:
×
1094
                new_s = s.replace(
×
1095
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1096
                ).upper()  # ssp245 to SSP2-4.5
1097
                new_name = name.replace(s, new_s)  # put back in name
×
1098
            elif sum(c.isdigit() for c in s) == 2:
×
1099
                new_s = s.replace(
×
1100
                    s[-2:], s[-2] + "." + s[-1]
1101
                ).upper()  # rcp45 to RCP4.5
1102
                new_name = name.replace(s, new_s)
×
1103
            else:
1104
                new_s = s.upper()  # cmip5 to CMIP5
×
1105
                new_name = name.replace(s, new_s)
×
1106

1107
        return new_name
×
1108
    else:
1109
        return name
×
1110

1111

1112
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
7✔
1113
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1114
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1115
        color_dict = json.load(_f)
×
1116

1117
    color = None
×
1118
    for entry in color_dict:
×
1119
        if entry in name:
×
1120
            color = color_dict[entry]
×
1121
            color = tuple([i / 255 for i in color])
×
1122
            break
×
1123

1124
    return color
×
1125

1126

1127
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
7✔
1128
    """Apply function to dictionary keys."""
1129
    old_keys = [key for key in dct]
×
1130
    for old_key in old_keys:
×
1131
        new_key = func(old_key)
×
1132
        dct[new_key] = dct.pop(old_key)
×
1133
    return dct
×
1134

1135

1136
def categorical_colors() -> dict[str, str]:
7✔
1137
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1138
    path = (
×
1139
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1140
    )
1141
    with path.open(encoding="utf-8") as _f:
×
1142
        cat = json.load(_f)
×
1143

1144
        return cat
×
1145

1146

1147
def get_mpl_styles() -> dict[str, pathlib.Path]:
7✔
1148
    """Get the available matplotlib styles and their paths as a dictionary."""
1149
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1150
    styles = {style.stem: style for style in files}
×
1151
    return styles
×
1152

1153

1154
def set_mpl_style(*args: str, reset: bool = False) -> None:
7✔
1155
    """
1156
    Set the matplotlib style using one or more stylesheets.
1157

1158
    Parameters
1159
    ----------
1160
    args : str
1161
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1162
    reset : bool
1163
        If True, reset style to matplotlib default before applying the stylesheets.
1164

1165
    Returns
1166
    -------
1167
    None
1168
    """
1169
    if reset is True:
×
1170
        mpl.style.use("default")
×
1171
    for s in args:
×
1172
        if s.endswith(".mplstyle") is True:
×
1173
            mpl.style.use(s)
×
1174
        elif s in get_mpl_styles():
×
1175
            mpl.style.use(get_mpl_styles()[s])
×
1176
        else:
1177
            warnings.warn(f"Style {s} not found.", stacklevel=2)
×
1178

1179

1180
def add_cartopy_features(
7✔
1181
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1182
) -> matplotlib.axes.Axes:
1183
    """
1184
    Add cartopy features to matplotlib axes.
1185

1186
    Parameters
1187
    ----------
1188
    ax : matplotlib.axes.Axes
1189
        The axes on which to add the features.
1190
    features : list or dict
1191
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1192

1193
    Returns
1194
    -------
1195
    matplotlib.axes.Axes
1196
        The axis with added features.
1197
    """
1198
    if isinstance(features, list):
×
1199
        features = {f: {} for f in features}
×
1200

1201
    for feat in features:
×
1202
        if "scale" not in features[feat]:
×
1203
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1204
        else:
1205
            scale = features[feat].pop("scale")
×
1206
            ax.add_feature(
×
1207
                getattr(cfeature, feat.upper()).with_scale(scale),
1208
                **features[feat],
1209
            )
1210
            features[feat]["scale"] = scale  # put back
×
1211
    return ax
×
1212

1213

1214
def custom_cmap_norm(
7✔
1215
    cmap,
1216
    vmin: int | float,
1217
    vmax: int | float,
1218
    levels: int | list[int | float] | None = None,
1219
    divergent: bool | int | float = False,
1220
    linspace_out: bool = False,
1221
) -> matplotlib.colors.Normalize | np.ndarray:
1222
    """
1223
    Get matplotlib normalization according to main function arguments.
1224

1225
    Parameters
1226
    ----------
1227
    cmap: matplotlib.colormap
1228
        Colormap to be used with the normalization.
1229
    vmin: int or float
1230
        Minimum of the data to be plotted with the colormap.
1231
    vmax: int or float
1232
        Maximum of the data to be plotted with the colormap.
1233
    levels : int or list, optional
1234
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1235
    divergent : bool or int or float
1236
        If int or float, becomes center of cmap. Default center is 0.
1237
    linspace_out: bool
1238
        If True, return array created by np.linspace() instead of normalization instance.
1239

1240
    Returns
1241
    -------
1242
    matplotlib.colors.Normalize
1243
    """
1244
    # get cmap if string
1245
    if isinstance(cmap, str):
×
1246
        if cmap in plt.colormaps():
×
1247
            cmap = matplotlib.colormaps[cmap]
×
1248
        else:
1249
            raise ValueError("Colormap not found")
×
1250

1251
    # make vmin and vmax prettier
1252
    if (vmax - vmin) >= 25:
×
1253
        rvmax = math.ceil(vmax / 10.0) * 10
×
1254
        rvmin = math.floor(vmin / 10.0) * 10
×
1255
    elif 1 <= (vmax - vmin) < 25:
×
1256
        rvmax = math.ceil(vmax / 1) * 1
×
1257
        rvmin = math.floor(vmin / 1) * 1
×
1258
    elif 0.1 <= (vmax - vmin) < 1:
×
1259
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1260
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1261
    else:
1262
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1263
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1264

1265
    # center
1266
    center = None
×
1267
    if divergent is not False:
×
1268
        if divergent is True:
×
1269
            center = 0
×
1270
        elif isinstance(divergent, int | float):
×
1271
            center = divergent
×
1272

1273
    # build norm with options
1274
    if center is not None and isinstance(levels, int):
×
1275
        if center <= rvmin or center >= rvmax:
×
1276
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1277
        if levels % 2 == 1:
×
1278
            half_levels = int((levels + 1) / 2) + 1
×
1279
        else:
1280
            half_levels = int(levels / 2) + 1
×
1281

1282
        lin = np.concatenate(
×
1283
            (
1284
                np.linspace(rvmin, center, num=half_levels),
1285
                np.linspace(center, rvmax, num=half_levels)[1:],
1286
            )
1287
        )
1288
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1289

1290
        if linspace_out:
×
1291
            return lin
×
1292

1293
    elif levels is not None:
×
1294
        if isinstance(levels, list):
×
1295
            if center is not None:
×
1296
                warnings.warn(
×
1297
                    "Divergent argument ignored when levels is a list. Use levels as a number instead.", stacklevel=2
1298
                )
1299
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1300
        else:
1301
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1302
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1303

1304
            if linspace_out:
×
1305
                return lin
×
1306

1307
    elif center is not None:
×
1308
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1309
    else:
1310
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1311

1312
    return norm
×
1313

1314

1315
def norm2range(
7✔
1316
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1317
) -> np.ndarray:
1318
    """Normalize data across a specific range."""
1319
    if data_range is None:
×
1320
        if len(data) > 1:
×
1321
            data_range = (np.nanmin(data), np.nanmax(data))
×
1322
        else:
1323
            raise ValueError(" if data is not an array, data_range must be specified")
×
1324

1325
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1326

1327
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1328

1329

1330
def size_legend_elements(
7✔
1331
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1332
) -> list[matplotlib.lines.Line2D]:
1333
    """
1334
    Create handles to use in a point-size legend.
1335

1336
    Parameters
1337
    ----------
1338
    data : np.ndarray
1339
        Data used to determine the point sizes.
1340
    sizes : np.ndarray
1341
        Array of point sizes.
1342
    max_entries : int
1343
        Maximum number of entries in the legend.
1344
    marker: str
1345
        Marker to use in legend.
1346

1347
    Returns
1348
    -------
1349
    list of matplotlib.lines.Line2D
1350
    """
1351
    # how many increments of 10 pts**2 are there in the sizes
1352
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1353

1354
    # divide data in those increments
1355
    lgd_data = np.linspace(min(data), max(data), n)
×
1356

1357
    # round according to range
1358
    ratio = abs(max(data) - min(data) / n)
×
1359

1360
    if ratio >= 1000:
×
1361
        rounding = 1000
×
1362
    elif 100 <= ratio < 1000:
×
1363
        rounding = 100
×
1364
    elif 10 <= ratio < 100:
×
1365
        rounding = 10
×
1366
    elif 5 <= ratio < 10:
×
1367
        rounding = 5
×
1368
    elif 1 <= ratio < 5:
×
1369
        rounding = 1
×
1370
    elif 0.1 <= ratio < 1:
×
1371
        rounding = 0.1
×
1372
    elif 0.01 <= ratio < 0.1:
×
1373
        rounding = 0.01
×
1374
    else:
1375
        rounding = 0.001
×
1376

1377
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1378

1379
    # convert back to sizes
1380
    lgd_sizes = norm2range(
×
1381
        data=lgd_data,
1382
        data_range=(min(data), max(data)),
1383
        target_range=(min(sizes), max(sizes)),
1384
    )
1385

1386
    legend_elements = []
×
1387

1388
    for s, d in zip(lgd_sizes, lgd_data, strict=False):
×
1389
        if isinstance(d, float) and d.is_integer():
×
1390
            label = str(int(d))
×
1391
        else:
1392
            label = str(d)
×
1393

1394
        legend_elements.append(
×
1395
            Line2D(
1396
                [0],
1397
                [0],
1398
                marker=marker,
1399
                color="k",
1400
                lw=0,
1401
                markerfacecolor="w",
1402
                label=label,
1403
                markersize=np.sqrt(np.abs(s)),
1404
            )
1405
        )
1406

1407
    if len(legend_elements) > max_entries:
×
1408
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1409
    else:
1410
        return legend_elements
×
1411

1412

1413
def add_features_map(
7✔
1414
    data,
1415
    ax,
1416
    use_attrs,
1417
    projection,
1418
    features,
1419
    geometries_kw,
1420
    frame,
1421
) -> matplotlib.axes.Axes:
1422
    """
1423
    Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1424

1425
    Parameters
1426
    ----------
1427
    data : dict, DataArray or Dataset
1428
        Input data do plot. If dictionary, must have only one entry.
1429
    ax : matplotlib axis
1430
        Matplotlib axis on which to plot, with the same projection as the one specified.
1431
    use_attrs : dict
1432
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1433
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1434
        Only the keys found in the default dict can be used.
1435
    projection : ccrs.Projection
1436
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1437
    features : list or dict
1438
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1439
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1440
    geometries_kw : dict
1441
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1442
    frame : bool
1443
        Show or hide frame. Default False.
1444

1445
    Returns
1446
    -------
1447
    matplotlib.axes.Axes
1448
    """
1449
    # add features
1450
    if features:
×
1451
        add_cartopy_features(ax, features)
×
1452

1453
    set_plot_attrs(use_attrs, data, ax)
×
1454

1455
    if frame is False:
×
1456
        ax.spines["geo"].set_visible(False)
×
1457

1458
    # add geometries
1459
    if geometries_kw:
×
1460
        if "geoms" not in geometries_kw.keys():
×
1461
            warnings.warn(
×
1462
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})', stacklevel=2
1463
            )
1464
        if "crs" in geometries_kw.keys():
×
1465
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1466
                geometries_kw["geoms"], geometries_kw["crs"]
1467
            )
1468
        else:
1469
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1470
        geometries_kw = {
×
1471
            "crs": projection,
1472
            "facecolor": "none",
1473
            "edgecolor": "black",
1474
        } | geometries_kw
1475

1476
        ax.add_geometries(**geometries_kw)
×
1477
    return ax
×
1478

1479

1480
def masknan_sizes_key(data, sizes) -> xr.Dataset:
7✔
1481
    """
1482
    Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1483

1484
    Parameters
1485
    ----------
1486
    data: xr.Dataset
1487
        xr.Dataset used to plot
1488
    sizes: str
1489
        Variable used to plot markersize
1490

1491
    Returns
1492
    -------
1493
    xr.Dataset
1494
    """
1495
    # find variable name
1496
    kl = list(data.keys())
×
1497
    kl.remove(sizes)
×
1498
    key = kl[0]
×
1499

1500
    # Create a mask for missing 'sizes' data
1501
    size_mask = np.isnan(data[sizes])
×
1502

1503
    # Set 'key' values to NaN where 'sizes' is missing
1504
    data[key] = data[key].where(~size_mask)
×
1505

1506
    # Create a mask for missing 'key' data
1507
    key_mask = np.isnan(data[key])
×
1508

1509
    # Set 'sizes' values to NaN where 'key' is missing
1510
    data[sizes] = data[sizes].where(~key_mask)
×
1511
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc