• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 20276532443

16 Dec 2025 05:13PM UTC coverage: 9.11% (+0.9%) from 8.187%
20276532443

Pull #368

github

web-flow
Merge 9c88f65b3 into 07bbe06e5
Pull Request #368: Enregistrement des cmaps de l'IPCC dans matplotlib

12 of 35 new or added lines in 2 files covered. (34.29%)

1 existing line in 1 file now uncovered.

173 of 1899 relevant lines covered (9.11%)

0.55 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

14.14
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
6✔
4
import importlib.resources
6✔
5
import json
6✔
6
import math
6✔
7
import pathlib
6✔
8
import re
6✔
9
import warnings
6✔
10
from collections.abc import Callable
6✔
11
from copy import deepcopy
6✔
12
from pathlib import Path
6✔
13
from tempfile import NamedTemporaryFile
6✔
14
from typing import Any
6✔
15

16
import cairosvg
6✔
17
import cartopy.crs as ccrs
6✔
18
import cartopy.feature as cfeature
6✔
19
import geopandas as gpd
6✔
20
import matplotlib as mpl
6✔
21
import matplotlib.axes
6✔
22
import matplotlib.colors as mcolors
6✔
23
import matplotlib.pyplot as plt
6✔
24
import numpy as np
6✔
25
import pandas as pd
6✔
26
import seaborn
6✔
27
import xarray as xr
6✔
28
import yaml
6✔
29
from matplotlib.lines import Line2D
6✔
30
from skimage.transform import resize
6✔
31
from xclim.core.options import METADATA_LOCALES
6✔
32
from xclim.core.options import OPTIONS as XC_OPTIONS
6✔
33

34
from .._logo import Logos
6✔
35

36

37
# file to map variable key words to variable group for IPCC color scheme
38
VARJSON = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
6✔
39

40
TERMS: dict = {}
6✔
41
"""
6✔
42
A translation directory for special terms to appear on the plots.
43

44
Keys are terms to translate and they map to "locale": "translation" dictionaries.
45
The "official" figanos terms are based on figanos/data/terms.yml.
46
"""
47

48

49
# Load terms translations
50
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
6✔
51
    TERMS = yaml.safe_load(f)
6✔
52

53

54
def get_localized_term(term, locale=None):
6✔
55
    """
56
    Get `term` translated into `locale`.
57

58
    Terms are pulled from the :py:data:`TERMS` dictionary.
59

60
    Parameters
61
    ----------
62
    term : str
63
        A word or short phrase to translate.
64
    locale : str, optional
65
        A 2-letter locale name to translate to.
66
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
67

68
    Returns
69
    -------
70
    str
71
        Translated term.
72
    """
73
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
74
    if locale == "en":
×
75
        return term
×
76

77
    if term not in TERMS:
×
78
        warnings.warn(f"No translation known for term '{term}'.", stacklevel=2)
×
79
        return term
×
80

81
    if locale not in TERMS[term]:
×
82
        warnings.warn(f"No {locale} translation known for term '{term}'.", stacklevel=2)
×
83
        return term
×
84

85
    return TERMS[term][locale]
×
86

87

88
def empty_dict(param) -> dict:
6✔
89
    """Return empty dict if input is None."""
90
    if param is None:
×
91
        param = dict()
×
92
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
93

94

95
def check_timeindex(
6✔
96
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
97
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
98
    """
99
    Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
100

101
    Parameters
102
    ----------
103
    xr_objs : xr.DataArray or xr.Dataset or dict
104
        Dictionary containing Xarray DataArrays or Datasets.
105

106
    Returns
107
    -------
108
    xr.DataArray or xr.Dataset or dict
109
        Dictionary of xarray objects with a pandas DatetimeIndex
110
    """
111
    if isinstance(xr_objs, dict):
×
112
        for name, obj in xr_objs.items():
×
113
            if "time" in obj.dims:
×
114
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
115
                    conv_obj = obj.convert_calendar(
×
116
                        "standard", use_cftime=None, align_on="year"
117
                    )
118
                    xr_objs[name] = conv_obj
×
119
                    warnings.warn(
×
120
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
121
                    )
122

123
    else:
124
        if "time" in xr_objs.dims:
×
125
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
126
                conv_obj = xr_objs.convert_calendar(
×
127
                    "standard", use_cftime=None, align_on="year"
128
                )
129
                xr_objs = conv_obj
×
130
                warnings.warn(
×
131
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
132
                )
133

134
    return xr_objs
×
135

136

137
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
6✔
138
    """
139
    Get an array category, which determines how to plot the array.
140

141
    Parameters
142
    ----------
143
    array : Dataset or DataArray
144
        The array being categorized.
145

146
    Returns
147
    -------
148
    str
149
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
150
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
151
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
152
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
153
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
154
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
155
        DS: any Dataset that is not  recognized as an ensemble
156
        DA: DataArray
157
    """
158
    if isinstance(array, xr.Dataset):
×
159
        if (
×
160
            pd.notnull(
161
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
162
            ).sum()
163
            >= 2
164
        ):
165
            cat = "ENS_PCT_VAR_DS"
×
166
        elif (
×
167
            pd.notnull(
168
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
169
            ).sum()
170
            >= 2
171
        ):
172
            cat = "ENS_STATS_VAR_DS"
×
173
        elif "percentiles" in array.dims:
×
174
            cat = "ENS_PCT_DIM_DS"
×
175
        elif "realization" in array.dims:
×
176
            cat = "ENS_REALS_DS"
×
177
        else:
178
            cat = "DS"
×
179

180
    elif isinstance(array, xr.DataArray):
×
181
        if "percentiles" in array.dims:
×
182
            cat = "ENS_PCT_DIM_DA"
×
183
        elif "realization" in array.dims:
×
184
            cat = "ENS_REALS_DA"
×
185
        else:
186
            cat = "DA"
×
187
    else:
188
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
189

190
    return cat
×
191

192

193
def get_attributes(
6✔
194
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
195
) -> str:
196
    """
197
    Fetch attributes or dims corresponding to keys from Xarray objects.
198

199
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
200
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
201

202
    Parameters
203
    ----------
204
    string : str
205
        String corresponding to an attribute name.
206
    xr_obj : DataArray or Dataset
207
        The Xarray object containing the attributes.
208
    locale : str, optional
209
        A 2-letter locale name to translate to.
210
        Default is None, which will pull the locale
211
        from xclim's "metadata_locales" option (taking the first).
212

213
    Returns
214
    -------
215
    str
216
        Xarray attribute value as string or empty string if not found
217
    """
218
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
219
    if locale != "en":
×
220
        names = [f"{string}_{locale}", string]
×
221
    else:
222
        names = [string]
×
223

224
    for name in names:
×
225
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
226
            return xr_obj.attrs[name]
×
227

228
        if (
×
229
            isinstance(xr_obj, xr.Dataset)
230
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
231
        ):  # DataArray of first variable
232
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
233

234
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
235
            return xr_obj.attrs[name]
×
236

237
    warnings.warn(f'Attribute "{string}" not found.', stacklevel=2)
×
238
    return ""
×
239

240

241
def set_plot_attrs(
6✔
242
    attr_dict: dict[str, Any],
243
    xr_obj: xr.DataArray | xr.Dataset,
244
    ax: matplotlib.axes.Axes | None = None,
245
    title_loc: str = "center",
246
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
247
    wrap_kw: dict[str, Any] | None = None,
248
) -> matplotlib.axes.Axes:
249
    """
250
    Set plot elements according to Dataset or DataArray attributes.
251

252
    Uses get_attributes() to check for and get the string.
253

254
    Parameters
255
    ----------
256
    attr_dict : dict
257
        Dictionary containing specified attribute keys.
258
    xr_obj : Dataset or DataArray
259
        The Xarray object containing the attributes.
260
    ax : matplotlib axis
261
        The matplotlib axis of the plot.
262
    title_loc : str
263
        Location of the title.
264
    wrap_kw : dict, optional
265
        Arguments to pass to the wrap_text function for the title.
266

267
    Returns
268
    -------
269
    matplotlib.axes.Axes
270
    """
271
    wrap_kw = empty_dict(wrap_kw)
×
272

273
    #  check
274
    for key in attr_dict:
×
275
        if key not in [
×
276
            "title",
277
            "ylabel",
278
            "yunits",
279
            "xlabel",
280
            "xunits",
281
            "cbar_label",
282
            "cbar_units",
283
            "suptitle",
284
        ]:
285
            warnings.warn(f'Use_attrs element "{key}" not supported', stacklevel=2)
×
286

287
    if "title" in attr_dict:
×
288
        title = get_attributes(attr_dict["title"], xr_obj)
×
289
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
290

291
    if "ylabel" in attr_dict:
×
292
        if (
×
293
            "yunits" in attr_dict
294
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
295
        ):  # second condition avoids '[]' as label
296
            ylabel = wrap_text(
×
297
                get_attributes(attr_dict["ylabel"], xr_obj)
298
                + " ("
299
                + get_attributes(attr_dict["yunits"], xr_obj)
300
                + ")"
301
            )
302
        else:
303
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
304

305
        ax.set_ylabel(ylabel)
×
306

307
    if "xlabel" in attr_dict:
×
308
        if (
×
309
            "xunits" in attr_dict
310
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
311
        ):  # second condition avoids '[]' as label
312
            xlabel = wrap_text(
×
313
                get_attributes(attr_dict["xlabel"], xr_obj)
314
                + " ("
315
                + get_attributes(attr_dict["xunits"], xr_obj)
316
                + ")"
317
            )
318
        else:
319
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
320

321
        ax.set_xlabel(xlabel)
×
322

323
    # cbar label has to be assigned in main function, ignore.
324
    if "cbar_label" in attr_dict:
×
325
        pass
×
326

327
    if "cbar_units" in attr_dict:
×
328
        pass
×
329

330
    if facetgrid:
×
331
        if "suptitle" in attr_dict:
×
332
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
333
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
334
            facetgrid.set_titles(template="{value}")
×
335
        return facetgrid
×
336

337
    else:
338
        return ax
×
339

340

341
def get_suffix(string: str) -> str:
6✔
342
    """Get suffix of typical Xclim variable names."""
343
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
344
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
345
        return suffix
×
346
    else:
347
        raise ValueError(f"Mean, min or max not found in {string}")
×
348

349

350
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
6✔
351
    """
352
    Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
353

354
    Parameters
355
    ----------
356
    array_dict : dict
357
        Dictionary of format {'name': array...}.
358

359
    Returns
360
    -------
361
    dict
362
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
363
    """
364
    if len(array_dict) != 3:
×
365
        raise ValueError("Ensembles must contain exactly three arrays")
×
366

367
    sorted_lines = {}
×
368

369
    for name in array_dict.keys():
×
370
        suffix = get_suffix(name)
×
371

372
        if suffix.isalpha():
×
373
            if suffix in ["max", "Max"]:
×
374
                sorted_lines["upper"] = name
×
375
            elif suffix in ["min", "Min"]:
×
376
                sorted_lines["lower"] = name
×
377
            elif suffix in ["mean", "Mean"]:
×
378
                sorted_lines["middle"] = name
×
379
        elif suffix.isdigit():
×
380
            if int(suffix) >= 51:
×
381
                sorted_lines["upper"] = name
×
382
            elif int(suffix) <= 49:
×
383
                sorted_lines["lower"] = name
×
384
            elif int(suffix) == 50:
×
385
                sorted_lines["middle"] = name
×
386
        else:
387
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
388
    return sorted_lines
×
389

390

391
def loc_mpl(
6✔
392
    loc: str | tuple[int | float, int | float] | int,
393
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
394
    """
395
    Find coordinates and alignment associated to loc string.
396

397
    Parameters
398
    ----------
399
    loc : string, int, or tuple[float, float]
400
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
401
        If a tuple, must be in axes coordinates.
402

403
    Returns
404
    -------
405
    tuple(float, float), tuple(float, float), str, str
406
    """
407
    ha = "left"
×
408
    va = "bottom"
×
409

410
    loc_strings = [
×
411
        "upper right",
412
        "upper left",
413
        "lower left",
414
        "lower right",
415
        "right",
416
        "center left",
417
        "center right",
418
        "lower center",
419
        "upper center",
420
        "center",
421
    ]
422

423
    if isinstance(loc, int):
×
424
        try:
×
425
            loc = loc_strings[loc - 1]
×
426
        except IndexError as err:
×
427
            raise ValueError("loc must be between 1 and 10, inclusively") from err
×
428

429
    if loc in loc_strings:
×
430
        # ha
431
        if "left" in loc:
×
432
            ha = "left"
×
433
        elif "right" in loc:
×
434
            ha = "right"
×
435
        else:
436
            ha = "center"
×
437

438
        # va
439
        if "lower" in loc:
×
440
            va = "bottom"
×
441
        elif "upper" in loc:
×
442
            va = "top"
×
443
        else:
444
            va = "center"
×
445

446
        # transAxes
447
        if loc == "upper right":
×
448
            loc = (0.97, 0.97)
×
449
            box_a = (1, 1)
×
450
        elif loc == "upper left":
×
451
            loc = (0.03, 0.97)
×
452
            box_a = (0, 1)
×
453
        elif loc == "lower left":
×
454
            loc = (0.03, 0.03)
×
455
            box_a = (0, 0)
×
456
        elif loc == "lower right":
×
457
            loc = (0.97, 0.03)
×
458
            box_a = (1, 0)
×
459
        elif loc == "right":
×
460
            loc = (0.97, 0.5)
×
461
            box_a = (1, 0.5)
×
462
        elif loc == "center left":
×
463
            loc = (0.03, 0.5)
×
464
            box_a = (0, 0.5)
×
465
        elif loc == "center right":
×
466
            loc = (0.97, 0.5)
×
467
            box_a = (0.97, 0.5)
×
468
        elif loc == "lower center":
×
469
            loc = (0.5, 0.03)
×
470
            box_a = (0.5, 0)
×
471
        elif loc == "upper center":
×
472
            loc = (0.5, 0.97)
×
473
            box_a = (0.5, 1)
×
474
        else:
475
            loc = (0.5, 0.5)
×
476
            box_a = (0.5, 0.5)
×
477

478
    elif isinstance(loc, tuple):
×
479
        box_a = []
×
480
        for i in loc:
×
481
            if i > 1 or i < 0:
×
482
                raise ValueError(
×
483
                    "Text location coordinates must be between 0 and 1, inclusively"
484
                )
485
            elif i > 0.5:
×
486
                box_a.append(1)
×
487
            else:
488
                box_a.append(0)
×
489
        box_a = tuple(box_a)
×
490
    else:
491
        raise ValueError(
×
492
            "loc must be a string, int or tuple. "
493
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
494
        )
495

496
    return loc, box_a, ha, va
×
497

498

499
def plot_coords(
6✔
500
    ax: matplotlib.axes.Axes | None,
501
    xr_obj: xr.DataArray | xr.Dataset,
502
    loc: str | tuple[float, float] | int,
503
    param: str | None = None,
504
    backgroundalpha: float = 1,
505
) -> matplotlib.axes.Axes:
506
    """
507
    Place coordinates on plot area.
508

509
    Parameters
510
    ----------
511
    ax : matplotlib.axes.Axes or None
512
        Matplotlib axes object on which to place the text.
513
        If None, will use plt.figtext instead (should be used for facetgrids).
514
    xr_obj : xr.DataArray or xr.Dataset
515
        The xarray object from which to fetch the text content.
516
    param : {"location", "time"}, optional
517
        The parameter used.
518
    loc : string, int or tuple
519
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
520
        If a tuple, must be in axes coordinates.
521
    backgroundalpha : float
522
        Transparency of the text background. 1 is opaque, 0 is transparent.
523

524
    Returns
525
    -------
526
    matplotlib.axes.Axes
527
    """
528
    text = None
×
529
    if param == "location":
×
530
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
531
            text = "lat={:.2f}, lon={:.2f}".format(
×
532
                float(xr_obj["lat"]), float(xr_obj["lon"])
533
            )
534
        else:
535
            warnings.warn(
×
536
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords', stacklevel=2
537
            )
538
    if param == "time":
×
539
        if "time" in xr_obj.coords:
×
540
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
541

542
        else:
543
            warnings.warn('show_time set to True, but "time" not found in coords', stacklevel=2)
×
544

545
    loc, box_a, ha, va = loc_mpl(loc)
×
546

547
    if text:
×
548
        if ax:
×
549
            t = mpl.offsetbox.TextArea(
×
550
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
551
            )
552

553
            tt = mpl.offsetbox.AnnotationBbox(
×
554
                t,
555
                loc,
556
                xycoords="axes fraction",
557
                box_alignment=box_a,
558
                pad=0.05,
559
                bboxprops=dict(
560
                    facecolor="white",
561
                    alpha=backgroundalpha,
562
                    edgecolor="w",
563
                    boxstyle="Square, pad=0.5",
564
                ),
565
            )
566
            ax.add_artist(tt)
×
567
            return ax
×
568
        elif not ax:
×
569
            """
×
570
            if loc == "top left":
571
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
572
            elif loc == "top right":
573
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
574
            elif loc == "bottom left":
575
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
576
            elif loc == "bottom right" or loc is True:
577
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
578
            elif isinstance(loc, tuple):
579
                        else:
580
                raise ValueError(
581
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
582
                    f"'bottom right' or a tuple of coordinates."
583
                )
584
            """
585
            plt.figtext(
×
586
                loc[0],
587
                loc[1],
588
                text,
589
                ha=ha,
590
                va=va,
591
                fontsize=12,
592
            )
593

594
            return None
×
595

596

597
def find_logo(logo: str | pathlib.Path) -> str:
6✔
598
    """Read a logo file."""
599
    logos = Logos()
×
600
    if logo:
×
601
        logo_path = logos[logo]
×
602
    else:
603
        logo_path = logos.default
×
604

605
    if logo_path is None:
×
606
        raise ValueError(
×
607
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
608
        )
609
    return logo_path
×
610

611

612
def load_image(
6✔
613
    im: str | pathlib.Path,
614
    height: float | None,
615
    width: float | None,
616
    keep_ratio: bool = True,
617
) -> np.ndarray:
618
    """
619
    Scale an image to a specified height and width.
620

621
    Parameters
622
    ----------
623
    im : str or Path
624
        The image to be scaled. PNG and SVG formats are supported.
625
    height : float, optional
626
        The desired height of the image. If None, the original height is used.
627
    width : float, optional
628
        The desired width of the image. If None, the original width is used.
629
    keep_ratio : bool
630
        If True, the aspect ratio of the original image is maintained. Default is True.
631

632
    Returns
633
    -------
634
    np.ndarray
635
        The scaled image.
636
    """
637
    if pathlib.Path(im).suffix == ".png":
×
638
        image = mpl.pyplot.imread(im)
×
639
        original_height, original_width = image.shape[:2]
×
640

641
        if height is None and width is None:
×
642
            return image
×
643

644
        warnings.warn(
×
645
            "The scikit-image library is used to resize PNG images. This may affect logo image quality.", stacklevel=2
646
        )
647
        if not keep_ratio:
×
648
            height = original_height or height
×
649
            width = original_width or width
×
650
        else:
651
            if width is not None:
×
652
                if height is not None:
×
653
                    warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
654
                # Only width is provided, derive zoom factor for height based on aspect ratio
655
                height = (width / original_width) * original_height
×
656
            elif height is not None:
×
657
                # Only height is provided, derive zoom factor for width based on aspect ratio
658
                width = (height / original_height) * original_width
×
659

660
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
661

662
    elif pathlib.Path(im).suffix == ".svg":
×
663
        cairo_kwargs = dict(url=im)
×
664
        if not keep_ratio:
×
665
            if height is not None and width is not None:
×
666
                cairo_kwargs.update(output_height=height, output_width=width)
×
667
        elif width is not None:
×
668
            if height is not None:
×
669
                warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
670
            cairo_kwargs.update(output_width=width)
×
671
        elif height is not None:
×
672
            cairo_kwargs.update(output_height=height)
×
673

674
        with NamedTemporaryFile(suffix=".png") as png_file:
×
675
            cairo_kwargs.update(write_to=png_file.name)
×
676
            cairosvg.svg2png(**cairo_kwargs)
×
677
            return mpl.pyplot.imread(png_file.name)
×
678

679

680
def plot_logo(
6✔
681
    ax: matplotlib.axes.Axes,
682
    loc: str | tuple[float, float] | int,
683
    logo: str | pathlib.Path | Logos | None = None,
684
    height: float | None = None,
685
    width: float | None = None,
686
    keep_ratio: bool = True,
687
    **offset_image_kwargs,
688
) -> matplotlib.axes.Axes:
689
    r"""
690
    Place logo of plot area.
691

692
    Parameters
693
    ----------
694
    ax : matplotlib.axes.Axes
695
        Matplotlib axes object on which to place the text.
696
    loc : string, int or tuple
697
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
698
        If a tuple, must be in axes coordinates.
699
    logo : str, Path, figanos.Logos, optional
700
        A name (str) or Path to a logo file, or a name of an already-installed logo.
701
        If an existing is not found, the logo will be installed and accessible via the filename.
702
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
703
        The logo must be in 'png' format.
704
    height : float, optional
705
        The desired height of the image. If None, the original height is used.
706
    width : float, optional
707
        The desired width of the image. If None, the original width is used.
708
    keep_ratio : bool, optional
709
        If True, the aspect ratio of the original image is maintained. Default is True.
710
    \*\*offset_image_kwargs
711
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
712

713
    Returns
714
    -------
715
    matplotlib.axes.Axes
716
    """
717
    if offset_image_kwargs is None:
×
718
        offset_image_kwargs = {}
×
719

720
    if isinstance(logo, Logos):
×
721
        logo_path = logo.default
×
722
    else:
723
        logo_path = find_logo(logo)
×
724

725
    image = load_image(logo_path, height, width, keep_ratio)
×
726
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
727

728
    loc, box_a, ha, va = loc_mpl(loc)
×
729
    ab = mpl.offsetbox.AnnotationBbox(
×
730
        imagebox,
731
        loc,
732
        frameon=False,
733
        xycoords="axes fraction",
734
        box_alignment=box_a,
735
        pad=0.05,
736
    )
737
    ax.add_artist(ab)
×
738
    return ax
×
739

740

741
def split_legend(
6✔
742
    ax: matplotlib.axes.Axes,
743
    in_plot: bool = False,
744
    axis_factor: float = 0.15,
745
    label_gap: float = 0.02,
746
) -> matplotlib.axes.Axes:
747
    #  TODO: check for and fix overlapping labels
748
    """
749
    Draw line labels at the end of each line, or outside the plot.
750

751
    Parameters
752
    ----------
753
    ax : matplotlib.axes.Axes
754
        The axis containing the legend.
755
    in_plot : bool
756
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
757
    axis_factor : float
758
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
759
    label_gap : float
760
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
761

762
    Returns
763
    -------
764
    matplotlib.axes.Axes
765
    """
766
    # create extra space
767
    init_xbound = ax.get_xbound()
×
768

769
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
770
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
771

772
    if in_plot is True:
×
773
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
774

775
    # get legend and plot
776

777
    handles, labels = ax.get_legend_handles_labels()
×
778
    for handle, label in zip(handles, labels, strict=False):
×
779
        last_x = handle.get_xdata()[-1]
×
780
        last_y = handle.get_ydata()[-1]
×
781

782
        if isinstance(last_x, np.datetime64):
×
783
            last_x = mpl.dates.date2num(last_x)
×
784

785
        color = handle.get_color()
×
786
        # ls = handle.get_linestyle()
787

788
        if in_plot is True:
×
789
            ax.text(
×
790
                last_x + label_bump,
791
                last_y,
792
                label,
793
                ha="left",
794
                va="center",
795
                color=color,
796
            )
797
        else:
798
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
799
            ax.text(
×
800
                1.01,
801
                last_y,
802
                label,
803
                ha="left",
804
                va="center",
805
                color=color,
806
                transform=trans,
807
            )
808

809
    return ax
×
810

811

812
def fill_between_label(
6✔
813
    sorted_lines: dict[str, Any],
814
    name: str,
815
    array_categ: dict[str, Any],
816
    legend: str,
817
) -> str:
818
    """
819
    Create a label for the shading around a line in line plots.
820

821
    Parameters
822
    ----------
823
    sorted_lines : dict
824
        Dictionary created by the sort_lines() function.
825
    name : str
826
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
827
    array_categ : dict
828
        The categories of the array, as created by the get_array_categ function.
829
    legend : str
830
        Legend mode.
831

832
    Returns
833
    -------
834
    str
835
        Label to be applied to the legend element representing the shading.
836
    """
837
    if legend != "full":
×
838
        label = None
×
839
    elif array_categ[name] in [
×
840
        "ENS_PCT_VAR_DS",
841
        "ENS_PCT_DIM_DS",
842
        "ENS_PCT_DIM_DA",
843
    ]:
844
        label = get_localized_term("{}th-{}th percentiles").format(
×
845
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
846
        )
847
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
848
        label = get_localized_term("min-max range")
×
849
    else:
850
        label = None
×
851

852
    return label
×
853

854

855
def get_var_group(
6✔
856
    da: xr.DataArray | None = None,
857
    unique_str: str | None = None,
858
    path_to_json: str | pathlib.Path | None = None,
859
) -> str:
860
    """
861
    Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
862

863
    If `da` is a Dataset, look in the DataArray of the first variable.
864
    """
865
    if path_to_json is None:
×
866
        path_to_json = VARJSON
×
867

868
    # create dict
869
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
870
        var_dict = json.load(_f)
×
871

872
    matches = []
×
873

874
    if unique_str:
×
875
        for v in var_dict:
×
876
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
877
            if re.search(regex, unique_str):
×
878
                matches.append(var_dict[v])
×
879

880
    else:
881
        if isinstance(da, xr.Dataset):
×
882
            da = da[list(da.data_vars)[0]]
×
883
        # look in DataArray name
884
        if hasattr(da, "name") and isinstance(da.name, str):
×
885
            for v in var_dict:
×
886
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
887
                if re.search(regex, da.name):
×
888
                    matches.append(var_dict[v])
×
889

890
        # look in history
891
        if hasattr(da, "history") and len(matches) == 0:
×
892
            for v in var_dict:
×
893
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
894
                if re.search(regex, da.history):
×
895
                    matches.append(var_dict[v])
×
896

897
    matches = np.unique(matches)
×
898

899
    if len(matches) == 0:
×
900
        warnings.warn(
×
901
            "Colormap warning: Variable group not found. Use the cmap argument.", stacklevel=2
902
        )
903
        return "misc"
×
904
    elif len(matches) >= 2:
×
905
        warnings.warn(
×
906
            "Colormap warning: More than one variable group found. Use the cmap argument.", stacklevel=2
907
        )
908
        return "misc"
×
909
    else:
910
        return matches[0]
×
911

912

913
def get_ipcc_cmap_name(
6✔
914
    var_group: str | None = None,
915
    divergent: bool | int = False,
916
    filename: str | None = None,
917
    reverse: bool = False
918
) -> matplotlib.colors.Colormap:
919
    """
920
    Get colormap name according to variable group or filename.
921

922
    Parameters
923
    ----------
924
    var_group : str, optional
925
        Variable group from IPCC scheme.
926
    divergent : bool or int
927
        Diverging colormap. If False, use sequential colormap.
928
    filename : str, optional
929
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
930
    reverse: bool
931
        If True, the name of the reverse color order colormap is returned.
932

933
    Returns
934
    -------
935
    str : Name of the colormap in `matplotlib.colormaps`.
936
    """
937
    if filename:
×
UNCOV
938
        filename = filename.replace(".txt", "")
×
939
    else:
940
        # filename
941
        if divergent is not False:
×
942
            if var_group == "misc2":
×
943
                var_group = "misc"
×
944
            filename = var_group + "_div"
×
945
        else:
946
            if var_group == "misc":
×
947
                filename = var_group + "_seq_3"  # Batlow
×
948
            elif var_group == "misc2":
×
949
                filename = "misc_seq_2"  # freezing rain
×
950
            else:
951
                filename = var_group + "_seq"
×
NEW
952
    return filename
×
953

954

955
def create_ipcc_cmap(filename: str, reverse: bool = False):
6✔
956
    """Create an IPCC colormap from a filename."""
957
    # Ensure filename is only the stem (no ext, no parents)
958
    filename = pathlib.Path(filename).stem
6✔
959
    if filename.endswith("_r"):
6✔
NEW
960
        reverse = True
×
NEW
961
        filename = filename[:-2]
×
962

963
    with importlib.resources.path('figanos.data.ipcc_colors.continuous_colormaps_rgb_0-255', f'{filename}.txt') as p:
6✔
964
        rgb_data = np.loadtxt(p)
6✔
965

966
    # convert to 0-1 RGB
967
    rgb_data = rgb_data / 255
6✔
968
    cmap = mcolors.LinearSegmentedColormap.from_list(filename, rgb_data, N=256)
6✔
969
    if reverse is True:
6✔
970
        # this also adds "_r" to the name
971
        cmap = cmap.reversed()
6✔
972

973
    return cmap
6✔
974

975

976
# Register cmaps
977
for name in importlib.resources.contents('figanos.data.ipcc_colors.continuous_colormaps_rgb_0-255'):
6✔
978
    name = name.replace('.txt', '')
6✔
979
    for reverse in [True, False]:
6✔
980
        mpl.colormaps.register(create_ipcc_cmap(name, reverse))
6✔
981

982

983
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
6✔
984
    """
985
    Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
986

987
    Parameters
988
    ----------
989
    xr_obj : xr.DataArray or xr.Dataset
990
        The xarray object from which to look for the attributes.
991

992
    Returns
993
    -------
994
    ccrs.RotatedPole or None
995
    """
996
    try:
×
997

998
        if isinstance(xr_obj, xr.Dataset):
×
999
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
1000

1001
            if len(gridmap) > 1:
×
1002
                warnings.warn(
×
1003
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}.", stacklevel=2
1004
                )
1005

1006
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
1007
        else:
1008
            # If it can't find grid_mapping, assume it's rotated_pole
1009
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
1010

1011
        rotpole = ccrs.RotatedPole(
×
1012
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
1013
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
1014
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
1015
        )
1016
        return rotpole
×
1017

1018
    except AttributeError:
×
1019
        warnings.warn("Rotated pole not found. Specify a transform if necessary.", stacklevel=2)
×
1020
        return None
×
1021

1022

1023
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
6✔
1024
    """
1025
    Wrap text.
1026

1027
    Parameters
1028
    ----------
1029
    text : str
1030
        The text to wrap.
1031
    min_line_len : int
1032
        Minimum length of each line.
1033
    max_line_len : int
1034
        Maximum length of each line.
1035

1036
    Returns
1037
    -------
1038
    str
1039
        Wrapped text
1040
    """
1041
    start = min_line_len
×
1042
    stop = max_line_len
×
1043
    sep = "\n"
×
1044
    remaining = len(text)
×
1045

1046
    if len(text) >= max_line_len:
×
1047
        while remaining > max_line_len:
×
1048
            if ". " in text[start:stop]:
×
1049
                pos = text.find(". ", start, stop) + 1
×
1050
            elif ": " in text[start:stop]:
×
1051
                pos = text.find(": ", start, stop) + 1
×
1052
            elif " " in text[start:stop]:
×
1053
                pos = text.rfind(" ", start, stop)
×
1054
            else:
1055
                warnings.warn("No spaces, points or colons to break line at.", stacklevel=2)
×
1056
                break
×
1057

1058
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1059

1060
            remaining = len(text) - len(text[:pos])
×
1061
            start = pos + 1 + min_line_len
×
1062
            stop = pos + 1 + max_line_len
×
1063

1064
    return text
×
1065

1066

1067
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
6✔
1068
    """
1069
    Open shapefile with geopandas and convert to cartopy projection.
1070

1071
    Parameters
1072
    ----------
1073
    df : gpd.GeoDataFrame
1074
        GeoDataFrame (geopandas) geometry to be added to axis.
1075
    proj : ccrs.CRS
1076
        Projection to use, taken from the cartopy.crs options.
1077

1078
    Returns
1079
    -------
1080
    gpd.GeoDataFrame
1081
        GeoDataFrame adjusted to given projection
1082
    """
1083
    prj4 = proj.proj4_init
×
1084
    return df.to_crs(prj4)
×
1085

1086

1087
def convert_scen_name(name: str) -> str:
6✔
1088
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1089
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1090
    if matches:
×
1091
        for s in matches:
×
1092
            if sum(c.isdigit() for c in s) == 3:
×
1093
                new_s = s.replace(
×
1094
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1095
                ).upper()  # ssp245 to SSP2-4.5
1096
                new_name = name.replace(s, new_s)  # put back in name
×
1097
            elif sum(c.isdigit() for c in s) == 2:
×
1098
                new_s = s.replace(
×
1099
                    s[-2:], s[-2] + "." + s[-1]
1100
                ).upper()  # rcp45 to RCP4.5
1101
                new_name = name.replace(s, new_s)
×
1102
            else:
1103
                new_s = s.upper()  # cmip5 to CMIP5
×
1104
                new_name = name.replace(s, new_s)
×
1105

1106
        return new_name
×
1107
    else:
1108
        return name
×
1109

1110

1111
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
6✔
1112
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1113
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1114
        color_dict = json.load(_f)
×
1115

1116
    color = None
×
1117
    for entry in color_dict:
×
1118
        if entry in name:
×
1119
            color = color_dict[entry]
×
1120
            color = tuple([i / 255 for i in color])
×
1121
            break
×
1122

1123
    return color
×
1124

1125

1126
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
6✔
1127
    """Apply function to dictionary keys."""
1128
    old_keys = [key for key in dct]
×
1129
    for old_key in old_keys:
×
1130
        new_key = func(old_key)
×
1131
        dct[new_key] = dct.pop(old_key)
×
1132
    return dct
×
1133

1134

1135
def categorical_colors() -> dict[str, str]:
6✔
1136
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1137
    path = (
×
1138
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1139
    )
1140
    with path.open(encoding="utf-8") as _f:
×
1141
        cat = json.load(_f)
×
1142

1143
        return cat
×
1144

1145

1146
def get_mpl_styles() -> dict[str, pathlib.Path]:
6✔
1147
    """Get the available matplotlib styles and their paths as a dictionary."""
1148
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1149
    styles = {style.stem: style for style in files}
×
1150
    return styles
×
1151

1152

1153
def set_mpl_style(*args: str, reset: bool = False) -> None:
6✔
1154
    """
1155
    Set the matplotlib style using one or more stylesheets.
1156

1157
    Parameters
1158
    ----------
1159
    args : str
1160
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1161
    reset : bool
1162
        If True, reset style to matplotlib default before applying the stylesheets.
1163

1164
    Returns
1165
    -------
1166
    None
1167
    """
1168
    if reset is True:
×
1169
        mpl.style.use("default")
×
1170
    for s in args:
×
1171
        if s.endswith(".mplstyle") is True:
×
1172
            mpl.style.use(s)
×
1173
        elif s in get_mpl_styles():
×
1174
            mpl.style.use(get_mpl_styles()[s])
×
1175
        else:
1176
            warnings.warn(f"Style {s} not found.", stacklevel=2)
×
1177

1178

1179
def add_cartopy_features(
6✔
1180
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1181
) -> matplotlib.axes.Axes:
1182
    """
1183
    Add cartopy features to matplotlib axes.
1184

1185
    Parameters
1186
    ----------
1187
    ax : matplotlib.axes.Axes
1188
        The axes on which to add the features.
1189
    features : list or dict
1190
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1191

1192
    Returns
1193
    -------
1194
    matplotlib.axes.Axes
1195
        The axis with added features.
1196
    """
1197
    if isinstance(features, list):
×
1198
        features = {f: {} for f in features}
×
1199

1200
    for feat in features:
×
1201
        if "scale" not in features[feat]:
×
1202
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1203
        else:
1204
            scale = features[feat].pop("scale")
×
1205
            ax.add_feature(
×
1206
                getattr(cfeature, feat.upper()).with_scale(scale),
1207
                **features[feat],
1208
            )
1209
            features[feat]["scale"] = scale  # put back
×
1210
    return ax
×
1211

1212

1213
def custom_cmap_norm(
6✔
1214
    cmap,
1215
    vmin: int | float,
1216
    vmax: int | float,
1217
    levels: int | list[int | float] | None = None,
1218
    divergent: bool | int | float = False,
1219
    linspace_out: bool = False,
1220
) -> matplotlib.colors.Normalize | np.ndarray:
1221
    """
1222
    Get matplotlib normalization according to main function arguments.
1223

1224
    Parameters
1225
    ----------
1226
    cmap: matplotlib.colormap
1227
        Colormap to be used with the normalization.
1228
    vmin: int or float
1229
        Minimum of the data to be plotted with the colormap.
1230
    vmax: int or float
1231
        Maximum of the data to be plotted with the colormap.
1232
    levels : int or list, optional
1233
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1234
    divergent : bool or int or float
1235
        If int or float, becomes center of cmap. Default center is 0.
1236
    linspace_out: bool
1237
        If True, return array created by np.linspace() instead of normalization instance.
1238

1239
    Returns
1240
    -------
1241
    matplotlib.colors.Normalize
1242
    """
1243
    # get cmap if string
1244
    if isinstance(cmap, str):
×
1245
        if cmap in plt.colormaps():
×
1246
            cmap = matplotlib.colormaps[cmap]
×
1247
        else:
1248
            raise ValueError("Colormap not found")
×
1249

1250
    # make vmin and vmax prettier
1251
    if (vmax - vmin) >= 25:
×
1252
        rvmax = math.ceil(vmax / 10.0) * 10
×
1253
        rvmin = math.floor(vmin / 10.0) * 10
×
1254
    elif 1 <= (vmax - vmin) < 25:
×
1255
        rvmax = math.ceil(vmax / 1) * 1
×
1256
        rvmin = math.floor(vmin / 1) * 1
×
1257
    elif 0.1 <= (vmax - vmin) < 1:
×
1258
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1259
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1260
    else:
1261
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1262
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1263

1264
    # center
1265
    center = None
×
1266
    if divergent is not False:
×
1267
        if divergent is True:
×
1268
            center = 0
×
1269
        elif isinstance(divergent, int | float):
×
1270
            center = divergent
×
1271

1272
    # build norm with options
1273
    if center is not None and isinstance(levels, int):
×
1274
        if center <= rvmin or center >= rvmax:
×
1275
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1276
        if levels % 2 == 1:
×
1277
            half_levels = int((levels + 1) / 2) + 1
×
1278
        else:
1279
            half_levels = int(levels / 2) + 1
×
1280

1281
        lin = np.concatenate(
×
1282
            (
1283
                np.linspace(rvmin, center, num=half_levels),
1284
                np.linspace(center, rvmax, num=half_levels)[1:],
1285
            )
1286
        )
1287
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1288

1289
        if linspace_out:
×
1290
            return lin
×
1291

1292
    elif levels is not None:
×
1293
        if isinstance(levels, list):
×
1294
            if center is not None:
×
1295
                warnings.warn(
×
1296
                    "Divergent argument ignored when levels is a list. Use levels as a number instead.", stacklevel=2
1297
                )
1298
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1299
        else:
1300
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1301
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1302

1303
            if linspace_out:
×
1304
                return lin
×
1305

1306
    elif center is not None:
×
1307
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1308
    else:
1309
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1310

1311
    return norm
×
1312

1313

1314
def norm2range(
6✔
1315
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1316
) -> np.ndarray:
1317
    """Normalize data across a specific range."""
1318
    if data_range is None:
×
1319
        if len(data) > 1:
×
1320
            data_range = (np.nanmin(data), np.nanmax(data))
×
1321
        else:
1322
            raise ValueError(" if data is not an array, data_range must be specified")
×
1323

1324
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1325

1326
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1327

1328

1329
def size_legend_elements(
6✔
1330
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1331
) -> list[matplotlib.lines.Line2D]:
1332
    """
1333
    Create handles to use in a point-size legend.
1334

1335
    Parameters
1336
    ----------
1337
    data : np.ndarray
1338
        Data used to determine the point sizes.
1339
    sizes : np.ndarray
1340
        Array of point sizes.
1341
    max_entries : int
1342
        Maximum number of entries in the legend.
1343
    marker: str
1344
        Marker to use in legend.
1345

1346
    Returns
1347
    -------
1348
    list of matplotlib.lines.Line2D
1349
    """
1350
    # how many increments of 10 pts**2 are there in the sizes
1351
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1352

1353
    # divide data in those increments
1354
    lgd_data = np.linspace(min(data), max(data), n)
×
1355

1356
    # round according to range
1357
    ratio = abs(max(data) - min(data) / n)
×
1358

1359
    if ratio >= 1000:
×
1360
        rounding = 1000
×
1361
    elif 100 <= ratio < 1000:
×
1362
        rounding = 100
×
1363
    elif 10 <= ratio < 100:
×
1364
        rounding = 10
×
1365
    elif 5 <= ratio < 10:
×
1366
        rounding = 5
×
1367
    elif 1 <= ratio < 5:
×
1368
        rounding = 1
×
1369
    elif 0.1 <= ratio < 1:
×
1370
        rounding = 0.1
×
1371
    elif 0.01 <= ratio < 0.1:
×
1372
        rounding = 0.01
×
1373
    else:
1374
        rounding = 0.001
×
1375

1376
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1377

1378
    # convert back to sizes
1379
    lgd_sizes = norm2range(
×
1380
        data=lgd_data,
1381
        data_range=(min(data), max(data)),
1382
        target_range=(min(sizes), max(sizes)),
1383
    )
1384

1385
    legend_elements = []
×
1386

1387
    for s, d in zip(lgd_sizes, lgd_data, strict=False):
×
1388
        if isinstance(d, float) and d.is_integer():
×
1389
            label = str(int(d))
×
1390
        else:
1391
            label = str(d)
×
1392

1393
        legend_elements.append(
×
1394
            Line2D(
1395
                [0],
1396
                [0],
1397
                marker=marker,
1398
                color="k",
1399
                lw=0,
1400
                markerfacecolor="w",
1401
                label=label,
1402
                markersize=np.sqrt(np.abs(s)),
1403
            )
1404
        )
1405

1406
    if len(legend_elements) > max_entries:
×
1407
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1408
    else:
1409
        return legend_elements
×
1410

1411

1412
def add_features_map(
6✔
1413
    data,
1414
    ax,
1415
    use_attrs,
1416
    projection,
1417
    features,
1418
    geometries_kw,
1419
    frame,
1420
) -> matplotlib.axes.Axes:
1421
    """
1422
    Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1423

1424
    Parameters
1425
    ----------
1426
    data : dict, DataArray or Dataset
1427
        Input data do plot. If dictionary, must have only one entry.
1428
    ax : matplotlib axis
1429
        Matplotlib axis on which to plot, with the same projection as the one specified.
1430
    use_attrs : dict
1431
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1432
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1433
        Only the keys found in the default dict can be used.
1434
    projection : ccrs.Projection
1435
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1436
    features : list or dict
1437
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1438
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1439
    geometries_kw : dict
1440
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1441
    frame : bool
1442
        Show or hide frame. Default False.
1443

1444
    Returns
1445
    -------
1446
    matplotlib.axes.Axes
1447
    """
1448
    # add features
1449
    if features:
×
1450
        add_cartopy_features(ax, features)
×
1451

1452
    set_plot_attrs(use_attrs, data, ax)
×
1453

1454
    if frame is False:
×
1455
        ax.spines["geo"].set_visible(False)
×
1456

1457
    # add geometries
1458
    if geometries_kw:
×
1459
        if "geoms" not in geometries_kw.keys():
×
1460
            warnings.warn(
×
1461
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})', stacklevel=2
1462
            )
1463
        if "crs" in geometries_kw.keys():
×
1464
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1465
                geometries_kw["geoms"], geometries_kw["crs"]
1466
            )
1467
        else:
1468
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1469
        geometries_kw = {
×
1470
            "crs": projection,
1471
            "facecolor": "none",
1472
            "edgecolor": "black",
1473
        } | geometries_kw
1474

1475
        ax.add_geometries(**geometries_kw)
×
1476
    return ax
×
1477

1478

1479
def masknan_sizes_key(data, sizes) -> xr.Dataset:
6✔
1480
    """
1481
    Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1482

1483
    Parameters
1484
    ----------
1485
    data: xr.Dataset
1486
        xr.Dataset used to plot
1487
    sizes: str
1488
        Variable used to plot markersize
1489

1490
    Returns
1491
    -------
1492
    xr.Dataset
1493
    """
1494
    # find variable name
1495
    kl = list(data.keys())
×
1496
    kl.remove(sizes)
×
1497
    key = kl[0]
×
1498

1499
    # Create a mask for missing 'sizes' data
1500
    size_mask = np.isnan(data[sizes])
×
1501

1502
    # Set 'key' values to NaN where 'sizes' is missing
1503
    data[key] = data[key].where(~size_mask)
×
1504

1505
    # Create a mask for missing 'key' data
1506
    key_mask = np.isnan(data[key])
×
1507

1508
    # Set 'sizes' values to NaN where 'key' is missing
1509
    data[sizes] = data[sizes].where(~key_mask)
×
1510
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc