• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 14839462019

05 May 2025 02:56PM UTC coverage: 8.19% (+0.08%) from 8.111%
14839462019

push

github

web-flow
Drop Python3.9, update several dependency pins (#322)

### What kind of change does this PR introduce?

* Drops Python3.9, enables Python3.13
* Updates several dependency pins
* Adjusts the CI to the new configurations

### Does this PR introduce a breaking change?

Yes. Python 3.9 is no longer supported (reaching EoL date in a few
months). Some base dependency pins are higher.

### Other information:

`figanos` is the only package in PAVICS that requires `numpy<2.0.0`.
This change should move us up for the next user images.

4 of 5 new or added lines in 2 files covered. (80.0%)

4 existing lines in 2 files now uncovered.

157 of 1917 relevant lines covered (8.19%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.37
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
5✔
4

5
import json
5✔
6
import math
5✔
7
import pathlib
5✔
8
import re
5✔
9
import warnings
5✔
10
from collections.abc import Callable
5✔
11
from copy import deepcopy
5✔
12
from tempfile import NamedTemporaryFile
5✔
13
from typing import Any
5✔
14

15
import cairosvg
5✔
16
import cartopy.crs as ccrs
5✔
17
import cartopy.feature as cfeature
5✔
18
import geopandas as gpd
5✔
19
import matplotlib as mpl
5✔
20
import matplotlib.axes
5✔
21
import matplotlib.colors as mcolors
5✔
22
import matplotlib.pyplot as plt
5✔
23
import numpy as np
5✔
24
import pandas as pd
5✔
25
import seaborn
5✔
26
import xarray as xr
5✔
27
import yaml
5✔
28
from matplotlib.lines import Line2D
5✔
29
from skimage.transform import resize
5✔
30
from xclim.core.options import METADATA_LOCALES
5✔
31
from xclim.core.options import OPTIONS as XC_OPTIONS
5✔
32

33
from .._logo import Logos
5✔
34

35
TERMS: dict = {}
5✔
36
"""
5✔
37
A translation directory for special terms to appear on the plots.
38

39
Keys are terms to translate and they map to "locale": "translation" dictionaries.
40
The "official" figanos terms are based on figanos/data/terms.yml.
41
"""
42

43

44
# Load terms translations
45
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
5✔
46
    TERMS = yaml.safe_load(f)
5✔
47

48

49
def get_localized_term(term, locale=None):
5✔
50
    """Get `term` translated into `locale`.
51

52
    Terms are pulled from the :py:data:`TERMS` dictionary.
53

54
    Parameters
55
    ----------
56
    term : str
57
        A word or short phrase to translate.
58
    locale : str, optional
59
        A 2-letter locale name to translate to.
60
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
61

62
    Returns
63
    -------
64
    str
65
        Translated term.
66
    """
67
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
68
    if locale == "en":
×
69
        return term
×
70

71
    if term not in TERMS:
×
72
        warnings.warn(f"No translation known for term '{term}'.")
×
73
        return term
×
74

75
    if locale not in TERMS[term]:
×
76
        warnings.warn(f"No {locale} translation known for term '{term}'.")
×
77
        return term
×
78

79
    return TERMS[term][locale]
×
80

81

82
def empty_dict(param) -> dict:
5✔
83
    """Return empty dict if input is None."""
84
    if param is None:
×
85
        param = dict()
×
86
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
87

88

89
def check_timeindex(
5✔
90
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
91
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
92
    """Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
93

94
    Parameters
95
    ----------
96
    xr_objs : xr.DataArray or xr.Dataset or dict
97
        Dictionary containing Xarray DataArrays or Datasets.
98

99
    Returns
100
    -------
101
    xr.DataArray or xr.Dataset or dict
102
        Dictionary of xarray objects with a pandas DatetimeIndex
103
    """
104
    if isinstance(xr_objs, dict):
×
105
        for name, obj in xr_objs.items():
×
106
            if "time" in obj.dims:
×
107
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
108
                    conv_obj = obj.convert_calendar(
×
109
                        "standard", use_cftime=None, align_on="year"
110
                    )
111
                    xr_objs[name] = conv_obj
×
112
                    warnings.warn(
×
113
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
114
                    )
115

116
    else:
117
        if "time" in xr_objs.dims:
×
118
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
119
                conv_obj = xr_objs.convert_calendar(
×
120
                    "standard", use_cftime=None, align_on="year"
121
                )
122
                xr_objs = conv_obj
×
123
                warnings.warn(
×
124
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
125
                )
126

127
    return xr_objs
×
128

129

130
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
5✔
131
    """Get an array category, which determines how to plot the array.
132

133
    Parameters
134
    ----------
135
    array : Dataset or DataArray
136
        The array being categorized.
137

138
    Returns
139
    -------
140
    str
141
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
142
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
143
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
144
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
145
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
146
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
147
        DS: any Dataset that is not  recognized as an ensemble
148
        DA: DataArray
149
    """
150
    if isinstance(array, xr.Dataset):
×
151
        if (
×
152
            pd.notnull(
153
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
154
            ).sum()
155
            >= 2
156
        ):
157
            cat = "ENS_PCT_VAR_DS"
×
158
        elif (
×
159
            pd.notnull(
160
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
161
            ).sum()
162
            >= 2
163
        ):
164
            cat = "ENS_STATS_VAR_DS"
×
165
        elif "percentiles" in array.dims:
×
166
            cat = "ENS_PCT_DIM_DS"
×
167
        elif "realization" in array.dims:
×
168
            cat = "ENS_REALS_DS"
×
169
        else:
170
            cat = "DS"
×
171

172
    elif isinstance(array, xr.DataArray):
×
173
        if "percentiles" in array.dims:
×
174
            cat = "ENS_PCT_DIM_DA"
×
175
        elif "realization" in array.dims:
×
176
            cat = "ENS_REALS_DA"
×
177
        else:
178
            cat = "DA"
×
179
    else:
180
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
181

182
    return cat
×
183

184

185
def get_attributes(
5✔
186
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
187
) -> str:
188
    """Fetch attributes or dims corresponding to keys from Xarray objects.
189

190
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
191
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
192

193
    Parameters
194
    ----------
195
    string : str
196
        String corresponding to an attribute name.
197
    xr_obj : DataArray or Dataset
198
        The Xarray object containing the attributes.
199
    locale : str, optional
200
        A 2-letter locale name to translate to.
201
        Default is None, which will pull the locale
202
        from xclim's "metadata_locales" option (taking the first).
203

204
    Returns
205
    -------
206
    str
207
        Xarray attribute value as string or empty string if not found
208
    """
209
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
210
    if locale != "en":
×
211
        names = [f"{string}_{locale}", string]
×
212
    else:
213
        names = [string]
×
214

215
    for name in names:
×
216
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
217
            return xr_obj.attrs[name]
×
218

219
        if (
×
220
            isinstance(xr_obj, xr.Dataset)
221
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
222
        ):  # DataArray of first variable
223
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
224

225
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
226
            return xr_obj.attrs[name]
×
227

228
    warnings.warn(f'Attribute "{string}" not found.')
×
229
    return ""
×
230

231

232
def set_plot_attrs(
5✔
233
    attr_dict: dict[str, Any],
234
    xr_obj: xr.DataArray | xr.Dataset,
235
    ax: matplotlib.axes.Axes | None = None,
236
    title_loc: str = "center",
237
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
238
    wrap_kw: dict[str, Any] | None = None,
239
) -> matplotlib.axes.Axes:
240
    """Set plot elements according to Dataset or DataArray attributes.
241

242
    Uses get_attributes() to check for and get the string.
243

244
    Parameters
245
    ----------
246
    attr_dict : dict
247
        Dictionary containing specified attribute keys.
248
    xr_obj : Dataset or DataArray
249
        The Xarray object containing the attributes.
250
    ax : matplotlib axis
251
        The matplotlib axis of the plot.
252
    title_loc : str
253
        Location of the title.
254
    wrap_kw : dict, optional
255
        Arguments to pass to the wrap_text function for the title.
256

257
    Returns
258
    -------
259
    matplotlib.axes.Axes
260
    """
261
    wrap_kw = empty_dict(wrap_kw)
×
262

263
    #  check
264
    for key in attr_dict:
×
265
        if key not in [
×
266
            "title",
267
            "ylabel",
268
            "yunits",
269
            "xlabel",
270
            "xunits",
271
            "cbar_label",
272
            "cbar_units",
273
            "suptitle",
274
        ]:
275
            warnings.warn(f'Use_attrs element "{key}" not supported')
×
276

277
    if "title" in attr_dict:
×
278
        title = get_attributes(attr_dict["title"], xr_obj)
×
279
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
280

281
    if "ylabel" in attr_dict:
×
282
        if (
×
283
            "yunits" in attr_dict
284
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
285
        ):  # second condition avoids '[]' as label
286
            ylabel = wrap_text(
×
287
                get_attributes(attr_dict["ylabel"], xr_obj)
288
                + " ("
289
                + get_attributes(attr_dict["yunits"], xr_obj)
290
                + ")"
291
            )
292
        else:
293
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
294

295
        ax.set_ylabel(ylabel)
×
296

297
    if "xlabel" in attr_dict:
×
298
        if (
×
299
            "xunits" in attr_dict
300
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
301
        ):  # second condition avoids '[]' as label
302
            xlabel = wrap_text(
×
303
                get_attributes(attr_dict["xlabel"], xr_obj)
304
                + " ("
305
                + get_attributes(attr_dict["xunits"], xr_obj)
306
                + ")"
307
            )
308
        else:
309
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
310

311
        ax.set_xlabel(xlabel)
×
312

313
    # cbar label has to be assigned in main function, ignore.
314
    if "cbar_label" in attr_dict:
×
UNCOV
315
        pass
×
316

317
    if "cbar_units" in attr_dict:
×
UNCOV
318
        pass
×
319

320
    if facetgrid:
×
321
        if "suptitle" in attr_dict:
×
322
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
323
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
324
            facetgrid.set_titles(template="{value}")
×
325
        return facetgrid
×
326

327
    else:
328
        return ax
×
329

330

331
def get_suffix(string: str) -> str:
5✔
332
    """Get suffix of typical Xclim variable names."""
333
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
334
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
335
        return suffix
×
336
    else:
337
        raise ValueError(f"Mean, min or max not found in {string}")
×
338

339

340
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
5✔
341
    """Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
342

343
    Parameters
344
    ----------
345
    array_dict : dict
346
        Dictionary of format {'name': array...}.
347

348
    Returns
349
    -------
350
    dict
351
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
352
    """
353
    if len(array_dict) != 3:
×
354
        raise ValueError("Ensembles must contain exactly three arrays")
×
355

356
    sorted_lines = {}
×
357

358
    for name in array_dict.keys():
×
359
        suffix = get_suffix(name)
×
360

361
        if suffix.isalpha():
×
362
            if suffix in ["max", "Max"]:
×
363
                sorted_lines["upper"] = name
×
364
            elif suffix in ["min", "Min"]:
×
365
                sorted_lines["lower"] = name
×
366
            elif suffix in ["mean", "Mean"]:
×
367
                sorted_lines["middle"] = name
×
368
        elif suffix.isdigit():
×
369
            if int(suffix) >= 51:
×
370
                sorted_lines["upper"] = name
×
371
            elif int(suffix) <= 49:
×
372
                sorted_lines["lower"] = name
×
373
            elif int(suffix) == 50:
×
374
                sorted_lines["middle"] = name
×
375
        else:
376
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
377
    return sorted_lines
×
378

379

380
def loc_mpl(
5✔
381
    loc: str | tuple[int | float, int | float] | int,
382
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
383
    """Find coordinates and alignment associated to loc string.
384

385
    Parameters
386
    ----------
387
    loc : string, int, or tuple[float, float]
388
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
389
        If a tuple, must be in axes coordinates.
390

391
    Returns
392
    -------
393
    tuple(float, float), tuple(float, float), str, str
394
    """
395
    ha = "left"
×
396
    va = "bottom"
×
397

398
    loc_strings = [
×
399
        "upper right",
400
        "upper left",
401
        "lower left",
402
        "lower right",
403
        "right",
404
        "center left",
405
        "center right",
406
        "lower center",
407
        "upper center",
408
        "center",
409
    ]
410

411
    if isinstance(loc, int):
×
412
        try:
×
413
            loc = loc_strings[loc - 1]
×
414
        except IndexError:
×
415
            raise ValueError("loc must be between 1 and 10, inclusively")
×
416

417
    if loc in loc_strings:
×
418
        # ha
419
        if "left" in loc:
×
420
            ha = "left"
×
421
        elif "right" in loc:
×
422
            ha = "right"
×
423
        else:
424
            ha = "center"
×
425

426
        # va
427
        if "lower" in loc:
×
428
            va = "bottom"
×
429
        elif "upper" in loc:
×
430
            va = "top"
×
431
        else:
432
            va = "center"
×
433

434
        # transAxes
435
        if loc == "upper right":
×
436
            loc = (0.97, 0.97)
×
437
            box_a = (1, 1)
×
438
        elif loc == "upper left":
×
439
            loc = (0.03, 0.97)
×
440
            box_a = (0, 1)
×
441
        elif loc == "lower left":
×
442
            loc = (0.03, 0.03)
×
443
            box_a = (0, 0)
×
444
        elif loc == "lower right":
×
445
            loc = (0.97, 0.03)
×
446
            box_a = (1, 0)
×
447
        elif loc == "right":
×
448
            loc = (0.97, 0.5)
×
449
            box_a = (1, 0.5)
×
450
        elif loc == "center left":
×
451
            loc = (0.03, 0.5)
×
452
            box_a = (0, 0.5)
×
453
        elif loc == "center right":
×
454
            loc = (0.97, 0.5)
×
455
            box_a = (0.97, 0.5)
×
456
        elif loc == "lower center":
×
457
            loc = (0.5, 0.03)
×
458
            box_a = (0.5, 0)
×
459
        elif loc == "upper center":
×
460
            loc = (0.5, 0.97)
×
461
            box_a = (0.5, 1)
×
462
        else:
463
            loc = (0.5, 0.5)
×
464
            box_a = (0.5, 0.5)
×
465

466
    elif isinstance(loc, tuple):
×
467
        box_a = []
×
468
        for i in loc:
×
469
            if i > 1 or i < 0:
×
470
                raise ValueError(
×
471
                    "Text location coordinates must be between 0 and 1, inclusively"
472
                )
473
            elif i > 0.5:
×
474
                box_a.append(1)
×
475
            else:
476
                box_a.append(0)
×
477
        box_a = tuple(box_a)
×
478
    else:
479
        raise ValueError(
×
480
            "loc must be a string, int or tuple. "
481
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
482
        )
483

484
    return loc, box_a, ha, va
×
485

486

487
def plot_coords(
5✔
488
    ax: matplotlib.axes.Axes | None,
489
    xr_obj: xr.DataArray | xr.Dataset,
490
    loc: str | tuple[float, float] | int,
491
    param: str | None = None,
492
    backgroundalpha: float = 1,
493
) -> matplotlib.axes.Axes:
494
    """Place coordinates on plot area.
495

496
    Parameters
497
    ----------
498
    ax : matplotlib.axes.Axes or None
499
        Matplotlib axes object on which to place the text.
500
        If None, will use plt.figtext instead (should be used for facetgrids).
501
    xr_obj : xr.DataArray or xr.Dataset
502
        The xarray object from which to fetch the text content.
503
    param : {"location", "time"}, optional
504
        The parameter used.
505
    loc : string, int or tuple
506
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
507
        If a tuple, must be in axes coordinates.
508
    backgroundalpha : float
509
        Transparency of the text background. 1 is opaque, 0 is transparent.
510

511
    Returns
512
    -------
513
    matplotlib.axes.Axes
514
    """
515
    text = None
×
516
    if param == "location":
×
517
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
518
            text = "lat={:.2f}, lon={:.2f}".format(
×
519
                float(xr_obj["lat"]), float(xr_obj["lon"])
520
            )
521
        else:
522
            warnings.warn(
×
523
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords'
524
            )
525
    if param == "time":
×
526
        if "time" in xr_obj.coords:
×
527
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
528

529
        else:
530
            warnings.warn('show_time set to True, but "time" not found in coords')
×
531

532
    loc, box_a, ha, va = loc_mpl(loc)
×
533

534
    if text:
×
535
        if ax:
×
536
            t = mpl.offsetbox.TextArea(
×
537
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
538
            )
539

540
            tt = mpl.offsetbox.AnnotationBbox(
×
541
                t,
542
                loc,
543
                xycoords="axes fraction",
544
                box_alignment=box_a,
545
                pad=0.05,
546
                bboxprops=dict(
547
                    facecolor="white",
548
                    alpha=backgroundalpha,
549
                    edgecolor="w",
550
                    boxstyle="Square, pad=0.5",
551
                ),
552
            )
553
            ax.add_artist(tt)
×
554
            return ax
×
555
        elif not ax:
×
UNCOV
556
            """
×
557
            if loc == "top left":
558
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
559
            elif loc == "top right":
560
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
561
            elif loc == "bottom left":
562
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
563
            elif loc == "bottom right" or loc is True:
564
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
565
            elif isinstance(loc, tuple):
566
                        else:
567
                raise ValueError(
568
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
569
                    f"'bottom right' or a tuple of coordinates."
570
                )
571
            """
572
            plt.figtext(
×
573
                loc[0],
574
                loc[1],
575
                text,
576
                ha=ha,
577
                va=va,
578
                fontsize=12,
579
            )
580

581
            return None
×
582

583

584
def find_logo(logo: str | pathlib.Path) -> str:
5✔
585
    """Read a logo file."""
586
    logos = Logos()
×
587
    if logo:
×
588
        logo_path = logos[logo]
×
589
    else:
590
        logo_path = logos.default
×
591

592
    if logo_path is None:
×
593
        raise ValueError(
×
594
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
595
        )
596
    return logo_path
×
597

598

599
def load_image(
5✔
600
    im: str | pathlib.Path,
601
    height: float | None,
602
    width: float | None,
603
    keep_ratio: bool = True,
604
) -> np.ndarray:
605
    """Scale an image to a specified height and width.
606

607
    Parameters
608
    ----------
609
    im : str or Path
610
        The image to be scaled. PNG and SVG formats are supported.
611
    height : float, optional
612
        The desired height of the image. If None, the original height is used.
613
    width : float, optional
614
        The desired width of the image. If None, the original width is used.
615
    keep_ratio : bool
616
        If True, the aspect ratio of the original image is maintained. Default is True.
617

618
    Returns
619
    -------
620
    np.ndarray
621
        The scaled image.
622
    """
623
    if pathlib.Path(im).suffix == ".png":
×
624
        image = mpl.pyplot.imread(im)
×
625
        original_height, original_width = image.shape[:2]
×
626

627
        if height is None and width is None:
×
628
            return image
×
629

630
        warnings.warn(
×
631
            "The scikit-image library is used to resize PNG images. This may affect logo image quality."
632
        )
633
        if not keep_ratio:
×
634
            height = original_height or height
×
635
            width = original_width or width
×
636
        else:
637
            if width is not None:
×
638
                if height is not None:
×
639
                    warnings.warn("Both height and width provided, using height.")
×
640
                # Only width is provided, derive zoom factor for height based on aspect ratio
641
                height = (width / original_width) * original_height
×
642
            elif height is not None:
×
643
                # Only height is provided, derive zoom factor for width based on aspect ratio
644
                width = (height / original_height) * original_width
×
645

646
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
647

648
    elif pathlib.Path(im).suffix == ".svg":
×
649
        cairo_kwargs = dict(url=im)
×
650
        if not keep_ratio:
×
651
            if height is not None and width is not None:
×
652
                cairo_kwargs.update(output_height=height, output_width=width)
×
653
        elif width is not None:
×
654
            if height is not None:
×
655
                warnings.warn("Both height and width provided, using height.")
×
656
            cairo_kwargs.update(output_width=width)
×
657
        elif height is not None:
×
658
            cairo_kwargs.update(output_height=height)
×
659

660
        with NamedTemporaryFile(suffix=".png") as png_file:
×
661
            cairo_kwargs.update(write_to=png_file.name)
×
662
            cairosvg.svg2png(**cairo_kwargs)
×
663
            return mpl.pyplot.imread(png_file.name)
×
664

665

666
def plot_logo(
5✔
667
    ax: matplotlib.axes.Axes,
668
    loc: str | tuple[float, float] | int,
669
    logo: str | pathlib.Path | Logos | None = None,
670
    height: float | None = None,
671
    width: float | None = None,
672
    keep_ratio: bool = True,
673
    **offset_image_kwargs,
674
) -> matplotlib.axes.Axes:
675
    r"""Place logo of plot area.
676

677
    Parameters
678
    ----------
679
    ax : matplotlib.axes.Axes
680
        Matplotlib axes object on which to place the text.
681
    loc : string, int or tuple
682
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
683
        If a tuple, must be in axes coordinates.
684
    logo : str, Path, figanos.Logos, optional
685
        A name (str) or Path to a logo file, or a name of an already-installed logo.
686
        If an existing is not found, the logo will be installed and accessible via the filename.
687
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
688
        The logo must be in 'png' format.
689
    height : float, optional
690
        The desired height of the image. If None, the original height is used.
691
    width : float, optional
692
        The desired width of the image. If None, the original width is used.
693
    keep_ratio : bool, optional
694
        If True, the aspect ratio of the original image is maintained. Default is True.
695
    \*\*offset_image_kwargs
696
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
697

698
    Returns
699
    -------
700
    matplotlib.axes.Axes
701
    """
702
    if offset_image_kwargs is None:
×
703
        offset_image_kwargs = {}
×
704

705
    if isinstance(logo, Logos):
×
706
        logo_path = logo.default
×
707
    else:
708
        logo_path = find_logo(logo)
×
709

710
    image = load_image(logo_path, height, width, keep_ratio)
×
711
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
712

713
    loc, box_a, ha, va = loc_mpl(loc)
×
714
    ab = mpl.offsetbox.AnnotationBbox(
×
715
        imagebox,
716
        loc,
717
        frameon=False,
718
        xycoords="axes fraction",
719
        box_alignment=box_a,
720
        pad=0.05,
721
    )
722
    ax.add_artist(ab)
×
723
    return ax
×
724

725

726
def split_legend(
5✔
727
    ax: matplotlib.axes.Axes,
728
    in_plot: bool = False,
729
    axis_factor: float = 0.15,
730
    label_gap: float = 0.02,
731
) -> matplotlib.axes.Axes:
732
    #  TODO: check for and fix overlapping labels
733
    """Draw line labels at the end of each line, or outside the plot.
734

735
    Parameters
736
    ----------
737
    ax : matplotlib.axes.Axes
738
        The axis containing the legend.
739
    in_plot : bool
740
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
741
    axis_factor : float
742
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
743
    label_gap : float
744
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
745

746
    Returns
747
    -------
748
    matplotlib.axes.Axes
749
    """
750
    # create extra space
751
    init_xbound = ax.get_xbound()
×
752

753
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
754
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
755

756
    if in_plot is True:
×
757
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
758

759
    # get legend and plot
760

761
    handles, labels = ax.get_legend_handles_labels()
×
762
    for handle, label in zip(handles, labels):
×
763
        last_x = handle.get_xdata()[-1]
×
764
        last_y = handle.get_ydata()[-1]
×
765

766
        if isinstance(last_x, np.datetime64):
×
767
            last_x = mpl.dates.date2num(last_x)
×
768

769
        color = handle.get_color()
×
770
        # ls = handle.get_linestyle()
771

772
        if in_plot is True:
×
773
            ax.text(
×
774
                last_x + label_bump,
775
                last_y,
776
                label,
777
                ha="left",
778
                va="center",
779
                color=color,
780
            )
781
        else:
782
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
783
            ax.text(
×
784
                1.01,
785
                last_y,
786
                label,
787
                ha="left",
788
                va="center",
789
                color=color,
790
                transform=trans,
791
            )
792

793
    return ax
×
794

795

796
def fill_between_label(
5✔
797
    sorted_lines: dict[str, Any],
798
    name: str,
799
    array_categ: dict[str, Any],
800
    legend: str,
801
) -> str:
802
    """Create a label for the shading around a line in line plots.
803

804
    Parameters
805
    ----------
806
    sorted_lines : dict
807
        Dictionary created by the sort_lines() function.
808
    name : str
809
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
810
    array_categ : dict
811
        The categories of the array, as created by the get_array_categ function.
812
    legend : str
813
        Legend mode.
814

815
    Returns
816
    -------
817
    str
818
        Label to be applied to the legend element representing the shading.
819
    """
820
    if legend != "full":
×
821
        label = None
×
822
    elif array_categ[name] in [
×
823
        "ENS_PCT_VAR_DS",
824
        "ENS_PCT_DIM_DS",
825
        "ENS_PCT_DIM_DA",
826
    ]:
827
        label = get_localized_term("{}th-{}th percentiles").format(
×
828
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
829
        )
830
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
831
        label = get_localized_term("min-max range")
×
832
    else:
833
        label = None
×
834

835
    return label
×
836

837

838
def get_var_group(
5✔
839
    path_to_json: str | pathlib.Path,
840
    da: xr.DataArray | None = None,
841
    unique_str: str | None = None,
842
) -> str:
843
    """Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
844

845
    If `da` is a Dataset, look in the DataArray of the first variable.
846
    """
847
    # create dict
848
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
849
        var_dict = json.load(_f)
×
850

851
    matches = []
×
852

853
    if unique_str:
×
854
        for v in var_dict:
×
855
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
856
            if re.search(regex, unique_str):
×
857
                matches.append(var_dict[v])
×
858

859
    else:
860
        if isinstance(da, xr.Dataset):
×
861
            da = da[list(da.data_vars)[0]]
×
862
        # look in DataArray name
863
        if hasattr(da, "name") and isinstance(da.name, str):
×
864
            for v in var_dict:
×
865
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
866
                if re.search(regex, da.name):
×
867
                    matches.append(var_dict[v])
×
868

869
        # look in history
870
        if hasattr(da, "history") and len(matches) == 0:
×
871
            for v in var_dict:
×
872
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
873
                if re.search(regex, da.history):
×
874
                    matches.append(var_dict[v])
×
875

876
    matches = np.unique(matches)
×
877

878
    if len(matches) == 0:
×
879
        warnings.warn(
×
880
            "Colormap warning: Variable group not found. Use the cmap argument."
881
        )
882
        return "misc"
×
883
    elif len(matches) >= 2:
×
884
        warnings.warn(
×
885
            "Colormap warning: More than one variable group found. Use the cmap argument."
886
        )
887
        return "misc"
×
888
    else:
889
        return matches[0]
×
890

891

892
def create_cmap(
5✔
893
    var_group: str | None = None,
894
    divergent: bool | int = False,
895
    filename: str | None = None,
896
) -> matplotlib.colors.Colormap:
897
    """Create colormap according to variable group.
898

899
    Parameters
900
    ----------
901
    var_group : str, optional
902
        Variable group from IPCC scheme.
903
    divergent : bool or int
904
        Diverging colormap. If False, use sequential colormap.
905
    filename : str, optional
906
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
907

908
    Returns
909
    -------
910
    matplotlib.colors.Colormap
911
    """
912
    reverse = False
×
913

914
    if filename:
×
915
        folder = "continuous_colormaps_rgb_0-255"
×
916
        filename = filename.replace(".txt", "")
×
917

918
        if filename.endswith("_r"):
×
919
            reverse = True
×
920
            filename = filename[:-2]
×
921

922
    else:
923
        # filename
924
        if divergent is not False:
×
925
            if var_group == "misc2":
×
926
                var_group = "misc"
×
927
            filename = var_group + "_div"
×
928
        else:
929
            if var_group == "misc":
×
930
                filename = var_group + "_seq_3"  # Batlow
×
931
            elif var_group == "misc2":
×
932
                filename = "misc_seq_2"  # freezing rain
×
933
            else:
934
                filename = var_group + "_seq"
×
935

936
        folder = "continuous_colormaps_rgb_0-255"
×
937

938
    # parent should be 'figanos/'
939
    path = (
×
940
        pathlib.Path(__file__).parents[1]
941
        / "data"
942
        / "ipcc_colors"
943
        / folder
944
        / (filename + ".txt")
945
    )
946

947
    rgb_data = np.loadtxt(path)
×
948

949
    # convert to 0-1 RGB
950
    rgb_data = rgb_data / 255
×
951

952
    cmap = mcolors.LinearSegmentedColormap.from_list("cmap", rgb_data, N=256)
×
953
    if reverse is True:
×
954
        cmap = cmap.reversed()
×
955

956
    return cmap
×
957

958

959
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
5✔
960
    """Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
961

962
    Parameters
963
    ----------
964
    xr_obj : xr.DataArray or xr.Dataset
965
        The xarray object from which to look for the attributes.
966

967
    Returns
968
    -------
969
    ccrs.RotatedPole or None
970
    """
971
    try:
×
972

973
        if isinstance(xr_obj, xr.Dataset):
×
974
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
975

976
            if len(gridmap) > 1:
×
977
                warnings.warn(
×
978
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}."
979
                )
980

981
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
982
        else:
983
            # If it can't find grid_mapping, assume it's rotated_pole
984
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
985

986
        rotpole = ccrs.RotatedPole(
×
987
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
988
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
989
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
990
        )
991
        return rotpole
×
992

993
    except AttributeError:
×
994
        warnings.warn("Rotated pole not found. Specify a transform if necessary.")
×
995
        return None
×
996

997

998
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
5✔
999
    """Wrap text.
1000

1001
    Parameters
1002
    ----------
1003
    text : str
1004
        The text to wrap.
1005
    min_line_len : int
1006
        Minimum length of each line.
1007
    max_line_len : int
1008
        Maximum length of each line.
1009

1010
    Returns
1011
    -------
1012
    str
1013
        Wrapped text
1014
    """
1015
    start = min_line_len
×
1016
    stop = max_line_len
×
1017
    sep = "\n"
×
1018
    remaining = len(text)
×
1019

1020
    if len(text) >= max_line_len:
×
1021
        while remaining > max_line_len:
×
1022
            if ". " in text[start:stop]:
×
1023
                pos = text.find(". ", start, stop) + 1
×
1024
            elif ": " in text[start:stop]:
×
1025
                pos = text.find(": ", start, stop) + 1
×
1026
            elif " " in text[start:stop]:
×
1027
                pos = text.rfind(" ", start, stop)
×
1028
            else:
1029
                warnings.warn("No spaces, points or colons to break line at.")
×
1030
                break
×
1031

1032
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1033

1034
            remaining = len(text) - len(text[:pos])
×
1035
            start = pos + 1 + min_line_len
×
1036
            stop = pos + 1 + max_line_len
×
1037

1038
    return text
×
1039

1040

1041
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
5✔
1042
    """Open shapefile with geopandas and convert to cartopy projection.
1043

1044
    Parameters
1045
    ----------
1046
    df : gpd.GeoDataFrame
1047
        GeoDataFrame (geopandas) geometry to be added to axis.
1048
    proj : ccrs.CRS
1049
        Projection to use, taken from the cartopy.crs options.
1050

1051
    Returns
1052
    -------
1053
    gpd.GeoDataFrame
1054
        GeoDataFrame adjusted to given projection
1055
    """
1056
    prj4 = proj.proj4_init
×
1057
    return df.to_crs(prj4)
×
1058

1059

1060
def convert_scen_name(name: str) -> str:
5✔
1061
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1062
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1063
    if matches:
×
1064
        for s in matches:
×
1065
            if sum(c.isdigit() for c in s) == 3:
×
1066
                new_s = s.replace(
×
1067
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1068
                ).upper()  # ssp245 to SSP2-4.5
1069
                new_name = name.replace(s, new_s)  # put back in name
×
1070
            elif sum(c.isdigit() for c in s) == 2:
×
1071
                new_s = s.replace(
×
1072
                    s[-2:], s[-2] + "." + s[-1]
1073
                ).upper()  # rcp45 to RCP4.5
1074
                new_name = name.replace(s, new_s)
×
1075
            else:
1076
                new_s = s.upper()  # cmip5 to CMIP5
×
1077
                new_name = name.replace(s, new_s)
×
1078

1079
        return new_name
×
1080
    else:
1081
        return name
×
1082

1083

1084
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
5✔
1085
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1086
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1087
        color_dict = json.load(_f)
×
1088

1089
    color = None
×
1090
    for entry in color_dict:
×
1091
        if entry in name:
×
1092
            color = color_dict[entry]
×
1093
            break
×
1094

1095
    return color
×
1096

1097

1098
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
5✔
1099
    """Apply function to dictionary keys."""
1100
    old_keys = [key for key in dct]
×
1101
    for old_key in old_keys:
×
1102
        new_key = func(old_key)
×
1103
        dct[new_key] = dct.pop(old_key)
×
1104
    return dct
×
1105

1106

1107
def categorical_colors() -> dict[str, str]:
5✔
1108
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1109
    path = (
×
1110
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1111
    )
1112
    with path.open(encoding="utf-8") as _f:
×
1113
        cat = json.load(_f)
×
1114

1115
        return cat
×
1116

1117

1118
def get_mpl_styles() -> dict[str, pathlib.Path]:
5✔
1119
    """Get the available matplotlib styles and their paths as a dictionary."""
1120
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1121
    styles = {style.stem: style for style in files}
×
1122
    return styles
×
1123

1124

1125
def set_mpl_style(*args: str, reset: bool = False) -> None:
5✔
1126
    """Set the matplotlib style using one or more stylesheets.
1127

1128
    Parameters
1129
    ----------
1130
    args : str
1131
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1132
    reset : bool
1133
        If True, reset style to matplotlib default before applying the stylesheets.
1134

1135
    Returns
1136
    -------
1137
    None
1138
    """
1139
    if reset is True:
×
1140
        mpl.style.use("default")
×
1141
    for s in args:
×
1142
        if s.endswith(".mplstyle") is True:
×
1143
            mpl.style.use(s)
×
1144
        elif s in get_mpl_styles():
×
1145
            mpl.style.use(get_mpl_styles()[s])
×
1146
        else:
1147
            warnings.warn(f"Style {s} not found.")
×
1148

1149

1150
def add_cartopy_features(
5✔
1151
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1152
) -> matplotlib.axes.Axes:
1153
    """Add cartopy features to matplotlib axes.
1154

1155
    Parameters
1156
    ----------
1157
    ax : matplotlib.axes.Axes
1158
        The axes on which to add the features.
1159
    features : list or dict
1160
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1161

1162
    Returns
1163
    -------
1164
    matplotlib.axes.Axes
1165
        The axis with added features.
1166
    """
1167
    if isinstance(features, list):
×
1168
        features = {f: {} for f in features}
×
1169

1170
    for feat in features:
×
1171
        if "scale" not in features[feat]:
×
1172
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1173
        else:
1174
            scale = features[feat].pop("scale")
×
1175
            ax.add_feature(
×
1176
                getattr(cfeature, feat.upper()).with_scale(scale),
1177
                **features[feat],
1178
            )
1179
            features[feat]["scale"] = scale  # put back
×
1180
    return ax
×
1181

1182

1183
def custom_cmap_norm(
5✔
1184
    cmap,
1185
    vmin: int | float,
1186
    vmax: int | float,
1187
    levels: int | list[int | float] | None = None,
1188
    divergent: bool | int | float = False,
1189
    linspace_out: bool = False,
1190
) -> matplotlib.colors.Normalize | np.ndarray:
1191
    """Get matplotlib normalization according to main function arguments.
1192

1193
    Parameters
1194
    ----------
1195
    cmap: matplotlib.colormap
1196
        Colormap to be used with the normalization.
1197
    vmin: int or float
1198
        Minimum of the data to be plotted with the colormap.
1199
    vmax: int or float
1200
        Maximum of the data to be plotted with the colormap.
1201
    levels : int or list, optional
1202
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1203
    divergent : bool or int or float
1204
        If int or float, becomes center of cmap. Default center is 0.
1205
    linspace_out: bool
1206
        If True, return array created by np.linspace() instead of normalization instance.
1207

1208
    Returns
1209
    -------
1210
    matplotlib.colors.Normalize
1211
    """
1212
    # get cmap if string
1213
    if isinstance(cmap, str):
×
1214
        if cmap in plt.colormaps():
×
1215
            cmap = matplotlib.colormaps[cmap]
×
1216
        else:
1217
            raise ValueError("Colormap not found")
×
1218

1219
    # make vmin and vmax prettier
1220
    if (vmax - vmin) >= 25:
×
1221
        rvmax = math.ceil(vmax / 10.0) * 10
×
1222
        rvmin = math.floor(vmin / 10.0) * 10
×
1223
    elif 1 <= (vmax - vmin) < 25:
×
1224
        rvmax = math.ceil(vmax / 1) * 1
×
1225
        rvmin = math.floor(vmin / 1) * 1
×
1226
    elif 0.1 <= (vmax - vmin) < 1:
×
1227
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1228
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1229
    else:
1230
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1231
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1232

1233
    # center
1234
    center = None
×
1235
    if divergent is not False:
×
1236
        if divergent is True:
×
1237
            center = 0
×
1238
        elif isinstance(divergent, (int, float)):
×
1239
            center = divergent
×
1240

1241
    # build norm with options
1242
    if center is not None and isinstance(levels, int):
×
1243
        if center <= rvmin or center >= rvmax:
×
1244
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1245
        if levels % 2 == 1:
×
1246
            half_levels = int((levels + 1) / 2) + 1
×
1247
        else:
1248
            half_levels = int(levels / 2) + 1
×
1249

1250
        lin = np.concatenate(
×
1251
            (
1252
                np.linspace(rvmin, center, num=half_levels),
1253
                np.linspace(center, rvmax, num=half_levels)[1:],
1254
            )
1255
        )
1256
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1257

1258
        if linspace_out:
×
1259
            return lin
×
1260

1261
    elif levels is not None:
×
1262
        if isinstance(levels, list):
×
1263
            if center is not None:
×
1264
                warnings.warn(
×
1265
                    "Divergent argument ignored when levels is a list. Use levels as a number instead."
1266
                )
1267
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1268
        else:
1269
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1270
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1271

1272
            if linspace_out:
×
1273
                return lin
×
1274

1275
    elif center is not None:
×
1276
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1277
    else:
1278
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1279

1280
    return norm
×
1281

1282

1283
def norm2range(
5✔
1284
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1285
) -> np.ndarray:
1286
    """Normalize data across a specific range."""
1287
    if data_range is None:
×
1288
        if len(data) > 1:
×
1289
            data_range = (np.nanmin(data), np.nanmax(data))
×
1290
        else:
1291
            raise ValueError(" if data is not an array, data_range must be specified")
×
1292

1293
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1294

1295
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1296

1297

1298
def size_legend_elements(
5✔
1299
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1300
) -> list[matplotlib.lines.Line2D]:
1301
    """Create handles to use in a point-size legend.
1302

1303
    Parameters
1304
    ----------
1305
    data : np.ndarray
1306
        Data used to determine the point sizes.
1307
    sizes : np.ndarray
1308
        Array of point sizes.
1309
    max_entries : int
1310
        Maximum number of entries in the legend.
1311
    marker: str
1312
        Marker to use in legend.
1313

1314
    Returns
1315
    -------
1316
    list of matplotlib.lines.Line2D
1317
    """
1318
    # how many increments of 10 pts**2 are there in the sizes
1319
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1320

1321
    # divide data in those increments
1322
    lgd_data = np.linspace(min(data), max(data), n)
×
1323

1324
    # round according to range
1325
    ratio = abs(max(data) - min(data) / n)
×
1326

1327
    if ratio >= 1000:
×
1328
        rounding = 1000
×
1329
    elif 100 <= ratio < 1000:
×
1330
        rounding = 100
×
1331
    elif 10 <= ratio < 100:
×
1332
        rounding = 10
×
1333
    elif 5 <= ratio < 10:
×
1334
        rounding = 5
×
1335
    elif 1 <= ratio < 5:
×
1336
        rounding = 1
×
1337
    elif 0.1 <= ratio < 1:
×
1338
        rounding = 0.1
×
1339
    elif 0.01 <= ratio < 0.1:
×
1340
        rounding = 0.01
×
1341
    else:
1342
        rounding = 0.001
×
1343

1344
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1345

1346
    # convert back to sizes
1347
    lgd_sizes = norm2range(
×
1348
        data=lgd_data,
1349
        data_range=(min(data), max(data)),
1350
        target_range=(min(sizes), max(sizes)),
1351
    )
1352

1353
    legend_elements = []
×
1354

1355
    for s, d in zip(lgd_sizes, lgd_data):
×
1356
        if isinstance(d, float) and d.is_integer():
×
1357
            label = str(int(d))
×
1358
        else:
1359
            label = str(d)
×
1360

1361
        legend_elements.append(
×
1362
            Line2D(
1363
                [0],
1364
                [0],
1365
                marker=marker,
1366
                color="k",
1367
                lw=0,
1368
                markerfacecolor="w",
1369
                label=label,
1370
                markersize=np.sqrt(np.abs(s)),
1371
            )
1372
        )
1373

1374
    if len(legend_elements) > max_entries:
×
1375
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1376
    else:
1377
        return legend_elements
×
1378

1379

1380
def add_features_map(
5✔
1381
    data,
1382
    ax,
1383
    use_attrs,
1384
    projection,
1385
    features,
1386
    geometries_kw,
1387
    frame,
1388
) -> matplotlib.axes.Axes:
1389
    """Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1390

1391
    Parameters
1392
    ----------
1393
    data : dict, DataArray or Dataset
1394
        Input data do plot. If dictionary, must have only one entry.
1395
    ax : matplotlib axis
1396
        Matplotlib axis on which to plot, with the same projection as the one specified.
1397
    use_attrs : dict
1398
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1399
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1400
        Only the keys found in the default dict can be used.
1401
    projection : ccrs.Projection
1402
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1403
    features : list or dict
1404
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1405
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1406
    geometries_kw : dict
1407
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1408
    frame : bool
1409
        Show or hide frame. Default False.
1410

1411
    Returns
1412
    -------
1413
    matplotlib.axes.Axes
1414
    """
1415
    # add features
1416
    if features:
×
1417
        add_cartopy_features(ax, features)
×
1418

1419
    set_plot_attrs(use_attrs, data, ax)
×
1420

1421
    if frame is False:
×
1422
        ax.spines["geo"].set_visible(False)
×
1423

1424
    # add geometries
1425
    if geometries_kw:
×
1426
        if "geoms" not in geometries_kw.keys():
×
1427
            warnings.warn(
×
1428
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})'
1429
            )
1430
        if "crs" in geometries_kw.keys():
×
1431
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1432
                geometries_kw["geoms"], geometries_kw["crs"]
1433
            )
1434
        else:
1435
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1436
        geometries_kw = {
×
1437
            "crs": projection,
1438
            "facecolor": "none",
1439
            "edgecolor": "black",
1440
        } | geometries_kw
1441

1442
        ax.add_geometries(**geometries_kw)
×
1443
    return ax
×
1444

1445

1446
def masknan_sizes_key(data, sizes) -> xr.Dataset:
5✔
1447
    """Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1448

1449
    Parameters
1450
    ----------
1451
    data: xr.Dataset
1452
        xr.Dataset used to plot
1453
    sizes: str
1454
        Variable used to plot markersize
1455

1456
    Returns
1457
    -------
1458
    xr.Dataset
1459
    """
1460
    # find variable name
1461
    kl = list(data.keys())
×
1462
    kl.remove(sizes)
×
1463
    key = kl[0]
×
1464

1465
    # Create a mask for missing 'sizes' data
1466
    size_mask = np.isnan(data[sizes])
×
1467

1468
    # Set 'key' values to NaN where 'sizes' is missing
1469
    data[key] = data[key].where(~size_mask)
×
1470

1471
    # Create a mask for missing 'key' data
1472
    key_mask = np.isnan(data[key])
×
1473

1474
    # Set 'sizes' values to NaN where 'key' is missing
1475
    data[sizes] = data[sizes].where(~key_mask)
×
1476
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc