• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 15281473103

27 May 2025 05:11PM UTC coverage: 8.186% (-0.004%) from 8.19%
15281473103

push

github

web-flow
Fix RGB issue (#324)

<!-- Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [x] This PR addresses an already opened issue (for bug fixes /
features)
  - This PR fixes #239 
- [x] (If applicable) Documentation has been added / updated (for bug
fixes / features).
- [ ] (If applicable) Tests have been added.
- [x] CHANGELOG.rst has been updated (with summary of main changes).
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added.

### What kind of change does this PR introduce?

* Fix wrong convert from RGB to HEX (now rgb is used in
categorical_colors.json)

### Does this PR introduce a breaking change?
no


### Other information:

0 of 1 new or added line in 1 file covered. (0.0%)

1 existing line in 1 file now uncovered.

157 of 1918 relevant lines covered (8.19%)

0.41 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.35
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
5✔
4

5
import json
5✔
6
import math
5✔
7
import pathlib
5✔
8
import re
5✔
9
import warnings
5✔
10
from collections.abc import Callable
5✔
11
from copy import deepcopy
5✔
12
from tempfile import NamedTemporaryFile
5✔
13
from typing import Any
5✔
14

15
import cairosvg
5✔
16
import cartopy.crs as ccrs
5✔
17
import cartopy.feature as cfeature
5✔
18
import geopandas as gpd
5✔
19
import matplotlib as mpl
5✔
20
import matplotlib.axes
5✔
21
import matplotlib.colors as mcolors
5✔
22
import matplotlib.pyplot as plt
5✔
23
import numpy as np
5✔
24
import pandas as pd
5✔
25
import seaborn
5✔
26
import xarray as xr
5✔
27
import yaml
5✔
28
from matplotlib.lines import Line2D
5✔
29
from skimage.transform import resize
5✔
30
from xclim.core.options import METADATA_LOCALES
5✔
31
from xclim.core.options import OPTIONS as XC_OPTIONS
5✔
32

33
from .._logo import Logos
5✔
34

35
TERMS: dict = {}
5✔
36
"""
5✔
37
A translation directory for special terms to appear on the plots.
38

39
Keys are terms to translate and they map to "locale": "translation" dictionaries.
40
The "official" figanos terms are based on figanos/data/terms.yml.
41
"""
42

43

44
# Load terms translations
45
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
5✔
46
    TERMS = yaml.safe_load(f)
5✔
47

48

49
def get_localized_term(term, locale=None):
5✔
50
    """Get `term` translated into `locale`.
51

52
    Terms are pulled from the :py:data:`TERMS` dictionary.
53

54
    Parameters
55
    ----------
56
    term : str
57
        A word or short phrase to translate.
58
    locale : str, optional
59
        A 2-letter locale name to translate to.
60
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
61

62
    Returns
63
    -------
64
    str
65
        Translated term.
66
    """
67
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
68
    if locale == "en":
×
69
        return term
×
70

71
    if term not in TERMS:
×
72
        warnings.warn(f"No translation known for term '{term}'.")
×
73
        return term
×
74

75
    if locale not in TERMS[term]:
×
76
        warnings.warn(f"No {locale} translation known for term '{term}'.")
×
77
        return term
×
78

79
    return TERMS[term][locale]
×
80

81

82
def empty_dict(param) -> dict:
5✔
83
    """Return empty dict if input is None."""
84
    if param is None:
×
85
        param = dict()
×
86
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
87

88

89
def check_timeindex(
5✔
90
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
91
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
92
    """Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
93

94
    Parameters
95
    ----------
96
    xr_objs : xr.DataArray or xr.Dataset or dict
97
        Dictionary containing Xarray DataArrays or Datasets.
98

99
    Returns
100
    -------
101
    xr.DataArray or xr.Dataset or dict
102
        Dictionary of xarray objects with a pandas DatetimeIndex
103
    """
104
    if isinstance(xr_objs, dict):
×
105
        for name, obj in xr_objs.items():
×
106
            if "time" in obj.dims:
×
107
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
108
                    conv_obj = obj.convert_calendar(
×
109
                        "standard", use_cftime=None, align_on="year"
110
                    )
111
                    xr_objs[name] = conv_obj
×
112
                    warnings.warn(
×
113
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
114
                    )
115

116
    else:
117
        if "time" in xr_objs.dims:
×
118
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
119
                conv_obj = xr_objs.convert_calendar(
×
120
                    "standard", use_cftime=None, align_on="year"
121
                )
122
                xr_objs = conv_obj
×
123
                warnings.warn(
×
124
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar."
125
                )
126

127
    return xr_objs
×
128

129

130
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
5✔
131
    """Get an array category, which determines how to plot the array.
132

133
    Parameters
134
    ----------
135
    array : Dataset or DataArray
136
        The array being categorized.
137

138
    Returns
139
    -------
140
    str
141
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
142
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
143
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
144
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
145
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
146
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
147
        DS: any Dataset that is not  recognized as an ensemble
148
        DA: DataArray
149
    """
150
    if isinstance(array, xr.Dataset):
×
151
        if (
×
152
            pd.notnull(
153
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
154
            ).sum()
155
            >= 2
156
        ):
157
            cat = "ENS_PCT_VAR_DS"
×
158
        elif (
×
159
            pd.notnull(
160
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
161
            ).sum()
162
            >= 2
163
        ):
164
            cat = "ENS_STATS_VAR_DS"
×
165
        elif "percentiles" in array.dims:
×
166
            cat = "ENS_PCT_DIM_DS"
×
167
        elif "realization" in array.dims:
×
168
            cat = "ENS_REALS_DS"
×
169
        else:
170
            cat = "DS"
×
171

172
    elif isinstance(array, xr.DataArray):
×
173
        if "percentiles" in array.dims:
×
174
            cat = "ENS_PCT_DIM_DA"
×
175
        elif "realization" in array.dims:
×
176
            cat = "ENS_REALS_DA"
×
177
        else:
178
            cat = "DA"
×
179
    else:
180
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
181

182
    return cat
×
183

184

185
def get_attributes(
5✔
186
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
187
) -> str:
188
    """Fetch attributes or dims corresponding to keys from Xarray objects.
189

190
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
191
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
192

193
    Parameters
194
    ----------
195
    string : str
196
        String corresponding to an attribute name.
197
    xr_obj : DataArray or Dataset
198
        The Xarray object containing the attributes.
199
    locale : str, optional
200
        A 2-letter locale name to translate to.
201
        Default is None, which will pull the locale
202
        from xclim's "metadata_locales" option (taking the first).
203

204
    Returns
205
    -------
206
    str
207
        Xarray attribute value as string or empty string if not found
208
    """
209
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
210
    if locale != "en":
×
211
        names = [f"{string}_{locale}", string]
×
212
    else:
213
        names = [string]
×
214

215
    for name in names:
×
216
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
217
            return xr_obj.attrs[name]
×
218

219
        if (
×
220
            isinstance(xr_obj, xr.Dataset)
221
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
222
        ):  # DataArray of first variable
223
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
224

225
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
226
            return xr_obj.attrs[name]
×
227

228
    warnings.warn(f'Attribute "{string}" not found.')
×
229
    return ""
×
230

231

232
def set_plot_attrs(
5✔
233
    attr_dict: dict[str, Any],
234
    xr_obj: xr.DataArray | xr.Dataset,
235
    ax: matplotlib.axes.Axes | None = None,
236
    title_loc: str = "center",
237
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
238
    wrap_kw: dict[str, Any] | None = None,
239
) -> matplotlib.axes.Axes:
240
    """Set plot elements according to Dataset or DataArray attributes.
241

242
    Uses get_attributes() to check for and get the string.
243

244
    Parameters
245
    ----------
246
    attr_dict : dict
247
        Dictionary containing specified attribute keys.
248
    xr_obj : Dataset or DataArray
249
        The Xarray object containing the attributes.
250
    ax : matplotlib axis
251
        The matplotlib axis of the plot.
252
    title_loc : str
253
        Location of the title.
254
    wrap_kw : dict, optional
255
        Arguments to pass to the wrap_text function for the title.
256

257
    Returns
258
    -------
259
    matplotlib.axes.Axes
260
    """
261
    wrap_kw = empty_dict(wrap_kw)
×
262

263
    #  check
264
    for key in attr_dict:
×
265
        if key not in [
×
266
            "title",
267
            "ylabel",
268
            "yunits",
269
            "xlabel",
270
            "xunits",
271
            "cbar_label",
272
            "cbar_units",
273
            "suptitle",
274
        ]:
275
            warnings.warn(f'Use_attrs element "{key}" not supported')
×
276

277
    if "title" in attr_dict:
×
278
        title = get_attributes(attr_dict["title"], xr_obj)
×
279
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
280

281
    if "ylabel" in attr_dict:
×
282
        if (
×
283
            "yunits" in attr_dict
284
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
285
        ):  # second condition avoids '[]' as label
286
            ylabel = wrap_text(
×
287
                get_attributes(attr_dict["ylabel"], xr_obj)
288
                + " ("
289
                + get_attributes(attr_dict["yunits"], xr_obj)
290
                + ")"
291
            )
292
        else:
293
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
294

295
        ax.set_ylabel(ylabel)
×
296

297
    if "xlabel" in attr_dict:
×
298
        if (
×
299
            "xunits" in attr_dict
300
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
301
        ):  # second condition avoids '[]' as label
302
            xlabel = wrap_text(
×
303
                get_attributes(attr_dict["xlabel"], xr_obj)
304
                + " ("
305
                + get_attributes(attr_dict["xunits"], xr_obj)
306
                + ")"
307
            )
308
        else:
309
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
310

311
        ax.set_xlabel(xlabel)
×
312

313
    # cbar label has to be assigned in main function, ignore.
314
    if "cbar_label" in attr_dict:
×
315
        pass
×
316

317
    if "cbar_units" in attr_dict:
×
318
        pass
×
319

320
    if facetgrid:
×
321
        if "suptitle" in attr_dict:
×
322
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
323
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
324
            facetgrid.set_titles(template="{value}")
×
325
        return facetgrid
×
326

327
    else:
328
        return ax
×
329

330

331
def get_suffix(string: str) -> str:
5✔
332
    """Get suffix of typical Xclim variable names."""
333
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
334
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
335
        return suffix
×
336
    else:
337
        raise ValueError(f"Mean, min or max not found in {string}")
×
338

339

340
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
5✔
341
    """Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
342

343
    Parameters
344
    ----------
345
    array_dict : dict
346
        Dictionary of format {'name': array...}.
347

348
    Returns
349
    -------
350
    dict
351
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
352
    """
353
    if len(array_dict) != 3:
×
354
        raise ValueError("Ensembles must contain exactly three arrays")
×
355

356
    sorted_lines = {}
×
357

358
    for name in array_dict.keys():
×
359
        suffix = get_suffix(name)
×
360

361
        if suffix.isalpha():
×
362
            if suffix in ["max", "Max"]:
×
363
                sorted_lines["upper"] = name
×
364
            elif suffix in ["min", "Min"]:
×
365
                sorted_lines["lower"] = name
×
366
            elif suffix in ["mean", "Mean"]:
×
367
                sorted_lines["middle"] = name
×
368
        elif suffix.isdigit():
×
369
            if int(suffix) >= 51:
×
370
                sorted_lines["upper"] = name
×
371
            elif int(suffix) <= 49:
×
372
                sorted_lines["lower"] = name
×
373
            elif int(suffix) == 50:
×
374
                sorted_lines["middle"] = name
×
375
        else:
376
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
377
    return sorted_lines
×
378

379

380
def loc_mpl(
5✔
381
    loc: str | tuple[int | float, int | float] | int,
382
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
383
    """Find coordinates and alignment associated to loc string.
384

385
    Parameters
386
    ----------
387
    loc : string, int, or tuple[float, float]
388
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
389
        If a tuple, must be in axes coordinates.
390

391
    Returns
392
    -------
393
    tuple(float, float), tuple(float, float), str, str
394
    """
395
    ha = "left"
×
396
    va = "bottom"
×
397

398
    loc_strings = [
×
399
        "upper right",
400
        "upper left",
401
        "lower left",
402
        "lower right",
403
        "right",
404
        "center left",
405
        "center right",
406
        "lower center",
407
        "upper center",
408
        "center",
409
    ]
410

411
    if isinstance(loc, int):
×
412
        try:
×
413
            loc = loc_strings[loc - 1]
×
414
        except IndexError:
×
415
            raise ValueError("loc must be between 1 and 10, inclusively")
×
416

417
    if loc in loc_strings:
×
418
        # ha
419
        if "left" in loc:
×
420
            ha = "left"
×
421
        elif "right" in loc:
×
422
            ha = "right"
×
423
        else:
424
            ha = "center"
×
425

426
        # va
427
        if "lower" in loc:
×
428
            va = "bottom"
×
429
        elif "upper" in loc:
×
430
            va = "top"
×
431
        else:
432
            va = "center"
×
433

434
        # transAxes
435
        if loc == "upper right":
×
436
            loc = (0.97, 0.97)
×
437
            box_a = (1, 1)
×
438
        elif loc == "upper left":
×
439
            loc = (0.03, 0.97)
×
440
            box_a = (0, 1)
×
441
        elif loc == "lower left":
×
442
            loc = (0.03, 0.03)
×
443
            box_a = (0, 0)
×
444
        elif loc == "lower right":
×
445
            loc = (0.97, 0.03)
×
446
            box_a = (1, 0)
×
447
        elif loc == "right":
×
448
            loc = (0.97, 0.5)
×
449
            box_a = (1, 0.5)
×
450
        elif loc == "center left":
×
451
            loc = (0.03, 0.5)
×
452
            box_a = (0, 0.5)
×
453
        elif loc == "center right":
×
454
            loc = (0.97, 0.5)
×
455
            box_a = (0.97, 0.5)
×
456
        elif loc == "lower center":
×
457
            loc = (0.5, 0.03)
×
458
            box_a = (0.5, 0)
×
459
        elif loc == "upper center":
×
460
            loc = (0.5, 0.97)
×
461
            box_a = (0.5, 1)
×
462
        else:
463
            loc = (0.5, 0.5)
×
464
            box_a = (0.5, 0.5)
×
465

466
    elif isinstance(loc, tuple):
×
467
        box_a = []
×
468
        for i in loc:
×
469
            if i > 1 or i < 0:
×
470
                raise ValueError(
×
471
                    "Text location coordinates must be between 0 and 1, inclusively"
472
                )
473
            elif i > 0.5:
×
474
                box_a.append(1)
×
475
            else:
476
                box_a.append(0)
×
477
        box_a = tuple(box_a)
×
478
    else:
479
        raise ValueError(
×
480
            "loc must be a string, int or tuple. "
481
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
482
        )
483

484
    return loc, box_a, ha, va
×
485

486

487
def plot_coords(
5✔
488
    ax: matplotlib.axes.Axes | None,
489
    xr_obj: xr.DataArray | xr.Dataset,
490
    loc: str | tuple[float, float] | int,
491
    param: str | None = None,
492
    backgroundalpha: float = 1,
493
) -> matplotlib.axes.Axes:
494
    """Place coordinates on plot area.
495

496
    Parameters
497
    ----------
498
    ax : matplotlib.axes.Axes or None
499
        Matplotlib axes object on which to place the text.
500
        If None, will use plt.figtext instead (should be used for facetgrids).
501
    xr_obj : xr.DataArray or xr.Dataset
502
        The xarray object from which to fetch the text content.
503
    param : {"location", "time"}, optional
504
        The parameter used.
505
    loc : string, int or tuple
506
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
507
        If a tuple, must be in axes coordinates.
508
    backgroundalpha : float
509
        Transparency of the text background. 1 is opaque, 0 is transparent.
510

511
    Returns
512
    -------
513
    matplotlib.axes.Axes
514
    """
515
    text = None
×
516
    if param == "location":
×
517
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
518
            text = "lat={:.2f}, lon={:.2f}".format(
×
519
                float(xr_obj["lat"]), float(xr_obj["lon"])
520
            )
521
        else:
522
            warnings.warn(
×
523
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords'
524
            )
525
    if param == "time":
×
526
        if "time" in xr_obj.coords:
×
527
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
528

529
        else:
530
            warnings.warn('show_time set to True, but "time" not found in coords')
×
531

532
    loc, box_a, ha, va = loc_mpl(loc)
×
533

534
    if text:
×
535
        if ax:
×
536
            t = mpl.offsetbox.TextArea(
×
537
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
538
            )
539

540
            tt = mpl.offsetbox.AnnotationBbox(
×
541
                t,
542
                loc,
543
                xycoords="axes fraction",
544
                box_alignment=box_a,
545
                pad=0.05,
546
                bboxprops=dict(
547
                    facecolor="white",
548
                    alpha=backgroundalpha,
549
                    edgecolor="w",
550
                    boxstyle="Square, pad=0.5",
551
                ),
552
            )
553
            ax.add_artist(tt)
×
554
            return ax
×
555
        elif not ax:
×
556
            """
×
557
            if loc == "top left":
558
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
559
            elif loc == "top right":
560
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
561
            elif loc == "bottom left":
562
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
563
            elif loc == "bottom right" or loc is True:
564
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
565
            elif isinstance(loc, tuple):
566
                        else:
567
                raise ValueError(
568
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
569
                    f"'bottom right' or a tuple of coordinates."
570
                )
571
            """
572
            plt.figtext(
×
573
                loc[0],
574
                loc[1],
575
                text,
576
                ha=ha,
577
                va=va,
578
                fontsize=12,
579
            )
580

581
            return None
×
582

583

584
def find_logo(logo: str | pathlib.Path) -> str:
5✔
585
    """Read a logo file."""
586
    logos = Logos()
×
587
    if logo:
×
588
        logo_path = logos[logo]
×
589
    else:
590
        logo_path = logos.default
×
591

592
    if logo_path is None:
×
593
        raise ValueError(
×
594
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
595
        )
596
    return logo_path
×
597

598

599
def load_image(
5✔
600
    im: str | pathlib.Path,
601
    height: float | None,
602
    width: float | None,
603
    keep_ratio: bool = True,
604
) -> np.ndarray:
605
    """Scale an image to a specified height and width.
606

607
    Parameters
608
    ----------
609
    im : str or Path
610
        The image to be scaled. PNG and SVG formats are supported.
611
    height : float, optional
612
        The desired height of the image. If None, the original height is used.
613
    width : float, optional
614
        The desired width of the image. If None, the original width is used.
615
    keep_ratio : bool
616
        If True, the aspect ratio of the original image is maintained. Default is True.
617

618
    Returns
619
    -------
620
    np.ndarray
621
        The scaled image.
622
    """
623
    if pathlib.Path(im).suffix == ".png":
×
624
        image = mpl.pyplot.imread(im)
×
625
        original_height, original_width = image.shape[:2]
×
626

627
        if height is None and width is None:
×
628
            return image
×
629

630
        warnings.warn(
×
631
            "The scikit-image library is used to resize PNG images. This may affect logo image quality."
632
        )
633
        if not keep_ratio:
×
634
            height = original_height or height
×
635
            width = original_width or width
×
636
        else:
637
            if width is not None:
×
638
                if height is not None:
×
639
                    warnings.warn("Both height and width provided, using height.")
×
640
                # Only width is provided, derive zoom factor for height based on aspect ratio
641
                height = (width / original_width) * original_height
×
642
            elif height is not None:
×
643
                # Only height is provided, derive zoom factor for width based on aspect ratio
644
                width = (height / original_height) * original_width
×
645

646
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
647

648
    elif pathlib.Path(im).suffix == ".svg":
×
649
        cairo_kwargs = dict(url=im)
×
650
        if not keep_ratio:
×
651
            if height is not None and width is not None:
×
652
                cairo_kwargs.update(output_height=height, output_width=width)
×
653
        elif width is not None:
×
654
            if height is not None:
×
655
                warnings.warn("Both height and width provided, using height.")
×
656
            cairo_kwargs.update(output_width=width)
×
657
        elif height is not None:
×
658
            cairo_kwargs.update(output_height=height)
×
659

660
        with NamedTemporaryFile(suffix=".png") as png_file:
×
661
            cairo_kwargs.update(write_to=png_file.name)
×
662
            cairosvg.svg2png(**cairo_kwargs)
×
663
            return mpl.pyplot.imread(png_file.name)
×
664

665

666
def plot_logo(
5✔
667
    ax: matplotlib.axes.Axes,
668
    loc: str | tuple[float, float] | int,
669
    logo: str | pathlib.Path | Logos | None = None,
670
    height: float | None = None,
671
    width: float | None = None,
672
    keep_ratio: bool = True,
673
    **offset_image_kwargs,
674
) -> matplotlib.axes.Axes:
675
    r"""Place logo of plot area.
676

677
    Parameters
678
    ----------
679
    ax : matplotlib.axes.Axes
680
        Matplotlib axes object on which to place the text.
681
    loc : string, int or tuple
682
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
683
        If a tuple, must be in axes coordinates.
684
    logo : str, Path, figanos.Logos, optional
685
        A name (str) or Path to a logo file, or a name of an already-installed logo.
686
        If an existing is not found, the logo will be installed and accessible via the filename.
687
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
688
        The logo must be in 'png' format.
689
    height : float, optional
690
        The desired height of the image. If None, the original height is used.
691
    width : float, optional
692
        The desired width of the image. If None, the original width is used.
693
    keep_ratio : bool, optional
694
        If True, the aspect ratio of the original image is maintained. Default is True.
695
    \*\*offset_image_kwargs
696
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
697

698
    Returns
699
    -------
700
    matplotlib.axes.Axes
701
    """
702
    if offset_image_kwargs is None:
×
703
        offset_image_kwargs = {}
×
704

705
    if isinstance(logo, Logos):
×
706
        logo_path = logo.default
×
707
    else:
708
        logo_path = find_logo(logo)
×
709

710
    image = load_image(logo_path, height, width, keep_ratio)
×
711
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
712

713
    loc, box_a, ha, va = loc_mpl(loc)
×
714
    ab = mpl.offsetbox.AnnotationBbox(
×
715
        imagebox,
716
        loc,
717
        frameon=False,
718
        xycoords="axes fraction",
719
        box_alignment=box_a,
720
        pad=0.05,
721
    )
722
    ax.add_artist(ab)
×
723
    return ax
×
724

725

726
def split_legend(
5✔
727
    ax: matplotlib.axes.Axes,
728
    in_plot: bool = False,
729
    axis_factor: float = 0.15,
730
    label_gap: float = 0.02,
731
) -> matplotlib.axes.Axes:
732
    #  TODO: check for and fix overlapping labels
733
    """Draw line labels at the end of each line, or outside the plot.
734

735
    Parameters
736
    ----------
737
    ax : matplotlib.axes.Axes
738
        The axis containing the legend.
739
    in_plot : bool
740
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
741
    axis_factor : float
742
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
743
    label_gap : float
744
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
745

746
    Returns
747
    -------
748
    matplotlib.axes.Axes
749
    """
750
    # create extra space
751
    init_xbound = ax.get_xbound()
×
752

753
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
754
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
755

756
    if in_plot is True:
×
757
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
758

759
    # get legend and plot
760

761
    handles, labels = ax.get_legend_handles_labels()
×
762
    for handle, label in zip(handles, labels):
×
763
        last_x = handle.get_xdata()[-1]
×
764
        last_y = handle.get_ydata()[-1]
×
765

766
        if isinstance(last_x, np.datetime64):
×
767
            last_x = mpl.dates.date2num(last_x)
×
768

769
        color = handle.get_color()
×
770
        # ls = handle.get_linestyle()
771

772
        if in_plot is True:
×
773
            ax.text(
×
774
                last_x + label_bump,
775
                last_y,
776
                label,
777
                ha="left",
778
                va="center",
779
                color=color,
780
            )
781
        else:
782
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
783
            ax.text(
×
784
                1.01,
785
                last_y,
786
                label,
787
                ha="left",
788
                va="center",
789
                color=color,
790
                transform=trans,
791
            )
792

793
    return ax
×
794

795

796
def fill_between_label(
5✔
797
    sorted_lines: dict[str, Any],
798
    name: str,
799
    array_categ: dict[str, Any],
800
    legend: str,
801
) -> str:
802
    """Create a label for the shading around a line in line plots.
803

804
    Parameters
805
    ----------
806
    sorted_lines : dict
807
        Dictionary created by the sort_lines() function.
808
    name : str
809
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
810
    array_categ : dict
811
        The categories of the array, as created by the get_array_categ function.
812
    legend : str
813
        Legend mode.
814

815
    Returns
816
    -------
817
    str
818
        Label to be applied to the legend element representing the shading.
819
    """
820
    if legend != "full":
×
821
        label = None
×
822
    elif array_categ[name] in [
×
823
        "ENS_PCT_VAR_DS",
824
        "ENS_PCT_DIM_DS",
825
        "ENS_PCT_DIM_DA",
826
    ]:
827
        label = get_localized_term("{}th-{}th percentiles").format(
×
828
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
829
        )
830
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
831
        label = get_localized_term("min-max range")
×
832
    else:
833
        label = None
×
834

835
    return label
×
836

837

838
def get_var_group(
5✔
839
    path_to_json: str | pathlib.Path,
840
    da: xr.DataArray | None = None,
841
    unique_str: str | None = None,
842
) -> str:
843
    """Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
844

845
    If `da` is a Dataset, look in the DataArray of the first variable.
846
    """
847
    # create dict
848
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
849
        var_dict = json.load(_f)
×
850

851
    matches = []
×
852

853
    if unique_str:
×
854
        for v in var_dict:
×
855
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
856
            if re.search(regex, unique_str):
×
857
                matches.append(var_dict[v])
×
858

859
    else:
860
        if isinstance(da, xr.Dataset):
×
861
            da = da[list(da.data_vars)[0]]
×
862
        # look in DataArray name
863
        if hasattr(da, "name") and isinstance(da.name, str):
×
864
            for v in var_dict:
×
865
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
866
                if re.search(regex, da.name):
×
867
                    matches.append(var_dict[v])
×
868

869
        # look in history
870
        if hasattr(da, "history") and len(matches) == 0:
×
871
            for v in var_dict:
×
872
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
873
                if re.search(regex, da.history):
×
874
                    matches.append(var_dict[v])
×
875

876
    matches = np.unique(matches)
×
877

878
    if len(matches) == 0:
×
879
        warnings.warn(
×
880
            "Colormap warning: Variable group not found. Use the cmap argument."
881
        )
882
        return "misc"
×
883
    elif len(matches) >= 2:
×
884
        warnings.warn(
×
885
            "Colormap warning: More than one variable group found. Use the cmap argument."
886
        )
887
        return "misc"
×
888
    else:
889
        return matches[0]
×
890

891

892
def create_cmap(
5✔
893
    var_group: str | None = None,
894
    divergent: bool | int = False,
895
    filename: str | None = None,
896
) -> matplotlib.colors.Colormap:
897
    """Create colormap according to variable group.
898

899
    Parameters
900
    ----------
901
    var_group : str, optional
902
        Variable group from IPCC scheme.
903
    divergent : bool or int
904
        Diverging colormap. If False, use sequential colormap.
905
    filename : str, optional
906
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
907

908
    Returns
909
    -------
910
    matplotlib.colors.Colormap
911
    """
912
    reverse = False
×
913

914
    if filename:
×
915
        folder = "continuous_colormaps_rgb_0-255"
×
916
        filename = filename.replace(".txt", "")
×
917

918
        if filename.endswith("_r"):
×
919
            reverse = True
×
920
            filename = filename[:-2]
×
921

922
    else:
923
        # filename
924
        if divergent is not False:
×
925
            if var_group == "misc2":
×
926
                var_group = "misc"
×
927
            filename = var_group + "_div"
×
928
        else:
929
            if var_group == "misc":
×
930
                filename = var_group + "_seq_3"  # Batlow
×
931
            elif var_group == "misc2":
×
932
                filename = "misc_seq_2"  # freezing rain
×
933
            else:
934
                filename = var_group + "_seq"
×
935

936
        folder = "continuous_colormaps_rgb_0-255"
×
937

938
    # parent should be 'figanos/'
939
    path = (
×
940
        pathlib.Path(__file__).parents[1]
941
        / "data"
942
        / "ipcc_colors"
943
        / folder
944
        / (filename + ".txt")
945
    )
946

947
    rgb_data = np.loadtxt(path)
×
948

949
    # convert to 0-1 RGB
950
    rgb_data = rgb_data / 255
×
951

952
    cmap = mcolors.LinearSegmentedColormap.from_list("cmap", rgb_data, N=256)
×
953
    if reverse is True:
×
954
        cmap = cmap.reversed()
×
955

956
    return cmap
×
957

958

959
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
5✔
960
    """Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
961

962
    Parameters
963
    ----------
964
    xr_obj : xr.DataArray or xr.Dataset
965
        The xarray object from which to look for the attributes.
966

967
    Returns
968
    -------
969
    ccrs.RotatedPole or None
970
    """
971
    try:
×
972

973
        if isinstance(xr_obj, xr.Dataset):
×
974
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
975

976
            if len(gridmap) > 1:
×
977
                warnings.warn(
×
978
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}."
979
                )
980

981
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
982
        else:
983
            # If it can't find grid_mapping, assume it's rotated_pole
984
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
985

986
        rotpole = ccrs.RotatedPole(
×
987
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
988
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
989
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
990
        )
991
        return rotpole
×
992

993
    except AttributeError:
×
994
        warnings.warn("Rotated pole not found. Specify a transform if necessary.")
×
995
        return None
×
996

997

998
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
5✔
999
    """Wrap text.
1000

1001
    Parameters
1002
    ----------
1003
    text : str
1004
        The text to wrap.
1005
    min_line_len : int
1006
        Minimum length of each line.
1007
    max_line_len : int
1008
        Maximum length of each line.
1009

1010
    Returns
1011
    -------
1012
    str
1013
        Wrapped text
1014
    """
1015
    start = min_line_len
×
1016
    stop = max_line_len
×
1017
    sep = "\n"
×
1018
    remaining = len(text)
×
1019

1020
    if len(text) >= max_line_len:
×
1021
        while remaining > max_line_len:
×
1022
            if ". " in text[start:stop]:
×
1023
                pos = text.find(". ", start, stop) + 1
×
1024
            elif ": " in text[start:stop]:
×
1025
                pos = text.find(": ", start, stop) + 1
×
1026
            elif " " in text[start:stop]:
×
1027
                pos = text.rfind(" ", start, stop)
×
1028
            else:
1029
                warnings.warn("No spaces, points or colons to break line at.")
×
1030
                break
×
1031

1032
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1033

1034
            remaining = len(text) - len(text[:pos])
×
1035
            start = pos + 1 + min_line_len
×
1036
            stop = pos + 1 + max_line_len
×
1037

1038
    return text
×
1039

1040

1041
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
5✔
1042
    """Open shapefile with geopandas and convert to cartopy projection.
1043

1044
    Parameters
1045
    ----------
1046
    df : gpd.GeoDataFrame
1047
        GeoDataFrame (geopandas) geometry to be added to axis.
1048
    proj : ccrs.CRS
1049
        Projection to use, taken from the cartopy.crs options.
1050

1051
    Returns
1052
    -------
1053
    gpd.GeoDataFrame
1054
        GeoDataFrame adjusted to given projection
1055
    """
1056
    prj4 = proj.proj4_init
×
1057
    return df.to_crs(prj4)
×
1058

1059

1060
def convert_scen_name(name: str) -> str:
5✔
1061
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1062
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1063
    if matches:
×
1064
        for s in matches:
×
1065
            if sum(c.isdigit() for c in s) == 3:
×
1066
                new_s = s.replace(
×
1067
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1068
                ).upper()  # ssp245 to SSP2-4.5
1069
                new_name = name.replace(s, new_s)  # put back in name
×
1070
            elif sum(c.isdigit() for c in s) == 2:
×
1071
                new_s = s.replace(
×
1072
                    s[-2:], s[-2] + "." + s[-1]
1073
                ).upper()  # rcp45 to RCP4.5
1074
                new_name = name.replace(s, new_s)
×
1075
            else:
1076
                new_s = s.upper()  # cmip5 to CMIP5
×
1077
                new_name = name.replace(s, new_s)
×
1078

1079
        return new_name
×
1080
    else:
1081
        return name
×
1082

1083

1084
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
5✔
1085
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1086
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1087
        color_dict = json.load(_f)
×
1088

1089
    color = None
×
1090
    for entry in color_dict:
×
1091
        if entry in name:
×
1092
            color = color_dict[entry]
×
NEW
1093
            color = tuple([i / 255 for i in color])
×
UNCOV
1094
            break
×
1095

1096
    return color
×
1097

1098

1099
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
5✔
1100
    """Apply function to dictionary keys."""
1101
    old_keys = [key for key in dct]
×
1102
    for old_key in old_keys:
×
1103
        new_key = func(old_key)
×
1104
        dct[new_key] = dct.pop(old_key)
×
1105
    return dct
×
1106

1107

1108
def categorical_colors() -> dict[str, str]:
5✔
1109
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1110
    path = (
×
1111
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1112
    )
1113
    with path.open(encoding="utf-8") as _f:
×
1114
        cat = json.load(_f)
×
1115

1116
        return cat
×
1117

1118

1119
def get_mpl_styles() -> dict[str, pathlib.Path]:
5✔
1120
    """Get the available matplotlib styles and their paths as a dictionary."""
1121
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1122
    styles = {style.stem: style for style in files}
×
1123
    return styles
×
1124

1125

1126
def set_mpl_style(*args: str, reset: bool = False) -> None:
5✔
1127
    """Set the matplotlib style using one or more stylesheets.
1128

1129
    Parameters
1130
    ----------
1131
    args : str
1132
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1133
    reset : bool
1134
        If True, reset style to matplotlib default before applying the stylesheets.
1135

1136
    Returns
1137
    -------
1138
    None
1139
    """
1140
    if reset is True:
×
1141
        mpl.style.use("default")
×
1142
    for s in args:
×
1143
        if s.endswith(".mplstyle") is True:
×
1144
            mpl.style.use(s)
×
1145
        elif s in get_mpl_styles():
×
1146
            mpl.style.use(get_mpl_styles()[s])
×
1147
        else:
1148
            warnings.warn(f"Style {s} not found.")
×
1149

1150

1151
def add_cartopy_features(
5✔
1152
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1153
) -> matplotlib.axes.Axes:
1154
    """Add cartopy features to matplotlib axes.
1155

1156
    Parameters
1157
    ----------
1158
    ax : matplotlib.axes.Axes
1159
        The axes on which to add the features.
1160
    features : list or dict
1161
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1162

1163
    Returns
1164
    -------
1165
    matplotlib.axes.Axes
1166
        The axis with added features.
1167
    """
1168
    if isinstance(features, list):
×
1169
        features = {f: {} for f in features}
×
1170

1171
    for feat in features:
×
1172
        if "scale" not in features[feat]:
×
1173
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1174
        else:
1175
            scale = features[feat].pop("scale")
×
1176
            ax.add_feature(
×
1177
                getattr(cfeature, feat.upper()).with_scale(scale),
1178
                **features[feat],
1179
            )
1180
            features[feat]["scale"] = scale  # put back
×
1181
    return ax
×
1182

1183

1184
def custom_cmap_norm(
5✔
1185
    cmap,
1186
    vmin: int | float,
1187
    vmax: int | float,
1188
    levels: int | list[int | float] | None = None,
1189
    divergent: bool | int | float = False,
1190
    linspace_out: bool = False,
1191
) -> matplotlib.colors.Normalize | np.ndarray:
1192
    """Get matplotlib normalization according to main function arguments.
1193

1194
    Parameters
1195
    ----------
1196
    cmap: matplotlib.colormap
1197
        Colormap to be used with the normalization.
1198
    vmin: int or float
1199
        Minimum of the data to be plotted with the colormap.
1200
    vmax: int or float
1201
        Maximum of the data to be plotted with the colormap.
1202
    levels : int or list, optional
1203
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1204
    divergent : bool or int or float
1205
        If int or float, becomes center of cmap. Default center is 0.
1206
    linspace_out: bool
1207
        If True, return array created by np.linspace() instead of normalization instance.
1208

1209
    Returns
1210
    -------
1211
    matplotlib.colors.Normalize
1212
    """
1213
    # get cmap if string
1214
    if isinstance(cmap, str):
×
1215
        if cmap in plt.colormaps():
×
1216
            cmap = matplotlib.colormaps[cmap]
×
1217
        else:
1218
            raise ValueError("Colormap not found")
×
1219

1220
    # make vmin and vmax prettier
1221
    if (vmax - vmin) >= 25:
×
1222
        rvmax = math.ceil(vmax / 10.0) * 10
×
1223
        rvmin = math.floor(vmin / 10.0) * 10
×
1224
    elif 1 <= (vmax - vmin) < 25:
×
1225
        rvmax = math.ceil(vmax / 1) * 1
×
1226
        rvmin = math.floor(vmin / 1) * 1
×
1227
    elif 0.1 <= (vmax - vmin) < 1:
×
1228
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1229
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1230
    else:
1231
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1232
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1233

1234
    # center
1235
    center = None
×
1236
    if divergent is not False:
×
1237
        if divergent is True:
×
1238
            center = 0
×
1239
        elif isinstance(divergent, (int, float)):
×
1240
            center = divergent
×
1241

1242
    # build norm with options
1243
    if center is not None and isinstance(levels, int):
×
1244
        if center <= rvmin or center >= rvmax:
×
1245
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1246
        if levels % 2 == 1:
×
1247
            half_levels = int((levels + 1) / 2) + 1
×
1248
        else:
1249
            half_levels = int(levels / 2) + 1
×
1250

1251
        lin = np.concatenate(
×
1252
            (
1253
                np.linspace(rvmin, center, num=half_levels),
1254
                np.linspace(center, rvmax, num=half_levels)[1:],
1255
            )
1256
        )
1257
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1258

1259
        if linspace_out:
×
1260
            return lin
×
1261

1262
    elif levels is not None:
×
1263
        if isinstance(levels, list):
×
1264
            if center is not None:
×
1265
                warnings.warn(
×
1266
                    "Divergent argument ignored when levels is a list. Use levels as a number instead."
1267
                )
1268
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1269
        else:
1270
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1271
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1272

1273
            if linspace_out:
×
1274
                return lin
×
1275

1276
    elif center is not None:
×
1277
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1278
    else:
1279
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1280

1281
    return norm
×
1282

1283

1284
def norm2range(
5✔
1285
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1286
) -> np.ndarray:
1287
    """Normalize data across a specific range."""
1288
    if data_range is None:
×
1289
        if len(data) > 1:
×
1290
            data_range = (np.nanmin(data), np.nanmax(data))
×
1291
        else:
1292
            raise ValueError(" if data is not an array, data_range must be specified")
×
1293

1294
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1295

1296
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1297

1298

1299
def size_legend_elements(
5✔
1300
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1301
) -> list[matplotlib.lines.Line2D]:
1302
    """Create handles to use in a point-size legend.
1303

1304
    Parameters
1305
    ----------
1306
    data : np.ndarray
1307
        Data used to determine the point sizes.
1308
    sizes : np.ndarray
1309
        Array of point sizes.
1310
    max_entries : int
1311
        Maximum number of entries in the legend.
1312
    marker: str
1313
        Marker to use in legend.
1314

1315
    Returns
1316
    -------
1317
    list of matplotlib.lines.Line2D
1318
    """
1319
    # how many increments of 10 pts**2 are there in the sizes
1320
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1321

1322
    # divide data in those increments
1323
    lgd_data = np.linspace(min(data), max(data), n)
×
1324

1325
    # round according to range
1326
    ratio = abs(max(data) - min(data) / n)
×
1327

1328
    if ratio >= 1000:
×
1329
        rounding = 1000
×
1330
    elif 100 <= ratio < 1000:
×
1331
        rounding = 100
×
1332
    elif 10 <= ratio < 100:
×
1333
        rounding = 10
×
1334
    elif 5 <= ratio < 10:
×
1335
        rounding = 5
×
1336
    elif 1 <= ratio < 5:
×
1337
        rounding = 1
×
1338
    elif 0.1 <= ratio < 1:
×
1339
        rounding = 0.1
×
1340
    elif 0.01 <= ratio < 0.1:
×
1341
        rounding = 0.01
×
1342
    else:
1343
        rounding = 0.001
×
1344

1345
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1346

1347
    # convert back to sizes
1348
    lgd_sizes = norm2range(
×
1349
        data=lgd_data,
1350
        data_range=(min(data), max(data)),
1351
        target_range=(min(sizes), max(sizes)),
1352
    )
1353

1354
    legend_elements = []
×
1355

1356
    for s, d in zip(lgd_sizes, lgd_data):
×
1357
        if isinstance(d, float) and d.is_integer():
×
1358
            label = str(int(d))
×
1359
        else:
1360
            label = str(d)
×
1361

1362
        legend_elements.append(
×
1363
            Line2D(
1364
                [0],
1365
                [0],
1366
                marker=marker,
1367
                color="k",
1368
                lw=0,
1369
                markerfacecolor="w",
1370
                label=label,
1371
                markersize=np.sqrt(np.abs(s)),
1372
            )
1373
        )
1374

1375
    if len(legend_elements) > max_entries:
×
1376
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1377
    else:
1378
        return legend_elements
×
1379

1380

1381
def add_features_map(
5✔
1382
    data,
1383
    ax,
1384
    use_attrs,
1385
    projection,
1386
    features,
1387
    geometries_kw,
1388
    frame,
1389
) -> matplotlib.axes.Axes:
1390
    """Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1391

1392
    Parameters
1393
    ----------
1394
    data : dict, DataArray or Dataset
1395
        Input data do plot. If dictionary, must have only one entry.
1396
    ax : matplotlib axis
1397
        Matplotlib axis on which to plot, with the same projection as the one specified.
1398
    use_attrs : dict
1399
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1400
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1401
        Only the keys found in the default dict can be used.
1402
    projection : ccrs.Projection
1403
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1404
    features : list or dict
1405
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1406
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1407
    geometries_kw : dict
1408
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1409
    frame : bool
1410
        Show or hide frame. Default False.
1411

1412
    Returns
1413
    -------
1414
    matplotlib.axes.Axes
1415
    """
1416
    # add features
1417
    if features:
×
1418
        add_cartopy_features(ax, features)
×
1419

1420
    set_plot_attrs(use_attrs, data, ax)
×
1421

1422
    if frame is False:
×
1423
        ax.spines["geo"].set_visible(False)
×
1424

1425
    # add geometries
1426
    if geometries_kw:
×
1427
        if "geoms" not in geometries_kw.keys():
×
1428
            warnings.warn(
×
1429
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})'
1430
            )
1431
        if "crs" in geometries_kw.keys():
×
1432
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1433
                geometries_kw["geoms"], geometries_kw["crs"]
1434
            )
1435
        else:
1436
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1437
        geometries_kw = {
×
1438
            "crs": projection,
1439
            "facecolor": "none",
1440
            "edgecolor": "black",
1441
        } | geometries_kw
1442

1443
        ax.add_geometries(**geometries_kw)
×
1444
    return ax
×
1445

1446

1447
def masknan_sizes_key(data, sizes) -> xr.Dataset:
5✔
1448
    """Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1449

1450
    Parameters
1451
    ----------
1452
    data: xr.Dataset
1453
        xr.Dataset used to plot
1454
    sizes: str
1455
        Variable used to plot markersize
1456

1457
    Returns
1458
    -------
1459
    xr.Dataset
1460
    """
1461
    # find variable name
1462
    kl = list(data.keys())
×
1463
    kl.remove(sizes)
×
1464
    key = kl[0]
×
1465

1466
    # Create a mask for missing 'sizes' data
1467
    size_mask = np.isnan(data[sizes])
×
1468

1469
    # Set 'key' values to NaN where 'sizes' is missing
1470
    data[key] = data[key].where(~size_mask)
×
1471

1472
    # Create a mask for missing 'key' data
1473
    key_mask = np.isnan(data[key])
×
1474

1475
    # Set 'sizes' values to NaN where 'key' is missing
1476
    data[sizes] = data[sizes].where(~key_mask)
×
1477
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc