• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 18353693760

08 Oct 2025 06:01PM UTC coverage: 8.075% (-0.05%) from 8.122%
18353693760

push

github

web-flow
Update cookiecutter template, add CITATION.cff (#362)

### What kind of change does this PR introduce?

* Updates the cookiecutter to the latest version
* Removed `black`, `blackdoc`, and `isort`
* Added a `CITATION.cff` file
* `pyproject.toml` is now PEP-639 compliant 
* Contributor Covenant agreement updated to v3.0
* Some dependency updates

### Does this PR introduce a breaking change?

Yes. `black`, `blackdoc`, and `isort` have been dropped for `ruff`.

### Other information:

There may be some small adjustments to the code base. `ruff` is
configured to be very similar to `black`, but it is not 1-1 compliant
with it.

0 of 56 new or added lines in 4 files covered. (0.0%)

2 existing lines in 1 file now uncovered.

156 of 1932 relevant lines covered (8.07%)

0.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

11.35
/src/figanos/matplotlib/utils.py
1
"""Utility functions for figanos figure-creation."""
2

3
from __future__ import annotations
5✔
4
import json
5✔
5
import math
5✔
6
import pathlib
5✔
7
import re
5✔
8
import warnings
5✔
9
from collections.abc import Callable
5✔
10
from copy import deepcopy
5✔
11
from tempfile import NamedTemporaryFile
5✔
12
from typing import Any
5✔
13

14
import cairosvg
5✔
15
import cartopy.crs as ccrs
5✔
16
import cartopy.feature as cfeature
5✔
17
import geopandas as gpd
5✔
18
import matplotlib as mpl
5✔
19
import matplotlib.axes
5✔
20
import matplotlib.colors as mcolors
5✔
21
import matplotlib.pyplot as plt
5✔
22
import numpy as np
5✔
23
import pandas as pd
5✔
24
import seaborn
5✔
25
import xarray as xr
5✔
26
import yaml
5✔
27
from matplotlib.lines import Line2D
5✔
28
from skimage.transform import resize
5✔
29
from xclim.core.options import METADATA_LOCALES
5✔
30
from xclim.core.options import OPTIONS as XC_OPTIONS
5✔
31

32
from .._logo import Logos
5✔
33

34

35
TERMS: dict = {}
5✔
36
"""
5✔
37
A translation directory for special terms to appear on the plots.
38

39
Keys are terms to translate and they map to "locale": "translation" dictionaries.
40
The "official" figanos terms are based on figanos/data/terms.yml.
41
"""
42

43

44
# Load terms translations
45
with (pathlib.Path(__file__).resolve().parents[1] / "data" / "terms.yml").open() as f:
5✔
46
    TERMS = yaml.safe_load(f)
5✔
47

48

49
def get_localized_term(term, locale=None):
5✔
50
    """
51
    Get `term` translated into `locale`.
52

53
    Terms are pulled from the :py:data:`TERMS` dictionary.
54

55
    Parameters
56
    ----------
57
    term : str
58
        A word or short phrase to translate.
59
    locale : str, optional
60
        A 2-letter locale name to translate to.
61
        Default is None, which will pull the locale from xclim's "metadata_locales" option (taking the first).
62

63
    Returns
64
    -------
65
    str
66
        Translated term.
67
    """
68
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
69
    if locale == "en":
×
70
        return term
×
71

72
    if term not in TERMS:
×
NEW
73
        warnings.warn(f"No translation known for term '{term}'.", stacklevel=2)
×
74
        return term
×
75

76
    if locale not in TERMS[term]:
×
NEW
77
        warnings.warn(f"No {locale} translation known for term '{term}'.", stacklevel=2)
×
78
        return term
×
79

80
    return TERMS[term][locale]
×
81

82

83
def empty_dict(param) -> dict:
5✔
84
    """Return empty dict if input is None."""
85
    if param is None:
×
86
        param = dict()
×
87
    return deepcopy(param)  # avoid modifying original input dict when popping items
×
88

89

90
def check_timeindex(
5✔
91
    xr_objs: xr.DataArray | xr.Dataset | dict[str, Any],
92
) -> xr.DataArray | xr.Dataset | dict[str, Any]:
93
    """
94
    Check if the time index of Xarray objects in a dict is CFtime and convert to pd.DatetimeIndex if True.
95

96
    Parameters
97
    ----------
98
    xr_objs : xr.DataArray or xr.Dataset or dict
99
        Dictionary containing Xarray DataArrays or Datasets.
100

101
    Returns
102
    -------
103
    xr.DataArray or xr.Dataset or dict
104
        Dictionary of xarray objects with a pandas DatetimeIndex
105
    """
106
    if isinstance(xr_objs, dict):
×
107
        for name, obj in xr_objs.items():
×
108
            if "time" in obj.dims:
×
109
                if isinstance(obj.get_index("time"), xr.CFTimeIndex):
×
110
                    conv_obj = obj.convert_calendar(
×
111
                        "standard", use_cftime=None, align_on="year"
112
                    )
113
                    xr_objs[name] = conv_obj
×
114
                    warnings.warn(
×
115
                        "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
116
                    )
117

118
    else:
119
        if "time" in xr_objs.dims:
×
120
            if isinstance(xr_objs.get_index("time"), xr.CFTimeIndex):
×
121
                conv_obj = xr_objs.convert_calendar(
×
122
                    "standard", use_cftime=None, align_on="year"
123
                )
124
                xr_objs = conv_obj
×
125
                warnings.warn(
×
126
                    "CFTimeIndex converted to pandas DatetimeIndex with a 'standard' calendar.", stacklevel=2
127
                )
128

129
    return xr_objs
×
130

131

132
def get_array_categ(array: xr.DataArray | xr.Dataset) -> str:
5✔
133
    """
134
    Get an array category, which determines how to plot the array.
135

136
    Parameters
137
    ----------
138
    array : Dataset or DataArray
139
        The array being categorized.
140

141
    Returns
142
    -------
143
    str
144
        ENS_PCT_VAR_DS: ensemble percentiles stored as variables
145
        ENS_PCT_DIM_DA: ensemble percentiles stored as dimension coordinates, DataArray
146
        ENS_PCT_DIM_DS: ensemble percentiles stored as dimension coordinates, DataSet
147
        ENS_STATS_VAR_DS: ensemble statistics (min, mean, max) stored as variables
148
        ENS_REALS_DA: ensemble with 'realization' dim, as DataArray
149
        ENS_REALS_DS: ensemble with 'realization' dim, as Dataset
150
        DS: any Dataset that is not  recognized as an ensemble
151
        DA: DataArray
152
    """
153
    if isinstance(array, xr.Dataset):
×
154
        if (
×
155
            pd.notnull(
156
                [re.search("_p[0-9]{1,2}", var) for var in array.data_vars]
157
            ).sum()
158
            >= 2
159
        ):
160
            cat = "ENS_PCT_VAR_DS"
×
161
        elif (
×
162
            pd.notnull(
163
                [re.search("_[Mm]ax|_[Mm]in", var) for var in array.data_vars]
164
            ).sum()
165
            >= 2
166
        ):
167
            cat = "ENS_STATS_VAR_DS"
×
168
        elif "percentiles" in array.dims:
×
169
            cat = "ENS_PCT_DIM_DS"
×
170
        elif "realization" in array.dims:
×
171
            cat = "ENS_REALS_DS"
×
172
        else:
173
            cat = "DS"
×
174

175
    elif isinstance(array, xr.DataArray):
×
176
        if "percentiles" in array.dims:
×
177
            cat = "ENS_PCT_DIM_DA"
×
178
        elif "realization" in array.dims:
×
179
            cat = "ENS_REALS_DA"
×
180
        else:
181
            cat = "DA"
×
182
    else:
183
        raise TypeError("Array is not an Xarray Dataset or DataArray")
×
184

185
    return cat
×
186

187

188
def get_attributes(
5✔
189
    string: str, xr_obj: xr.DataArray | xr.Dataset, locale: str | None = None
190
) -> str:
191
    """
192
    Fetch attributes or dims corresponding to keys from Xarray objects.
193

194
    Searches DataArray attributes first, then the first variable (DataArray) of the Dataset, then Dataset attributes.
195
    If a locale is activated in xclim's options or a locale is passed, a localized version is given if available.
196

197
    Parameters
198
    ----------
199
    string : str
200
        String corresponding to an attribute name.
201
    xr_obj : DataArray or Dataset
202
        The Xarray object containing the attributes.
203
    locale : str, optional
204
        A 2-letter locale name to translate to.
205
        Default is None, which will pull the locale
206
        from xclim's "metadata_locales" option (taking the first).
207

208
    Returns
209
    -------
210
    str
211
        Xarray attribute value as string or empty string if not found
212
    """
213
    locale = locale or (XC_OPTIONS[METADATA_LOCALES] or ["en"])[0]
×
214
    if locale != "en":
×
215
        names = [f"{string}_{locale}", string]
×
216
    else:
217
        names = [string]
×
218

219
    for name in names:
×
220
        if isinstance(xr_obj, xr.DataArray) and name in xr_obj.attrs:
×
221
            return xr_obj.attrs[name]
×
222

223
        if (
×
224
            isinstance(xr_obj, xr.Dataset)
225
            and name in xr_obj[list(xr_obj.data_vars)[0]].attrs
226
        ):  # DataArray of first variable
227
            return xr_obj[list(xr_obj.data_vars)[0]].attrs[name]
×
228

229
        if isinstance(xr_obj, xr.Dataset) and name in xr_obj.attrs:
×
230
            return xr_obj.attrs[name]
×
231

NEW
232
    warnings.warn(f'Attribute "{string}" not found.', stacklevel=2)
×
233
    return ""
×
234

235

236
def set_plot_attrs(
5✔
237
    attr_dict: dict[str, Any],
238
    xr_obj: xr.DataArray | xr.Dataset,
239
    ax: matplotlib.axes.Axes | None = None,
240
    title_loc: str = "center",
241
    facetgrid: seaborn.axisgrid.FacetGrid | None = None,
242
    wrap_kw: dict[str, Any] | None = None,
243
) -> matplotlib.axes.Axes:
244
    """
245
    Set plot elements according to Dataset or DataArray attributes.
246

247
    Uses get_attributes() to check for and get the string.
248

249
    Parameters
250
    ----------
251
    attr_dict : dict
252
        Dictionary containing specified attribute keys.
253
    xr_obj : Dataset or DataArray
254
        The Xarray object containing the attributes.
255
    ax : matplotlib axis
256
        The matplotlib axis of the plot.
257
    title_loc : str
258
        Location of the title.
259
    wrap_kw : dict, optional
260
        Arguments to pass to the wrap_text function for the title.
261

262
    Returns
263
    -------
264
    matplotlib.axes.Axes
265
    """
266
    wrap_kw = empty_dict(wrap_kw)
×
267

268
    #  check
269
    for key in attr_dict:
×
270
        if key not in [
×
271
            "title",
272
            "ylabel",
273
            "yunits",
274
            "xlabel",
275
            "xunits",
276
            "cbar_label",
277
            "cbar_units",
278
            "suptitle",
279
        ]:
NEW
280
            warnings.warn(f'Use_attrs element "{key}" not supported', stacklevel=2)
×
281

282
    if "title" in attr_dict:
×
283
        title = get_attributes(attr_dict["title"], xr_obj)
×
284
        ax.set_title(wrap_text(title, **wrap_kw), loc=title_loc)
×
285

286
    if "ylabel" in attr_dict:
×
287
        if (
×
288
            "yunits" in attr_dict
289
            and len(get_attributes(attr_dict["yunits"], xr_obj)) >= 1
290
        ):  # second condition avoids '[]' as label
291
            ylabel = wrap_text(
×
292
                get_attributes(attr_dict["ylabel"], xr_obj)
293
                + " ("
294
                + get_attributes(attr_dict["yunits"], xr_obj)
295
                + ")"
296
            )
297
        else:
298
            ylabel = wrap_text(get_attributes(attr_dict["ylabel"], xr_obj))
×
299

300
        ax.set_ylabel(ylabel)
×
301

302
    if "xlabel" in attr_dict:
×
303
        if (
×
304
            "xunits" in attr_dict
305
            and len(get_attributes(attr_dict["xunits"], xr_obj)) >= 1
306
        ):  # second condition avoids '[]' as label
307
            xlabel = wrap_text(
×
308
                get_attributes(attr_dict["xlabel"], xr_obj)
309
                + " ("
310
                + get_attributes(attr_dict["xunits"], xr_obj)
311
                + ")"
312
            )
313
        else:
314
            xlabel = wrap_text(get_attributes(attr_dict["xlabel"], xr_obj))
×
315

316
        ax.set_xlabel(xlabel)
×
317

318
    # cbar label has to be assigned in main function, ignore.
319
    if "cbar_label" in attr_dict:
×
320
        pass
×
321

322
    if "cbar_units" in attr_dict:
×
323
        pass
×
324

325
    if facetgrid:
×
326
        if "suptitle" in attr_dict:
×
327
            suptitle = get_attributes(attr_dict["suptitle"], xr_obj)
×
328
            facetgrid.fig.suptitle(suptitle, y=1.05)
×
329
            facetgrid.set_titles(template="{value}")
×
330
        return facetgrid
×
331

332
    else:
333
        return ax
×
334

335

336
def get_suffix(string: str) -> str:
5✔
337
    """Get suffix of typical Xclim variable names."""
338
    if re.search("[0-9]{1,2}$|_[Mm]ax$|_[Mm]in$|_[Mm]ean$", string):
×
339
        suffix = re.search("[0-9]{1,2}$|[Mm]ax$|[Mm]in$|[Mm]ean$", string).group()
×
340
        return suffix
×
341
    else:
342
        raise ValueError(f"Mean, min or max not found in {string}")
×
343

344

345
def sort_lines(array_dict: dict[str, Any]) -> dict[str, str]:
5✔
346
    """
347
    Label arrays as 'middle', 'upper' and 'lower' for ensemble plotting.
348

349
    Parameters
350
    ----------
351
    array_dict : dict
352
        Dictionary of format {'name': array...}.
353

354
    Returns
355
    -------
356
    dict
357
        Dictionary of {'middle': 'name', 'upper': 'name', 'lower': 'name'}.
358
    """
359
    if len(array_dict) != 3:
×
360
        raise ValueError("Ensembles must contain exactly three arrays")
×
361

362
    sorted_lines = {}
×
363

364
    for name in array_dict.keys():
×
365
        suffix = get_suffix(name)
×
366

367
        if suffix.isalpha():
×
368
            if suffix in ["max", "Max"]:
×
369
                sorted_lines["upper"] = name
×
370
            elif suffix in ["min", "Min"]:
×
371
                sorted_lines["lower"] = name
×
372
            elif suffix in ["mean", "Mean"]:
×
373
                sorted_lines["middle"] = name
×
374
        elif suffix.isdigit():
×
375
            if int(suffix) >= 51:
×
376
                sorted_lines["upper"] = name
×
377
            elif int(suffix) <= 49:
×
378
                sorted_lines["lower"] = name
×
379
            elif int(suffix) == 50:
×
380
                sorted_lines["middle"] = name
×
381
        else:
382
            raise ValueError('Arrays names must end in format "_mean" or "_p50" ')
×
383
    return sorted_lines
×
384

385

386
def loc_mpl(
5✔
387
    loc: str | tuple[int | float, int | float] | int,
388
) -> tuple[tuple[float, float], tuple[int | float, int | float], str, str]:
389
    """
390
    Find coordinates and alignment associated to loc string.
391

392
    Parameters
393
    ----------
394
    loc : string, int, or tuple[float, float]
395
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
396
        If a tuple, must be in axes coordinates.
397

398
    Returns
399
    -------
400
    tuple(float, float), tuple(float, float), str, str
401
    """
402
    ha = "left"
×
403
    va = "bottom"
×
404

405
    loc_strings = [
×
406
        "upper right",
407
        "upper left",
408
        "lower left",
409
        "lower right",
410
        "right",
411
        "center left",
412
        "center right",
413
        "lower center",
414
        "upper center",
415
        "center",
416
    ]
417

418
    if isinstance(loc, int):
×
419
        try:
×
420
            loc = loc_strings[loc - 1]
×
NEW
421
        except IndexError as err:
×
NEW
422
            raise ValueError("loc must be between 1 and 10, inclusively") from err
×
423

424
    if loc in loc_strings:
×
425
        # ha
426
        if "left" in loc:
×
427
            ha = "left"
×
428
        elif "right" in loc:
×
429
            ha = "right"
×
430
        else:
431
            ha = "center"
×
432

433
        # va
434
        if "lower" in loc:
×
435
            va = "bottom"
×
436
        elif "upper" in loc:
×
437
            va = "top"
×
438
        else:
439
            va = "center"
×
440

441
        # transAxes
442
        if loc == "upper right":
×
443
            loc = (0.97, 0.97)
×
444
            box_a = (1, 1)
×
445
        elif loc == "upper left":
×
446
            loc = (0.03, 0.97)
×
447
            box_a = (0, 1)
×
448
        elif loc == "lower left":
×
449
            loc = (0.03, 0.03)
×
450
            box_a = (0, 0)
×
451
        elif loc == "lower right":
×
452
            loc = (0.97, 0.03)
×
453
            box_a = (1, 0)
×
454
        elif loc == "right":
×
455
            loc = (0.97, 0.5)
×
456
            box_a = (1, 0.5)
×
457
        elif loc == "center left":
×
458
            loc = (0.03, 0.5)
×
459
            box_a = (0, 0.5)
×
460
        elif loc == "center right":
×
461
            loc = (0.97, 0.5)
×
462
            box_a = (0.97, 0.5)
×
463
        elif loc == "lower center":
×
464
            loc = (0.5, 0.03)
×
465
            box_a = (0.5, 0)
×
466
        elif loc == "upper center":
×
467
            loc = (0.5, 0.97)
×
468
            box_a = (0.5, 1)
×
469
        else:
470
            loc = (0.5, 0.5)
×
471
            box_a = (0.5, 0.5)
×
472

473
    elif isinstance(loc, tuple):
×
474
        box_a = []
×
475
        for i in loc:
×
476
            if i > 1 or i < 0:
×
477
                raise ValueError(
×
478
                    "Text location coordinates must be between 0 and 1, inclusively"
479
                )
480
            elif i > 0.5:
×
481
                box_a.append(1)
×
482
            else:
483
                box_a.append(0)
×
484
        box_a = tuple(box_a)
×
485
    else:
486
        raise ValueError(
×
487
            "loc must be a string, int or tuple. "
488
            "See https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html"
489
        )
490

491
    return loc, box_a, ha, va
×
492

493

494
def plot_coords(
5✔
495
    ax: matplotlib.axes.Axes | None,
496
    xr_obj: xr.DataArray | xr.Dataset,
497
    loc: str | tuple[float, float] | int,
498
    param: str | None = None,
499
    backgroundalpha: float = 1,
500
) -> matplotlib.axes.Axes:
501
    """
502
    Place coordinates on plot area.
503

504
    Parameters
505
    ----------
506
    ax : matplotlib.axes.Axes or None
507
        Matplotlib axes object on which to place the text.
508
        If None, will use plt.figtext instead (should be used for facetgrids).
509
    xr_obj : xr.DataArray or xr.Dataset
510
        The xarray object from which to fetch the text content.
511
    param : {"location", "time"}, optional
512
        The parameter used.
513
    loc : string, int or tuple
514
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
515
        If a tuple, must be in axes coordinates.
516
    backgroundalpha : float
517
        Transparency of the text background. 1 is opaque, 0 is transparent.
518

519
    Returns
520
    -------
521
    matplotlib.axes.Axes
522
    """
523
    text = None
×
524
    if param == "location":
×
525
        if "lat" in xr_obj.coords and "lon" in xr_obj.coords:
×
526
            text = "lat={:.2f}, lon={:.2f}".format(
×
527
                float(xr_obj["lat"]), float(xr_obj["lon"])
528
            )
529
        else:
530
            warnings.warn(
×
531
                'show_lat_lon set to True, but "lat" and/or "lon" not found in coords', stacklevel=2
532
            )
533
    if param == "time":
×
534
        if "time" in xr_obj.coords:
×
535
            text = str(xr_obj.time.dt.strftime("%Y-%m-%d").values)
×
536

537
        else:
NEW
538
            warnings.warn('show_time set to True, but "time" not found in coords', stacklevel=2)
×
539

540
    loc, box_a, ha, va = loc_mpl(loc)
×
541

542
    if text:
×
543
        if ax:
×
544
            t = mpl.offsetbox.TextArea(
×
545
                text, textprops=dict(transform=ax.transAxes, ha=ha, va=va)
546
            )
547

548
            tt = mpl.offsetbox.AnnotationBbox(
×
549
                t,
550
                loc,
551
                xycoords="axes fraction",
552
                box_alignment=box_a,
553
                pad=0.05,
554
                bboxprops=dict(
555
                    facecolor="white",
556
                    alpha=backgroundalpha,
557
                    edgecolor="w",
558
                    boxstyle="Square, pad=0.5",
559
                ),
560
            )
561
            ax.add_artist(tt)
×
562
            return ax
×
563
        elif not ax:
×
564
            """
×
565
            if loc == "top left":
566
                plt.figtext(0.8, 1.025, text, ha="center", fontsize=12)
567
            elif loc == "top right":
568
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
569
            elif loc == "bottom left":
570
                plt.figtext(0.2, -0.075, text, ha="center", fontsize=12)
571
            elif loc == "bottom right" or loc is True:
572
                plt.figtext(0.8, -0.075, text, ha="center", fontsize=12)
573
            elif isinstance(loc, tuple):
574
                        else:
575
                raise ValueError(
576
                    f"{loc} option does not work with facetgrids. Try 'top left', ''top right', 'bottom left', "
577
                    f"'bottom right' or a tuple of coordinates."
578
                )
579
            """
580
            plt.figtext(
×
581
                loc[0],
582
                loc[1],
583
                text,
584
                ha=ha,
585
                va=va,
586
                fontsize=12,
587
            )
588

589
            return None
×
590

591

592
def find_logo(logo: str | pathlib.Path) -> str:
5✔
593
    """Read a logo file."""
594
    logos = Logos()
×
595
    if logo:
×
596
        logo_path = logos[logo]
×
597
    else:
598
        logo_path = logos.default
×
599

600
    if logo_path is None:
×
601
        raise ValueError(
×
602
            "No logo found. Please install one with the figanos.Logos().set_logo() method."
603
        )
604
    return logo_path
×
605

606

607
def load_image(
5✔
608
    im: str | pathlib.Path,
609
    height: float | None,
610
    width: float | None,
611
    keep_ratio: bool = True,
612
) -> np.ndarray:
613
    """
614
    Scale an image to a specified height and width.
615

616
    Parameters
617
    ----------
618
    im : str or Path
619
        The image to be scaled. PNG and SVG formats are supported.
620
    height : float, optional
621
        The desired height of the image. If None, the original height is used.
622
    width : float, optional
623
        The desired width of the image. If None, the original width is used.
624
    keep_ratio : bool
625
        If True, the aspect ratio of the original image is maintained. Default is True.
626

627
    Returns
628
    -------
629
    np.ndarray
630
        The scaled image.
631
    """
632
    if pathlib.Path(im).suffix == ".png":
×
633
        image = mpl.pyplot.imread(im)
×
634
        original_height, original_width = image.shape[:2]
×
635

636
        if height is None and width is None:
×
637
            return image
×
638

639
        warnings.warn(
×
640
            "The scikit-image library is used to resize PNG images. This may affect logo image quality.", stacklevel=2
641
        )
642
        if not keep_ratio:
×
643
            height = original_height or height
×
644
            width = original_width or width
×
645
        else:
646
            if width is not None:
×
647
                if height is not None:
×
NEW
648
                    warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
649
                # Only width is provided, derive zoom factor for height based on aspect ratio
650
                height = (width / original_width) * original_height
×
651
            elif height is not None:
×
652
                # Only height is provided, derive zoom factor for width based on aspect ratio
653
                width = (height / original_height) * original_width
×
654

655
        return resize(image, (height, width, image.shape[2]), anti_aliasing=True)
×
656

657
    elif pathlib.Path(im).suffix == ".svg":
×
658
        cairo_kwargs = dict(url=im)
×
659
        if not keep_ratio:
×
660
            if height is not None and width is not None:
×
661
                cairo_kwargs.update(output_height=height, output_width=width)
×
662
        elif width is not None:
×
663
            if height is not None:
×
NEW
664
                warnings.warn("Both height and width provided, using height.", stacklevel=2)
×
665
            cairo_kwargs.update(output_width=width)
×
666
        elif height is not None:
×
667
            cairo_kwargs.update(output_height=height)
×
668

669
        with NamedTemporaryFile(suffix=".png") as png_file:
×
670
            cairo_kwargs.update(write_to=png_file.name)
×
671
            cairosvg.svg2png(**cairo_kwargs)
×
672
            return mpl.pyplot.imread(png_file.name)
×
673

674

675
def plot_logo(
5✔
676
    ax: matplotlib.axes.Axes,
677
    loc: str | tuple[float, float] | int,
678
    logo: str | pathlib.Path | Logos | None = None,
679
    height: float | None = None,
680
    width: float | None = None,
681
    keep_ratio: bool = True,
682
    **offset_image_kwargs,
683
) -> matplotlib.axes.Axes:
684
    r"""
685
    Place logo of plot area.
686

687
    Parameters
688
    ----------
689
    ax : matplotlib.axes.Axes
690
        Matplotlib axes object on which to place the text.
691
    loc : string, int or tuple
692
        Location of text, replicating https://matplotlib.org/stable/api/_as_gen/matplotlib.pyplot.legend.html.
693
        If a tuple, must be in axes coordinates.
694
    logo : str, Path, figanos.Logos, optional
695
        A name (str) or Path to a logo file, or a name of an already-installed logo.
696
        If an existing is not found, the logo will be installed and accessible via the filename.
697
        The default logo is the Figanos logo. To install the Ouranos (or another) logo consult the Usage page.
698
        The logo must be in 'png' format.
699
    height : float, optional
700
        The desired height of the image. If None, the original height is used.
701
    width : float, optional
702
        The desired width of the image. If None, the original width is used.
703
    keep_ratio : bool, optional
704
        If True, the aspect ratio of the original image is maintained. Default is True.
705
    \*\*offset_image_kwargs
706
        Arguments to pass to matplotlib.offsetbox.OffsetImage().
707

708
    Returns
709
    -------
710
    matplotlib.axes.Axes
711
    """
712
    if offset_image_kwargs is None:
×
713
        offset_image_kwargs = {}
×
714

715
    if isinstance(logo, Logos):
×
716
        logo_path = logo.default
×
717
    else:
718
        logo_path = find_logo(logo)
×
719

720
    image = load_image(logo_path, height, width, keep_ratio)
×
721
    imagebox = mpl.offsetbox.OffsetImage(image, **offset_image_kwargs)
×
722

723
    loc, box_a, ha, va = loc_mpl(loc)
×
724
    ab = mpl.offsetbox.AnnotationBbox(
×
725
        imagebox,
726
        loc,
727
        frameon=False,
728
        xycoords="axes fraction",
729
        box_alignment=box_a,
730
        pad=0.05,
731
    )
732
    ax.add_artist(ab)
×
733
    return ax
×
734

735

736
def split_legend(
5✔
737
    ax: matplotlib.axes.Axes,
738
    in_plot: bool = False,
739
    axis_factor: float = 0.15,
740
    label_gap: float = 0.02,
741
) -> matplotlib.axes.Axes:
742
    #  TODO: check for and fix overlapping labels
743
    """
744
    Draw line labels at the end of each line, or outside the plot.
745

746
    Parameters
747
    ----------
748
    ax : matplotlib.axes.Axes
749
        The axis containing the legend.
750
    in_plot : bool
751
        If True, prolong plot area to fit labels. If False, print labels outside of plot area. Default: False.
752
    axis_factor : float
753
        If in_plot is True, fraction of the x-axis length to add at the far right of the plot. Default: 0.15.
754
    label_gap : float
755
        If in_plot is True, fraction of the x-axis length to add as a gap between line and label. Default: 0.02.
756

757
    Returns
758
    -------
759
    matplotlib.axes.Axes
760
    """
761
    # create extra space
762
    init_xbound = ax.get_xbound()
×
763

764
    ax_bump = (init_xbound[1] - init_xbound[0]) * axis_factor
×
765
    label_bump = (init_xbound[1] - init_xbound[0]) * label_gap
×
766

767
    if in_plot is True:
×
768
        ax.set_xbound(lower=init_xbound[0], upper=init_xbound[1] + ax_bump)
×
769

770
    # get legend and plot
771

772
    handles, labels = ax.get_legend_handles_labels()
×
NEW
773
    for handle, label in zip(handles, labels, strict=False):
×
774
        last_x = handle.get_xdata()[-1]
×
775
        last_y = handle.get_ydata()[-1]
×
776

777
        if isinstance(last_x, np.datetime64):
×
778
            last_x = mpl.dates.date2num(last_x)
×
779

780
        color = handle.get_color()
×
781
        # ls = handle.get_linestyle()
782

783
        if in_plot is True:
×
784
            ax.text(
×
785
                last_x + label_bump,
786
                last_y,
787
                label,
788
                ha="left",
789
                va="center",
790
                color=color,
791
            )
792
        else:
793
            trans = mpl.transforms.blended_transform_factory(ax.transAxes, ax.transData)
×
794
            ax.text(
×
795
                1.01,
796
                last_y,
797
                label,
798
                ha="left",
799
                va="center",
800
                color=color,
801
                transform=trans,
802
            )
803

804
    return ax
×
805

806

807
def fill_between_label(
5✔
808
    sorted_lines: dict[str, Any],
809
    name: str,
810
    array_categ: dict[str, Any],
811
    legend: str,
812
) -> str:
813
    """
814
    Create a label for the shading around a line in line plots.
815

816
    Parameters
817
    ----------
818
    sorted_lines : dict
819
        Dictionary created by the sort_lines() function.
820
    name : str
821
        Key associated with the object being plotted in the 'data' argument of the timeseries() function.
822
    array_categ : dict
823
        The categories of the array, as created by the get_array_categ function.
824
    legend : str
825
        Legend mode.
826

827
    Returns
828
    -------
829
    str
830
        Label to be applied to the legend element representing the shading.
831
    """
832
    if legend != "full":
×
833
        label = None
×
834
    elif array_categ[name] in [
×
835
        "ENS_PCT_VAR_DS",
836
        "ENS_PCT_DIM_DS",
837
        "ENS_PCT_DIM_DA",
838
    ]:
839
        label = get_localized_term("{}th-{}th percentiles").format(
×
840
            get_suffix(sorted_lines["lower"]), get_suffix(sorted_lines["upper"])
841
        )
842
    elif array_categ[name] == "ENS_STATS_VAR_DS":
×
843
        label = get_localized_term("min-max range")
×
844
    else:
845
        label = None
×
846

847
    return label
×
848

849

850
def get_var_group(
5✔
851
    path_to_json: str | pathlib.Path,
852
    da: xr.DataArray | None = None,
853
    unique_str: str | None = None,
854
) -> str:
855
    """
856
    Get IPCC variable group from DataArray or a string using a json file (figanos/data/ipcc_colors/variable_groups.json).
857

858
    If `da` is a Dataset, look in the DataArray of the first variable.
859
    """
860
    # create dict
861
    with pathlib.Path(path_to_json).open(encoding="utf-8") as _f:
×
862
        var_dict = json.load(_f)
×
863

864
    matches = []
×
865

866
    if unique_str:
×
867
        for v in var_dict:
×
868
            regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"  # matches when variable is not inside word
×
869
            if re.search(regex, unique_str):
×
870
                matches.append(var_dict[v])
×
871

872
    else:
873
        if isinstance(da, xr.Dataset):
×
874
            da = da[list(da.data_vars)[0]]
×
875
        # look in DataArray name
876
        if hasattr(da, "name") and isinstance(da.name, str):
×
877
            for v in var_dict:
×
878
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
879
                if re.search(regex, da.name):
×
880
                    matches.append(var_dict[v])
×
881

882
        # look in history
883
        if hasattr(da, "history") and len(matches) == 0:
×
884
            for v in var_dict:
×
885
                regex = rf"(?:^|[^a-zA-Z])({v})(?:[^a-zA-Z]|$)"
×
886
                if re.search(regex, da.history):
×
887
                    matches.append(var_dict[v])
×
888

889
    matches = np.unique(matches)
×
890

891
    if len(matches) == 0:
×
892
        warnings.warn(
×
893
            "Colormap warning: Variable group not found. Use the cmap argument.", stacklevel=2
894
        )
895
        return "misc"
×
896
    elif len(matches) >= 2:
×
897
        warnings.warn(
×
898
            "Colormap warning: More than one variable group found. Use the cmap argument.", stacklevel=2
899
        )
900
        return "misc"
×
901
    else:
902
        return matches[0]
×
903

904

905
def create_cmap(
5✔
906
    var_group: str | None = None,
907
    divergent: bool | int = False,
908
    filename: str | None = None,
909
) -> matplotlib.colors.Colormap:
910
    """
911
    Create colormap according to variable group.
912

913
    Parameters
914
    ----------
915
    var_group : str, optional
916
        Variable group from IPCC scheme.
917
    divergent : bool or int
918
        Diverging colormap. If False, use sequential colormap.
919
    filename : str, optional
920
        Name of IPCC colormap file. If not None, 'var_group' and 'divergent' are not used.
921

922
    Returns
923
    -------
924
    matplotlib.colors.Colormap
925
    """
926
    reverse = False
×
927

928
    if filename:
×
929
        folder = "continuous_colormaps_rgb_0-255"
×
930
        filename = filename.replace(".txt", "")
×
931

932
        if filename.endswith("_r"):
×
933
            reverse = True
×
934
            filename = filename[:-2]
×
935

936
    else:
937
        # filename
938
        if divergent is not False:
×
939
            if var_group == "misc2":
×
940
                var_group = "misc"
×
941
            filename = var_group + "_div"
×
942
        else:
943
            if var_group == "misc":
×
944
                filename = var_group + "_seq_3"  # Batlow
×
945
            elif var_group == "misc2":
×
946
                filename = "misc_seq_2"  # freezing rain
×
947
            else:
948
                filename = var_group + "_seq"
×
949

950
        folder = "continuous_colormaps_rgb_0-255"
×
951

952
    # parent should be 'figanos/'
953
    path = (
×
954
        pathlib.Path(__file__).parents[1]
955
        / "data"
956
        / "ipcc_colors"
957
        / folder
958
        / (filename + ".txt")
959
    )
960

961
    rgb_data = np.loadtxt(path)
×
962

963
    # convert to 0-1 RGB
964
    rgb_data = rgb_data / 255
×
965

966
    cmap = mcolors.LinearSegmentedColormap.from_list("cmap", rgb_data, N=256)
×
967
    if reverse is True:
×
968
        cmap = cmap.reversed()
×
969

970
    return cmap
×
971

972

973
def get_rotpole(xr_obj: xr.DataArray | xr.Dataset) -> ccrs.RotatedPole | None:
5✔
974
    """
975
    Create a Cartopy crs rotated pole projection/transform from DataArray or Dataset attributes.
976

977
    Parameters
978
    ----------
979
    xr_obj : xr.DataArray or xr.Dataset
980
        The xarray object from which to look for the attributes.
981

982
    Returns
983
    -------
984
    ccrs.RotatedPole or None
985
    """
986
    try:
×
987

988
        if isinstance(xr_obj, xr.Dataset):
×
989
            gridmap = xr_obj.cf.grid_mapping_names.get("rotated_latitude_longitude", [])
×
990

991
            if len(gridmap) > 1:
×
992
                warnings.warn(
×
993
                    f"There are conflicting grid_mapping attributes in the dataset. Assuming {gridmap[0]}.", stacklevel=2
994
                )
995

996
            coord_name = gridmap[0] if gridmap else "rotated_pole"
×
997
        else:
998
            # If it can't find grid_mapping, assume it's rotated_pole
999
            coord_name = xr_obj.attrs.get("grid_mapping", "rotated_pole")
×
1000

1001
        rotpole = ccrs.RotatedPole(
×
1002
            pole_longitude=xr_obj[coord_name].grid_north_pole_longitude,
1003
            pole_latitude=xr_obj[coord_name].grid_north_pole_latitude,
1004
            central_rotated_longitude=xr_obj[coord_name].north_pole_grid_longitude,
1005
        )
1006
        return rotpole
×
1007

1008
    except AttributeError:
×
NEW
1009
        warnings.warn("Rotated pole not found. Specify a transform if necessary.", stacklevel=2)
×
1010
        return None
×
1011

1012

1013
def wrap_text(text: str, min_line_len: int = 18, max_line_len: int = 30) -> str:
5✔
1014
    """
1015
    Wrap text.
1016

1017
    Parameters
1018
    ----------
1019
    text : str
1020
        The text to wrap.
1021
    min_line_len : int
1022
        Minimum length of each line.
1023
    max_line_len : int
1024
        Maximum length of each line.
1025

1026
    Returns
1027
    -------
1028
    str
1029
        Wrapped text
1030
    """
1031
    start = min_line_len
×
1032
    stop = max_line_len
×
1033
    sep = "\n"
×
1034
    remaining = len(text)
×
1035

1036
    if len(text) >= max_line_len:
×
1037
        while remaining > max_line_len:
×
1038
            if ". " in text[start:stop]:
×
1039
                pos = text.find(". ", start, stop) + 1
×
1040
            elif ": " in text[start:stop]:
×
1041
                pos = text.find(": ", start, stop) + 1
×
1042
            elif " " in text[start:stop]:
×
1043
                pos = text.rfind(" ", start, stop)
×
1044
            else:
NEW
1045
                warnings.warn("No spaces, points or colons to break line at.", stacklevel=2)
×
1046
                break
×
1047

1048
            text = sep.join([text[:pos], text[pos + 1 :]])
×
1049

1050
            remaining = len(text) - len(text[:pos])
×
1051
            start = pos + 1 + min_line_len
×
1052
            stop = pos + 1 + max_line_len
×
1053

1054
    return text
×
1055

1056

1057
def gpd_to_ccrs(df: gpd.GeoDataFrame, proj: ccrs.CRS) -> gpd.GeoDataFrame:
5✔
1058
    """
1059
    Open shapefile with geopandas and convert to cartopy projection.
1060

1061
    Parameters
1062
    ----------
1063
    df : gpd.GeoDataFrame
1064
        GeoDataFrame (geopandas) geometry to be added to axis.
1065
    proj : ccrs.CRS
1066
        Projection to use, taken from the cartopy.crs options.
1067

1068
    Returns
1069
    -------
1070
    gpd.GeoDataFrame
1071
        GeoDataFrame adjusted to given projection
1072
    """
1073
    prj4 = proj.proj4_init
×
1074
    return df.to_crs(prj4)
×
1075

1076

1077
def convert_scen_name(name: str) -> str:
5✔
1078
    """Convert strings containing SSP, RCP or CMIP to their proper format."""
1079
    matches = re.findall(r"(?:SSP|RCP|CMIP)[0-9]{1,3}", name, flags=re.I)
×
1080
    if matches:
×
1081
        for s in matches:
×
1082
            if sum(c.isdigit() for c in s) == 3:
×
1083
                new_s = s.replace(
×
1084
                    s[-3:], s[-3] + "-" + s[-2] + "." + s[-1]
1085
                ).upper()  # ssp245 to SSP2-4.5
1086
                new_name = name.replace(s, new_s)  # put back in name
×
1087
            elif sum(c.isdigit() for c in s) == 2:
×
1088
                new_s = s.replace(
×
1089
                    s[-2:], s[-2] + "." + s[-1]
1090
                ).upper()  # rcp45 to RCP4.5
1091
                new_name = name.replace(s, new_s)
×
1092
            else:
1093
                new_s = s.upper()  # cmip5 to CMIP5
×
1094
                new_name = name.replace(s, new_s)
×
1095

1096
        return new_name
×
1097
    else:
1098
        return name
×
1099

1100

1101
def get_scen_color(name: str, path_to_dict: str | pathlib.Path) -> str:
5✔
1102
    """Get color corresponding to SSP,RCP, model or CMIP substring from a dictionary."""
1103
    with pathlib.Path(path_to_dict).open(encoding="utf-8") as _f:
×
1104
        color_dict = json.load(_f)
×
1105

1106
    color = None
×
1107
    for entry in color_dict:
×
1108
        if entry in name:
×
1109
            color = color_dict[entry]
×
1110
            color = tuple([i / 255 for i in color])
×
1111
            break
×
1112

1113
    return color
×
1114

1115

1116
def process_keys(dct: dict[str, Any], func: Callable) -> dict[str, Any]:
5✔
1117
    """Apply function to dictionary keys."""
1118
    old_keys = [key for key in dct]
×
1119
    for old_key in old_keys:
×
1120
        new_key = func(old_key)
×
1121
        dct[new_key] = dct.pop(old_key)
×
1122
    return dct
×
1123

1124

1125
def categorical_colors() -> dict[str, str]:
5✔
1126
    """Get a list of the categorical colors associated with certain substrings (SSP,RCP,CMIP)."""
1127
    path = (
×
1128
        pathlib.Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
1129
    )
1130
    with path.open(encoding="utf-8") as _f:
×
1131
        cat = json.load(_f)
×
1132

1133
        return cat
×
1134

1135

1136
def get_mpl_styles() -> dict[str, pathlib.Path]:
5✔
1137
    """Get the available matplotlib styles and their paths as a dictionary."""
1138
    files = sorted(pathlib.Path(__file__).parent.joinpath("style").glob("*.mplstyle"))
×
1139
    styles = {style.stem: style for style in files}
×
1140
    return styles
×
1141

1142

1143
def set_mpl_style(*args: str, reset: bool = False) -> None:
5✔
1144
    """
1145
    Set the matplotlib style using one or more stylesheets.
1146

1147
    Parameters
1148
    ----------
1149
    args : str
1150
        Name(s) of figanos matplotlib style ('ouranos', 'paper, 'poster') or path(s) to matplotlib stylesheet(s).
1151
    reset : bool
1152
        If True, reset style to matplotlib default before applying the stylesheets.
1153

1154
    Returns
1155
    -------
1156
    None
1157
    """
1158
    if reset is True:
×
1159
        mpl.style.use("default")
×
1160
    for s in args:
×
1161
        if s.endswith(".mplstyle") is True:
×
1162
            mpl.style.use(s)
×
1163
        elif s in get_mpl_styles():
×
1164
            mpl.style.use(get_mpl_styles()[s])
×
1165
        else:
NEW
1166
            warnings.warn(f"Style {s} not found.", stacklevel=2)
×
1167

1168

1169
def add_cartopy_features(
5✔
1170
    ax: matplotlib.axes.Axes, features: list[str] | dict[str, dict[str, Any]]
1171
) -> matplotlib.axes.Axes:
1172
    """
1173
    Add cartopy features to matplotlib axes.
1174

1175
    Parameters
1176
    ----------
1177
    ax : matplotlib.axes.Axes
1178
        The axes on which to add the features.
1179
    features : list or dict
1180
        List of features, or nested dictionary of format {'feature': {'kwarg':'value'}}
1181

1182
    Returns
1183
    -------
1184
    matplotlib.axes.Axes
1185
        The axis with added features.
1186
    """
1187
    if isinstance(features, list):
×
1188
        features = {f: {} for f in features}
×
1189

1190
    for feat in features:
×
1191
        if "scale" not in features[feat]:
×
1192
            ax.add_feature(getattr(cfeature, feat.upper()), **features[feat])
×
1193
        else:
1194
            scale = features[feat].pop("scale")
×
1195
            ax.add_feature(
×
1196
                getattr(cfeature, feat.upper()).with_scale(scale),
1197
                **features[feat],
1198
            )
1199
            features[feat]["scale"] = scale  # put back
×
1200
    return ax
×
1201

1202

1203
def custom_cmap_norm(
5✔
1204
    cmap,
1205
    vmin: int | float,
1206
    vmax: int | float,
1207
    levels: int | list[int | float] | None = None,
1208
    divergent: bool | int | float = False,
1209
    linspace_out: bool = False,
1210
) -> matplotlib.colors.Normalize | np.ndarray:
1211
    """
1212
    Get matplotlib normalization according to main function arguments.
1213

1214
    Parameters
1215
    ----------
1216
    cmap: matplotlib.colormap
1217
        Colormap to be used with the normalization.
1218
    vmin: int or float
1219
        Minimum of the data to be plotted with the colormap.
1220
    vmax: int or float
1221
        Maximum of the data to be plotted with the colormap.
1222
    levels : int or list, optional
1223
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
1224
    divergent : bool or int or float
1225
        If int or float, becomes center of cmap. Default center is 0.
1226
    linspace_out: bool
1227
        If True, return array created by np.linspace() instead of normalization instance.
1228

1229
    Returns
1230
    -------
1231
    matplotlib.colors.Normalize
1232
    """
1233
    # get cmap if string
1234
    if isinstance(cmap, str):
×
1235
        if cmap in plt.colormaps():
×
1236
            cmap = matplotlib.colormaps[cmap]
×
1237
        else:
1238
            raise ValueError("Colormap not found")
×
1239

1240
    # make vmin and vmax prettier
1241
    if (vmax - vmin) >= 25:
×
1242
        rvmax = math.ceil(vmax / 10.0) * 10
×
1243
        rvmin = math.floor(vmin / 10.0) * 10
×
1244
    elif 1 <= (vmax - vmin) < 25:
×
1245
        rvmax = math.ceil(vmax / 1) * 1
×
1246
        rvmin = math.floor(vmin / 1) * 1
×
1247
    elif 0.1 <= (vmax - vmin) < 1:
×
1248
        rvmax = math.ceil(vmax / 0.1) * 0.1
×
1249
        rvmin = math.floor(vmin / 0.1) * 0.1
×
1250
    else:
1251
        rvmax = math.ceil(vmax / 0.01) * 0.01
×
1252
        rvmin = math.floor(vmin / 0.01) * 0.01
×
1253

1254
    # center
1255
    center = None
×
1256
    if divergent is not False:
×
1257
        if divergent is True:
×
1258
            center = 0
×
NEW
1259
        elif isinstance(divergent, int | float):
×
1260
            center = divergent
×
1261

1262
    # build norm with options
1263
    if center is not None and isinstance(levels, int):
×
1264
        if center <= rvmin or center >= rvmax:
×
1265
            raise ValueError("vmin, center and vmax must be in ascending order.")
×
1266
        if levels % 2 == 1:
×
1267
            half_levels = int((levels + 1) / 2) + 1
×
1268
        else:
1269
            half_levels = int(levels / 2) + 1
×
1270

1271
        lin = np.concatenate(
×
1272
            (
1273
                np.linspace(rvmin, center, num=half_levels),
1274
                np.linspace(center, rvmax, num=half_levels)[1:],
1275
            )
1276
        )
1277
        norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1278

1279
        if linspace_out:
×
1280
            return lin
×
1281

1282
    elif levels is not None:
×
1283
        if isinstance(levels, list):
×
1284
            if center is not None:
×
1285
                warnings.warn(
×
1286
                    "Divergent argument ignored when levels is a list. Use levels as a number instead.", stacklevel=2
1287
                )
1288
            norm = matplotlib.colors.BoundaryNorm(boundaries=levels, ncolors=cmap.N)
×
1289
        else:
1290
            lin = np.linspace(rvmin, rvmax, num=levels + 1)
×
1291
            norm = matplotlib.colors.BoundaryNorm(boundaries=lin, ncolors=cmap.N)
×
1292

1293
            if linspace_out:
×
1294
                return lin
×
1295

1296
    elif center is not None:
×
1297
        norm = matplotlib.colors.TwoSlopeNorm(center, vmin=rvmin, vmax=rvmax)
×
1298
    else:
1299
        norm = matplotlib.colors.Normalize(rvmin, rvmax)
×
1300

1301
    return norm
×
1302

1303

1304
def norm2range(
5✔
1305
    data: np.ndarray, target_range: tuple, data_range: tuple | None = None
1306
) -> np.ndarray:
1307
    """Normalize data across a specific range."""
1308
    if data_range is None:
×
1309
        if len(data) > 1:
×
1310
            data_range = (np.nanmin(data), np.nanmax(data))
×
1311
        else:
1312
            raise ValueError(" if data is not an array, data_range must be specified")
×
1313

1314
    norm = (data - data_range[0]) / (data_range[1] - data_range[0])
×
1315

1316
    return target_range[0] + (norm * (target_range[1] - target_range[0]))
×
1317

1318

1319
def size_legend_elements(
5✔
1320
    data: np.ndarray, sizes: np.ndarray, marker: str, max_entries: int = 6
1321
) -> list[matplotlib.lines.Line2D]:
1322
    """
1323
    Create handles to use in a point-size legend.
1324

1325
    Parameters
1326
    ----------
1327
    data : np.ndarray
1328
        Data used to determine the point sizes.
1329
    sizes : np.ndarray
1330
        Array of point sizes.
1331
    max_entries : int
1332
        Maximum number of entries in the legend.
1333
    marker: str
1334
        Marker to use in legend.
1335

1336
    Returns
1337
    -------
1338
    list of matplotlib.lines.Line2D
1339
    """
1340
    # how many increments of 10 pts**2 are there in the sizes
1341
    n = int(np.round(max(sizes) - min(sizes), -1) / 10)
×
1342

1343
    # divide data in those increments
1344
    lgd_data = np.linspace(min(data), max(data), n)
×
1345

1346
    # round according to range
1347
    ratio = abs(max(data) - min(data) / n)
×
1348

1349
    if ratio >= 1000:
×
1350
        rounding = 1000
×
1351
    elif 100 <= ratio < 1000:
×
1352
        rounding = 100
×
1353
    elif 10 <= ratio < 100:
×
1354
        rounding = 10
×
1355
    elif 5 <= ratio < 10:
×
1356
        rounding = 5
×
1357
    elif 1 <= ratio < 5:
×
1358
        rounding = 1
×
1359
    elif 0.1 <= ratio < 1:
×
1360
        rounding = 0.1
×
1361
    elif 0.01 <= ratio < 0.1:
×
1362
        rounding = 0.01
×
1363
    else:
1364
        rounding = 0.001
×
1365

1366
    lgd_data = np.unique(rounding * np.round(lgd_data / rounding))
×
1367

1368
    # convert back to sizes
1369
    lgd_sizes = norm2range(
×
1370
        data=lgd_data,
1371
        data_range=(min(data), max(data)),
1372
        target_range=(min(sizes), max(sizes)),
1373
    )
1374

1375
    legend_elements = []
×
1376

NEW
1377
    for s, d in zip(lgd_sizes, lgd_data, strict=False):
×
1378
        if isinstance(d, float) and d.is_integer():
×
1379
            label = str(int(d))
×
1380
        else:
1381
            label = str(d)
×
1382

1383
        legend_elements.append(
×
1384
            Line2D(
1385
                [0],
1386
                [0],
1387
                marker=marker,
1388
                color="k",
1389
                lw=0,
1390
                markerfacecolor="w",
1391
                label=label,
1392
                markersize=np.sqrt(np.abs(s)),
1393
            )
1394
        )
1395

1396
    if len(legend_elements) > max_entries:
×
1397
        return [legend_elements[i] for i in np.arange(0, max_entries + 1, 2)]
×
1398
    else:
1399
        return legend_elements
×
1400

1401

1402
def add_features_map(
5✔
1403
    data,
1404
    ax,
1405
    use_attrs,
1406
    projection,
1407
    features,
1408
    geometries_kw,
1409
    frame,
1410
) -> matplotlib.axes.Axes:
1411
    """
1412
    Add features such as cartopy, time label, and geometries to a map on a given matplotlib axis.
1413

1414
    Parameters
1415
    ----------
1416
    data : dict, DataArray or Dataset
1417
        Input data do plot. If dictionary, must have only one entry.
1418
    ax : matplotlib axis
1419
        Matplotlib axis on which to plot, with the same projection as the one specified.
1420
    use_attrs : dict
1421
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1422
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1423
        Only the keys found in the default dict can be used.
1424
    projection : ccrs.Projection
1425
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1426
    features : list or dict
1427
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1428
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1429
    geometries_kw : dict
1430
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1431
    frame : bool
1432
        Show or hide frame. Default False.
1433

1434
    Returns
1435
    -------
1436
    matplotlib.axes.Axes
1437
    """
1438
    # add features
1439
    if features:
×
1440
        add_cartopy_features(ax, features)
×
1441

1442
    set_plot_attrs(use_attrs, data, ax)
×
1443

1444
    if frame is False:
×
1445
        ax.spines["geo"].set_visible(False)
×
1446

1447
    # add geometries
1448
    if geometries_kw:
×
1449
        if "geoms" not in geometries_kw.keys():
×
1450
            warnings.warn(
×
1451
                'geoms missing from geometries_kw (ex: {"geoms": df["geometry"]})', stacklevel=2
1452
            )
1453
        if "crs" in geometries_kw.keys():
×
1454
            geometries_kw["geoms"] = gpd_to_ccrs(
×
1455
                geometries_kw["geoms"], geometries_kw["crs"]
1456
            )
1457
        else:
1458
            geometries_kw["geoms"] = gpd_to_ccrs(geometries_kw["geoms"], projection)
×
1459
        geometries_kw = {
×
1460
            "crs": projection,
1461
            "facecolor": "none",
1462
            "edgecolor": "black",
1463
        } | geometries_kw
1464

1465
        ax.add_geometries(**geometries_kw)
×
1466
    return ax
×
1467

1468

1469
def masknan_sizes_key(data, sizes) -> xr.Dataset:
5✔
1470
    """
1471
    Mask the np.Nan values between variables used to plot hue and markersize in xr.plot.scatter().
1472

1473
    Parameters
1474
    ----------
1475
    data: xr.Dataset
1476
        xr.Dataset used to plot
1477
    sizes: str
1478
        Variable used to plot markersize
1479

1480
    Returns
1481
    -------
1482
    xr.Dataset
1483
    """
1484
    # find variable name
1485
    kl = list(data.keys())
×
1486
    kl.remove(sizes)
×
1487
    key = kl[0]
×
1488

1489
    # Create a mask for missing 'sizes' data
1490
    size_mask = np.isnan(data[sizes])
×
1491

1492
    # Set 'key' values to NaN where 'sizes' is missing
1493
    data[key] = data[key].where(~size_mask)
×
1494

1495
    # Create a mask for missing 'key' data
1496
    key_mask = np.isnan(data[key])
×
1497

1498
    # Set 'sizes' values to NaN where 'key' is missing
1499
    data[sizes] = data[sizes].where(~key_mask)
×
1500
    return data
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc