• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 18353693760

08 Oct 2025 06:01PM UTC coverage: 8.075% (-0.05%) from 8.122%
18353693760

push

github

web-flow
Update cookiecutter template, add CITATION.cff (#362)

### What kind of change does this PR introduce?

* Updates the cookiecutter to the latest version
* Removed `black`, `blackdoc`, and `isort`
* Added a `CITATION.cff` file
* `pyproject.toml` is now PEP-639 compliant 
* Contributor Covenant agreement updated to v3.0
* Some dependency updates

### Does this PR introduce a breaking change?

Yes. `black`, `blackdoc`, and `isort` have been dropped for `ruff`.

### Other information:

There may be some small adjustments to the code base. `ruff` is
configured to be very similar to `black`, but it is not 1-1 compliant
with it.

0 of 56 new or added lines in 4 files covered. (0.0%)

2 existing lines in 1 file now uncovered.

156 of 1932 relevant lines covered (8.07%)

0.4 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

3.53
/src/figanos/matplotlib/plot.py
1
# noqa: D100
2
from __future__ import annotations
5✔
3
import copy
5✔
4
import logging
5✔
5
import math
5✔
6
import string
5✔
7
import warnings
5✔
8
from collections.abc import Iterable
5✔
9
from inspect import signature
5✔
10
from pathlib import Path
5✔
11
from typing import Any
5✔
12

13
import cartopy.mpl.geoaxes
5✔
14
import geopandas as gpd
5✔
15
import matplotlib
5✔
16
import matplotlib.axes
5✔
17
import matplotlib.colors
5✔
18
import matplotlib.pyplot as plt
5✔
19
import mpl_toolkits.axisartist.grid_finder as gf
5✔
20
import numpy as np
5✔
21
import pandas as pd
5✔
22
import seaborn as sns
5✔
23
import xarray as xr
5✔
24
from cartopy import crs as ccrs
5✔
25
from matplotlib.cm import ScalarMappable
5✔
26
from matplotlib.lines import Line2D
5✔
27
from matplotlib.projections import PolarAxes
5✔
28
from matplotlib.tri import Triangulation
5✔
29
from mpl_toolkits.axisartist.floating_axes import FloatingSubplot, GridHelperCurveLinear
5✔
30

31
from figanos.matplotlib.utils import (  # masknan_sizes_key,
5✔
32
    add_cartopy_features,
33
    add_features_map,
34
    check_timeindex,
35
    convert_scen_name,
36
    create_cmap,
37
    custom_cmap_norm,
38
    empty_dict,
39
    fill_between_label,
40
    get_array_categ,
41
    get_attributes,
42
    get_localized_term,
43
    get_rotpole,
44
    get_scen_color,
45
    get_var_group,
46
    gpd_to_ccrs,
47
    norm2range,
48
    plot_coords,
49
    process_keys,
50
    set_plot_attrs,
51
    size_legend_elements,
52
    sort_lines,
53
    split_legend,
54
    wrap_text,
55
)
56

57

58
logger = logging.getLogger(__name__)
5✔
59

60

61
def _plot_realizations(
5✔
62
    ax: matplotlib.axes.Axes,
63
    da: xr.DataArray,
64
    name: str,
65
    plot_kw: dict[str, Any],
66
    non_dict_data: dict[str, Any],
67
) -> matplotlib.axes.Axes:
68
    """
69
    Plot realizations from a DataArray, inside or outside a Dataset.
70

71
    Parameters
72
    ----------
73
    ax : matplotlib.axes.Axes
74
        The Matplotlib axis object.
75
    da : DataArray
76
        The DataArray containing the realizations.
77
    name : str
78
        The label to be used in the first part of a composite label.
79
        Can be the name of the parent Dataset or that of the DataArray.
80
    plot_kw : dict
81
        Dictionary of kwargs coming from the timeseries() input.
82
    non_dict_data : dict
83
        TBD.
84

85
    Returns
86
    -------
87
    matplotlib.axes.Axes
88
    """
89
    ignore_label = False
×
90

91
    for r in da.realization.values:
×
92
        if plot_kw[name]:  # if kwargs (all lines identical)
×
93
            if not ignore_label:  # if label not already in legend
×
94
                label = "" if non_dict_data is True else name
×
95
                ignore_label = True
×
96
            else:
97
                label = ""
×
98
        else:
99
            label = str(r) if non_dict_data is True else (name + "_" + str(r))
×
100

101
        ax.plot(
×
102
            da.sel(realization=r)["time"],
103
            da.sel(realization=r).values,
104
            label=label,
105
            **plot_kw[name],
106
        )
107

108
    return ax
×
109

110

111
def _plot_timeseries(
5✔
112
    ax: matplotlib.axes.Axes,
113
    name: str,
114
    arr: xr.DataArray | xr.Dataset,
115
    plot_kw: dict[str, Any],
116
    non_dict_data: bool,
117
    array_categ: dict[str, Any],
118
    legend: str,
119
) -> matplotlib.axes.Axes:
120
    """
121
    Plot figanos timeseries.
122

123
    Parameters
124
    ----------
125
    ax: matplotlib.axes.Axes
126
        Axe to be used for plotting.
127
    name : str
128
        Dictionary key of the plotted data.
129
    arr : Dataset/DataArray
130
        Data to be plotted.
131
    plot_kw : dict
132
        Dictionary of kwargs coming from the timeseries() input.
133
    non_dic_data : bool
134
        If True, plot_kw is not a dictionary.
135
    array_categ: dict
136
        Categories of data.
137
    legend: str
138
        Legend type.
139

140
    Returns
141
    -------
142
    matplotlib.axes.Axes
143
    """
144
    lines_dict = {}  # created to facilitate accessing line properties later
×
145
    # look for SSP, RCP, CMIP model color
146
    cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
×
147
    if get_scen_color(name, cat_colors):
×
148
        plot_kw[name].setdefault("color", get_scen_color(name, cat_colors))
×
149

150
    #  remove 'label' to avoid error due to double 'label' args
151
    if "label" in plot_kw[name]:
×
152
        del plot_kw[name]["label"]
×
NEW
153
        warnings.warn(f'"label" entry in plot_kw[{name}] will be ignored.', stacklevel=2)
×
154

155
    if array_categ[name] == "ENS_REALS_DA":
×
156
        _plot_realizations(ax, arr, name, plot_kw, non_dict_data)
×
157

158
    elif array_categ[name] == "ENS_REALS_DS":
×
159
        if len(arr.data_vars) >= 2:
×
160
            raise TypeError(
×
161
                "To plot multiple ensembles containing realizations, use DataArrays outside a Dataset"
162
            )
NEW
163
        for sub_arr in arr.data_vars.values():
×
164
            _plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data)
×
165

166
    elif array_categ[name] == "ENS_PCT_DIM_DS":
×
NEW
167
        for sub_arr in arr.data_vars.values():
×
168
            sub_name = (
×
169
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
170
            )
171

172
            # extract each percentile array from the dims
173
            array_data = {}
×
174
            for pct in sub_arr.percentiles.values:
×
175
                array_data[str(pct)] = sub_arr.sel(percentiles=pct)
×
176

177
            # create a dictionary labeling the middle, upper and lower line
178
            sorted_lines = sort_lines(array_data)
×
179

180
            # plot
181
            lines_dict[sub_name] = ax.plot(
×
182
                array_data[sorted_lines["middle"]]["time"],
183
                array_data[sorted_lines["middle"]].values,
184
                label=sub_name,
185
                **plot_kw[name],
186
            )
187

188
            ax.fill_between(
×
189
                array_data[sorted_lines["lower"]]["time"],
190
                array_data[sorted_lines["lower"]].values,
191
                array_data[sorted_lines["upper"]].values,
192
                color=lines_dict[sub_name][0].get_color(),
193
                linewidth=0.0,
194
                alpha=0.2,
195
                label=fill_between_label(sorted_lines, name, array_categ, legend),
196
            )
197

198
    # other ensembles
199
    elif array_categ[name] in [
×
200
        "ENS_PCT_VAR_DS",
201
        "ENS_STATS_VAR_DS",
202
        "ENS_PCT_DIM_DA",
203
    ]:
204
        # extract each array from the datasets
205
        array_data = {}
×
206
        if array_categ[name] == "ENS_PCT_DIM_DA":
×
207
            for pct in arr.percentiles:
×
208
                array_data[str(int(pct))] = arr.sel(percentiles=int(pct))
×
209
        else:
210
            for k, v in arr.data_vars.items():
×
211
                array_data[k] = v
×
212

213
        # create a dictionary labeling the middle, upper and lower line
214
        sorted_lines = sort_lines(array_data)
×
215

216
        # plot
217
        lines_dict[name] = ax.plot(
×
218
            array_data[sorted_lines["middle"]]["time"],
219
            array_data[sorted_lines["middle"]].values,
220
            label=name,
221
            **plot_kw[name],
222
        )
223

224
        ax.fill_between(
×
225
            array_data[sorted_lines["lower"]]["time"],
226
            array_data[sorted_lines["lower"]].values,
227
            array_data[sorted_lines["upper"]].values,
228
            color=lines_dict[name][0].get_color(),
229
            linewidth=0.0,
230
            alpha=0.2,
231
            label=fill_between_label(sorted_lines, name, array_categ, legend),
232
        )
233

234
    #  non-ensemble Datasets
235
    elif array_categ[name] == "DS":
×
236
        ignore_label = False
×
NEW
237
        for sub_arr in arr.data_vars.values():
×
238
            sub_name = (
×
239
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
240
            )
241

242
            #  if kwargs are specified by user, all lines are the same and we want one legend entry
243
            if plot_kw[name]:
×
244
                label = name if not ignore_label else ""
×
245
                ignore_label = True
×
246
            else:
247
                label = sub_name
×
248

249
            lines_dict[sub_name] = ax.plot(
×
250
                sub_arr["time"], sub_arr.values, label=label, **plot_kw[name]
251
            )
252

253
    #  non-ensemble DataArrays
254
    elif array_categ[name] in ["DA"]:
×
255
        lines_dict[name] = ax.plot(arr["time"], arr.values, label=name, **plot_kw[name])
×
256

257
    else:
258
        raise ValueError(
×
259
            "Data structure not supported"
260
        )  # can probably be removed along with elif logic above,
261
        # given that get_array_categ() also does this check
262
    return ax
×
263

264

265
def timeseries(
5✔
266
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
267
    ax: matplotlib.axes.Axes | None = None,
268
    use_attrs: dict[str, Any] | None = None,
269
    fig_kw: dict[str, Any] | None = None,
270
    plot_kw: dict[str, Any] | None = None,
271
    legend: str = "lines",
272
    show_lat_lon: bool | str | int | tuple[float, float] = True,
273
    enumerate_subplots: bool = False,
274
) -> matplotlib.axes.Axes:
275
    """
276
    Plot time series from 1D Xarray Datasets or DataArrays as line plots.
277

278
    Parameters
279
    ----------
280
    data : dict or Dataset/DataArray
281
        Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets.
282
    ax : matplotlib.axes.Axes, optional
283
        Matplotlib axis on which to plot.
284
    use_attrs : dict, optional
285
        A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
286
        Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}.
287
        Only the keys found in the default dict can be used.
288
    fig_kw : dict, optional
289
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
290
    plot_kw : dict, optional
291
        Arguments to pass to the `plot()` function. Changes how the line looks.
292
        If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'.
293
    legend : str (default 'lines') or dict
294
        'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines),
295
         'edge' (out of plot), 'facetgrid' under figure, 'none' (no legend). If dict, arguments to pass to ax.legend().
296
    show_lat_lon : bool, tuple, str or int
297
        If True, show latitude and longitude at the bottom right of the figure.
298
        Can be a tuple of axis coordinates (from 0 to 1, as a fraction of the axis length) representing
299
        the location of the text. If a string or an int, the same values as those of the 'loc' parameter
300
        of matplotlib's legends are accepted.
301

302
        ==================   =============
303
        Location String      Location Code
304
        ==================   =============
305
        'upper right'        1
306
        'upper left'         2
307
        'lower left'         3
308
        'lower right'        4
309
        'right'              5
310
        'center left'        6
311
        'center right'       7
312
        'lower center'       8
313
        'upper center'       9
314
        'center'             10
315
        ==================   =============
316
    enumerate_subplots: bool
317
        If True, enumerate subplots with letters.
318
        Only works with facetgrids (pass `col` or `row` in plot_kw).
319

320
    Returns
321
    -------
322
    matplotlib.axes.Axes
323
    """
324
    # convert SSP, RCP, CMIP formats in keys
325
    if isinstance(data, dict):
×
326
        data = process_keys(data, convert_scen_name)
×
327
    if isinstance(plot_kw, dict):
×
328
        plot_kw = process_keys(plot_kw, convert_scen_name)
×
329

330
    # create empty dicts if None
331
    use_attrs = empty_dict(use_attrs)
×
332
    fig_kw = empty_dict(fig_kw)
×
333
    plot_kw = empty_dict(plot_kw)
×
334

335
    # if only one data input, insert in dict.
336
    non_dict_data = False
×
337
    if not isinstance(data, dict):
×
338
        non_dict_data = True
×
339
        data = {"_no_label": data}  # mpl excludes labels starting with "_" from legend
×
340
        plot_kw = {"_no_label": empty_dict(plot_kw)}
×
341

342
    # assign keys to plot_kw if not there
343
    if non_dict_data is False:
×
344
        for name in data:
×
345
            if name not in plot_kw:
×
346
                plot_kw[name] = {}
×
347
        for key in plot_kw:
×
348
            if key not in data:
×
349
                raise KeyError(
×
350
                    'plot_kw must be a nested dictionary with keys corresponding to the keys in "data"'
351
                )
352

353
    # check: type
NEW
354
    for arr in data.values():
×
NEW
355
        if not isinstance(arr, xr.Dataset | xr.DataArray):
×
UNCOV
356
            raise TypeError(
×
357
                '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
358
            )
359

360
    # check: 'time' dimension and calendar format
361
    data = check_timeindex(data)
×
362

363
    # set fig, ax if not provided
364
    if ax is None and (
×
365
        "row" not in list(plot_kw.values())[0].keys()
366
        and "col" not in list(plot_kw.values())[0].keys()
367
    ):
368
        fig, ax = plt.subplots(**fig_kw)
×
369
    elif ax is not None and (
×
370
        "col" in list(plot_kw.values())[0].keys()
371
        or "row" in list(plot_kw.values())[0].keys()
372
    ):
373
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
374
    elif ax is None:
×
375
        cfig_kw = fig_kw.copy()
×
376
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
377
            list(plot_kw.values())[0].setdefault("figsize", fig_kw["figsize"])
×
378
            cfig_kw.pop("figsize")
×
379
        if cfig_kw:
×
380
            for v in plot_kw.values():
×
381
                {"subplots_kws": cfig_kw} | v
×
382
            warnings.warn(
×
383
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2
384
            )
385

386
    # set default use_attrs values
387
    if ax:
×
388
        use_attrs.setdefault("title", "description")
×
389
    else:
390
        use_attrs.setdefault("suptitle", "description")
×
391
    use_attrs.setdefault("ylabel", "long_name")
×
392
    use_attrs.setdefault("yunits", "units")
×
393

394
    # dict of array 'categories'
395
    array_categ = {name: get_array_categ(array) for name, array in data.items()}
×
396
    cp_plot_kw = copy.deepcopy(plot_kw)
×
397
    # get data and plot
398
    for name, arr in data.items():
×
399
        if ax:
×
400
            _plot_timeseries(ax, name, arr, plot_kw, non_dict_data, array_categ, legend)
×
401
        else:
402
            if name == list(data.keys())[0]:
×
403
                # create empty DataArray with same dimensions as data first entry to create an empty xr.plot.FacetGrid
404
                if isinstance(arr, xr.Dataset):
×
405
                    da = arr[list(arr.keys())[0]]
×
406
                else:
407
                    da = arr
×
408
                da = da.where(da == np.nan)
×
409
                im = da.plot(**plot_kw[name], color="white")
×
410

411
            [
×
412
                cp_plot_kw[name].pop(key)
413
                for key in ["row", "col", "figsize"]
414
                if key in cp_plot_kw[name].keys()
415
            ]
416

417
            # plot data in every axis of the facetgrid
418
            for i in range(0, im.axs.shape[0]):
×
419
                for j in range(0, im.axs.shape[1]):
×
420
                    sel_arr = {}
×
421

422
                    if "row" in plot_kw[name]:
×
423
                        sel_arr[plot_kw[name]["row"]] = i
×
424
                    if "col" in plot_kw[name]:
×
425
                        sel_arr[plot_kw[name]["col"]] = j
×
426

427
                    _plot_timeseries(
×
428
                        im.axs[i, j],
429
                        name,
430
                        arr.isel(**sel_arr).squeeze(),
431
                        cp_plot_kw,
432
                        non_dict_data,
433
                        array_categ,
434
                        legend,
435
                    )
436

437
    #  add/modify plot elements according to the first entry.
438
    if ax:
×
439
        set_plot_attrs(
×
440
            use_attrs,
441
            list(data.values())[0],
442
            ax,
443
            title_loc="left",
444
            wrap_kw={"min_line_len": 35, "max_line_len": 48},
445
        )
446
        ax.set_xlabel(
×
447
            get_localized_term("time").capitalize()
448
        )  # check_timeindex() already checks for 'time'
449

450
        # other plot elements
451
        if show_lat_lon:
×
452
            if show_lat_lon is True:
×
453
                plot_coords(
×
454
                    ax,
455
                    list(data.values())[0],
456
                    param="location",
457
                    loc="lower right",
458
                    backgroundalpha=1,
459
                )
NEW
460
            elif isinstance(show_lat_lon, str | tuple | int):
×
461
                plot_coords(
×
462
                    ax,
463
                    list(data.values())[0],
464
                    param="location",
465
                    loc=show_lat_lon,
466
                    backgroundalpha=1,
467
                )
468
            else:
469
                raise TypeError(" show_lat_lon must be a bool, string, int, or tuple")
×
470

471
        if legend is not None:
×
472
            if not ax.get_legend_handles_labels()[0]:  # check if legend is empty
×
473
                pass
×
474
            elif legend == "in_plot":
×
475
                split_legend(ax, in_plot=True)
×
476
            elif legend == "edge":
×
477
                split_legend(ax, in_plot=False)
×
478
            elif isinstance(legend, dict):
×
479
                ax.legend(**legend)
×
480
            else:
481
                ax.legend()
×
482

483
        return ax
×
484
    else:
485
        if legend is not None:
×
486
            if not im.axs[-1, -1].get_legend_handles_labels()[
×
487
                0
488
            ]:  # check if legend is empty
489
                pass
×
490
            elif legend == "in_plot":
×
491
                split_legend(im.axs[-1, -1], in_plot=True)
×
492
            elif legend == "edge":
×
493
                split_legend(im.axs[-1, -1], in_plot=False)
×
494
            elif isinstance(legend, dict):
×
495
                handles, labels = im.axs[-1, -1].get_legend_handles_labels()
×
496
                legend = {"handles": handles, "labels": labels} | legend
×
497
                im.fig.legend(**legend)
×
498
            elif legend == "facetgrid":
×
499
                handles, labels = im.axs[-1, -1].get_legend_handles_labels()
×
500
                im.fig.legend(
×
501
                    handles,
502
                    labels,
503
                    loc="lower center",
504
                    ncol=len(im.axs[-1, -1].lines),
505
                    bbox_to_anchor=(0.5, -0.05),
506
                )
507

508
        if show_lat_lon:
×
509
            if show_lat_lon is True:
×
510
                plot_coords(
×
511
                    None,
512
                    list(data.values())[0].isel(lat=0, lon=0),
513
                    param="location",
514
                    loc="lower right",
515
                    backgroundalpha=1,
516
                )
NEW
517
            elif isinstance(show_lat_lon, str | tuple | int):
×
518
                plot_coords(
×
519
                    None,
520
                    list(data.values())[0].isel(lat=0, lon=0),
521
                    param="location",
522
                    loc=show_lat_lon,
523
                    backgroundalpha=1,
524
                )
525
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
526
            for idx, ax in enumerate(im.axs.flat):
×
527
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
528

529
        return im
×
530

531

532
def gridmap(
5✔
533
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
534
    ax: matplotlib.axes.Axes | None = None,
535
    use_attrs: dict[str, Any] | None = None,
536
    fig_kw: dict[str, Any] | None = None,
537
    plot_kw: dict[str, Any] | None = None,
538
    projection: ccrs.Projection = ccrs.LambertConformal(),
539
    transform: ccrs.Projection | None = None,
540
    features: list[str] | dict[str, dict[str, Any]] | None = None,
541
    geometries_kw: dict[str, Any] | None = None,
542
    contourf: bool = False,
543
    cmap: str | matplotlib.colors.Colormap | None = None,
544
    levels: int | list | np.ndarray | None = None,
545
    divergent: bool | int | float = False,
546
    show_time: bool | str | int | tuple[float, float] = False,
547
    frame: bool = False,
548
    enumerate_subplots: bool = False,
549
) -> matplotlib.axes.Axes:
550
    """
551
    Create map from 2D data.
552

553
    Parameters
554
    ----------
555
    data : dict, DataArray or Dataset
556
        Input data do plot. If dictionary, must have only one entry.
557
    ax : matplotlib axis, optional
558
        Matplotlib axis on which to plot, with the same projection as the one specified.
559
    use_attrs : dict, optional
560
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
561
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
562
        Only the keys found in the default dict can be used.
563
    fig_kw : dict, optional
564
        Arguments to pass to `plt.figure()`.
565
    plot_kw:  dict, optional
566
        Arguments to pass to the `xarray.plot.pcolormesh()` or 'xarray.plot.contourf()' function.
567
    projection : ccrs.Projection
568
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
569
    transform : ccrs.Projection, optional
570
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
571
        ccrs.PlateCarree() or ccrs.RotatedPole().
572
    features : list or dict, optional
573
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
574
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
575
    geometries_kw : dict, optional
576
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
577
    contourf : bool
578
        By default False, use plt.pcolormesh(). If True, use plt.contourf().
579
    cmap : matplotlib.colors.Colormap or str, optional
580
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
581
        If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of the DataArray
582
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC visual style guide 2022
583
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
584
    levels : int, list, np.ndarray, optional
585
        Number of levels to divide the colormap into or list of level boundaries (in data units).
586
    divergent : bool or int or float
587
        If int or float, becomes center of cmap. Default center is 0.
588
    show_time : bool, tuple, string or int.
589
        If True, show time (as date) at the bottom right of the figure.
590
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
591
        of the text. If a string or an int, the same values as those of the 'loc' parameter
592
        of matplotlib's legends are accepted.
593

594
        ==================   =============
595
        Location String      Location Code
596
        ==================   =============
597
        'upper right'        1
598
        'upper left'         2
599
        'lower left'         3
600
        'lower right'        4
601
        'right'              5
602
        'center left'        6
603
        'center right'       7
604
        'lower center'       8
605
        'upper center'       9
606
        'center'             10
607
        ==================   =============
608
    frame : bool
609
        Show or hide frame. Default False.
610
    enumerate_subplots: bool
611
        If True, enumerate subplots with letters.
612
        Only works with facetgrids (pass `col` or `row` in plot_kw).
613

614
    Returns
615
    -------
616
    matplotlib.axes.Axes
617
    """
618
    # create empty dicts if None
619
    use_attrs = empty_dict(use_attrs)
×
620
    fig_kw = empty_dict(fig_kw)
×
621
    plot_kw = empty_dict(plot_kw)
×
622

623
    # set default use_attrs values
624
    use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs
×
625
    if "row" not in plot_kw and "col" not in plot_kw:
×
626
        use_attrs.setdefault("title", "description")
×
627

628
    # extract plot_kw from dict if needed
629
    if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys():
×
630
        plot_kw = plot_kw[list(data.keys())[0]]
×
631

632
    # if data is dict, extract
633
    if isinstance(data, dict):
×
634
        if len(data) == 1:
×
635
            data = list(data.values())[0]
×
636
        else:
637
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
638

639
    # select data to plot
640
    if isinstance(data, xr.DataArray):
×
641
        plot_data = data.squeeze()
×
642
    elif isinstance(data, xr.Dataset):
×
643
        if len(data.data_vars) > 1:
×
644
            warnings.warn(
×
645
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
646
            )
647
        plot_data = data[list(data.keys())[0]].squeeze()
×
648
    else:
649
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
650

651
    # setup transform
652
    if transform is None:
×
653
        if "lat" in data.dims and "lon" in data.dims:
×
654
            transform = ccrs.PlateCarree()
×
655
        if "rlat" in data.dims and "rlon" in data.dims:
×
656
            transform = get_rotpole(data)
×
657

658
    # setup fig, ax
659
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
660
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
661
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
662
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
663
    elif ax is None:
×
664
        plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
665
        cfig_kw = fig_kw.copy()
×
666
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
667
            plot_kw.setdefault("figsize", fig_kw["figsize"])
×
668
            cfig_kw.pop("figsize")
×
669
        if len(cfig_kw) >= 1:
×
670
            plot_kw = {"subplot_kws": {"projection": cfig_kw}} | plot_kw
×
671
            warnings.warn(
×
672
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2
673
            )
674

675
    # create cbar label
676
    if (
×
677
        "cbar_units" in use_attrs
678
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
679
    ):  # avoids '[]' as label
680
        cbar_label = (
×
681
            get_attributes(use_attrs["cbar_label"], data)
682
            + " ("
683
            + get_attributes(use_attrs["cbar_units"], data)
684
            + ")"
685
        )
686
    else:
687
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
688

689
    # colormap
690
    if isinstance(cmap, str):
×
691
        if cmap not in plt.colormaps():
×
692
            try:
×
693
                cmap = create_cmap(filename=cmap)
×
694
            except FileNotFoundError as e:
×
695
                logger.error(e)
×
696
                pass
×
697

698
    elif cmap is None:
×
699
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
700
        cmap = create_cmap(
×
701
            get_var_group(path_to_json=cdata, da=plot_data),
702
            divergent=divergent,
703
        )
704
    plot_kw.setdefault("cmap", cmap)
×
705

706
    if levels is not None:
×
707
        if isinstance(levels, Iterable):
×
708
            lin = levels
×
709
        else:
710
            lin = custom_cmap_norm(
×
711
                cmap,
712
                np.nanmin(plot_data.values),
713
                np.nanmax(plot_data.values),
714
                levels=levels,
715
                divergent=divergent,
716
                linspace_out=True,
717
            )
718
        plot_kw.setdefault("levels", lin)
×
719

720
    elif (divergent is not False) and ("levels" not in plot_kw):
×
721
        vmin = plot_kw.pop("vmin", np.nanmin(plot_data.values))
×
722
        vmax = plot_kw.pop("vmax", np.nanmax(plot_data.values))
×
723
        norm = custom_cmap_norm(
×
724
            cmap,
725
            vmin,
726
            vmax,
727
            levels=levels,
728
            divergent=divergent,
729
        )
730
        plot_kw.setdefault("norm", norm)
×
731

732
    # set defaults
733
    if divergent is not False:
×
NEW
734
        if isinstance(divergent, int | float):
×
735
            plot_kw.setdefault("center", divergent)
×
736
        else:
737
            plot_kw.setdefault("center", 0)
×
738

739
    if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False:
×
740
        plot_kw.setdefault("cbar_kwargs", {})
×
741
        plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
×
742

743
    # bug xlim / ylim + transform in facetgrids
744
    # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766)
745
    if transform and ("xlim" in plot_kw and "ylim" in plot_kw):
×
746
        extent = [
×
747
            plot_kw["xlim"][0],
748
            plot_kw["xlim"][1],
749
            plot_kw["ylim"][0],
750
            plot_kw["ylim"][1],
751
        ]
752
        plot_kw.pop("xlim")
×
753
        plot_kw.pop("ylim")
×
754
    elif transform and ("xlim" in plot_kw or "ylim" in plot_kw):
×
755
        extent = None
×
756
        warnings.warn(
×
757
            "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped", stacklevel=2
758
        )
759
        if "xlim" in plot_kw.keys():
×
760
            plot_kw.pop("xlim")
×
761
        if "ylim" in plot_kw.keys():
×
762
            plot_kw.pop("ylim")
×
763
    else:
764
        extent = None
×
765

766
    # plot
767
    if ax:
×
768
        plot_kw.setdefault("ax", ax)
×
769
    if transform:
×
770
        plot_kw.setdefault("transform", transform)
×
771

772
    if contourf is False:
×
773
        im = plot_data.plot.pcolormesh(**plot_kw)
×
774
    else:
775
        im = plot_data.plot.contourf(**plot_kw)
×
776

777
    if ax:
×
778
        if extent:
×
779
            ax.set_extent(extent)
×
780

781
        ax = add_features_map(
×
782
            data,
783
            ax,
784
            use_attrs,
785
            projection,
786
            features,
787
            geometries_kw,
788
            frame,
789
        )
790
        if show_time:
×
791
            if isinstance(show_time, bool):
×
792
                plot_coords(
×
793
                    ax,
794
                    plot_data,
795
                    param="time",
796
                    loc="lower right",
797
                    backgroundalpha=1,
798
                )
NEW
799
            elif isinstance(show_time, str | tuple | int):
×
800
                plot_coords(
×
801
                    ax,
802
                    plot_data,
803
                    param="time",
804
                    loc=show_time,
805
                    backgroundalpha=1,
806
                )
807

808
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
809
        if (frame is False) and (
×
810
            (getattr(im, "colorbar", None) is not None)
811
            or (getattr(im, "cbar", None) is not None)
812
        ):
813
            im.colorbar.outline.set_visible(False)
×
814
        return ax
×
815

816
    else:
NEW
817
        for _i, fax in enumerate(im.axs.flat):
×
818
            add_features_map(
×
819
                data,
820
                fax,
821
                use_attrs,
822
                projection,
823
                features,
824
                geometries_kw,
825
                frame,
826
            )
827
            if extent:
×
828
                fax.set_extent(extent)
×
829

830
            # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
831
        if (frame is False) and (
×
832
            (getattr(im, "colorbar", None) is not None)
833
            or (getattr(im, "cbar", None) is not None)
834
        ):
835
            im.cbar.outline.set_visible(False)
×
836

837
        if show_time:
×
838
            if isinstance(show_time, bool):
×
839
                plot_coords(
×
840
                    None,
841
                    plot_data,
842
                    param="time",
843
                    loc="lower right",
844
                    backgroundalpha=1,
845
                )
NEW
846
            elif isinstance(show_time, str | tuple | int):
×
847
                plot_coords(
×
848
                    None,
849
                    plot_data,
850
                    param="time",
851
                    loc=show_time,
852
                    backgroundalpha=1,
853
                )
854

855
        use_attrs.setdefault("suptitle", "long_name")
×
856
        im = set_plot_attrs(use_attrs, data, facetgrid=im)
×
857
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
858
            for idx, ax in enumerate(im.axs.flat):
×
859
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
860

861
        return im
×
862

863

864
def gdfmap(
5✔
865
    df: gpd.GeoDataFrame,
866
    df_col: str,
867
    ax: cartopy.mpl.geoaxes.GeoAxes | cartopy.mpl.geoaxes.GeoAxesSubplot | None = None,
868
    fig_kw: dict[str, Any] | None = None,
869
    plot_kw: dict[str, Any] | None = None,
870
    projection: ccrs.Projection = ccrs.LambertConformal(),
871
    features: list[str] | dict[str, dict[str, Any]] | None = None,
872
    cmap: str | matplotlib.colors.Colormap | None = None,
873
    levels: int | list[int | float] | None = None,
874
    divergent: bool | int | float = False,
875
    cbar: bool = True,
876
    frame: bool = False,
877
) -> matplotlib.axes.Axes:
878
    """
879
    Create a map plot from geometries.
880

881
    Parameters
882
    ----------
883
    df : geopandas.GeoDataFrame
884
        Dataframe containing the geometries and the data to plot. Must have a column named 'geometry'.
885
    df_col : str
886
        Name of the column of 'df' containing the data to plot using the colorscale.
887
        If `boundary`, only the boundary of the geometries is plotted, without colorscale.
888
    ax : cartopy.mpl.geoaxes.GeoAxes or cartopy.mpl.geoaxes.GeoaxesSubplot, optional
889
        Matplotlib axis built with a projection, on which to plot.
890
    fig_kw : dict, optional
891
        Arguments to pass to `plt.figure()`.
892
    plot_kw :  dict, optional
893
        Arguments to pass to the GeoDataFrame.plot() method.
894
    projection : ccrs.Projection
895
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
896
    features : list or dict, optional
897
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
898
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
899
    cmap : matplotlib.colors.Colormap or str
900
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
901
        If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of df_col
902
        and use corresponding colormap, aligned with the IPCC visual style guide 2022
903
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
904
    levels : int or list, optional
905
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
906
    divergent : bool or int or float
907
        If int or float, becomes center of cmap. Default center is 0.
908
    cbar : bool
909
        Show colorbar. Default 'True'.
910
    frame : bool
911
        Show or hide frame. Default False.
912

913
    Returns
914
    -------
915
    matplotlib.axes.Axes
916
    """
917
    # create empty dicts if None
918
    fig_kw = empty_dict(fig_kw)
×
919
    plot_kw = empty_dict(plot_kw)
×
920
    features = empty_dict(features)
×
921

922
    # checks
923
    if not isinstance(df, gpd.GeoDataFrame):
×
924
        raise TypeError("df myst be an instance of class geopandas.GeoDataFrame")
×
925

926
    if "geometry" not in df.columns:
×
927
        raise ValueError("column 'geometry' not found in GeoDataFrame")
×
928

929
    # convert to projection
930
    if ax is None:
×
931
        df = gpd_to_ccrs(df=df, proj=projection)
×
932
    else:
933
        df = gpd_to_ccrs(df=df, proj=ax.projection)
×
934

935
    # setup fig, ax
936
    if ax is None:
×
937
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
938
        ax.set_aspect("equal")  # recommended by geopandas
×
939

940
    # add features
941
    if features:
×
942
        add_cartopy_features(ax, features)
×
943

944
    if df_col == "boundary":
×
945
        plot = df.boundary.plot(ax=ax, **plot_kw)
×
946
        if cmap is not None or levels is not None or divergent is not False:
×
NEW
947
            warnings.warn("Colomap arguments are ignored when plotting 'boundary'.", stacklevel=2)
×
948
    else:
949

950
        # colormap
951
        if isinstance(cmap, str):
×
952
            if cmap in plt.colormaps():
×
953
                cmap = matplotlib.colormaps[cmap]
×
954
            else:
955
                try:
×
956
                    cmap = create_cmap(filename=cmap)
×
957
                except FileNotFoundError:
×
NEW
958
                    warnings.warn("invalid cmap, using default", stacklevel=2)
×
959
                    cmap = create_cmap(filename="slev_seq")
×
960

961
        elif cmap is None:
×
962
            cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
963
            cmap = create_cmap(
×
964
                get_var_group(unique_str=df_col, path_to_json=cdata),
965
                divergent=divergent,
966
            )
967

968
        # create normalization for colormap
969
        plot_kw.setdefault("vmin", df[df_col].min())
×
970
        plot_kw.setdefault("vmax", df[df_col].max())
×
971

972
        if (levels is not None) or (divergent is not False):
×
973
            norm = custom_cmap_norm(
×
974
                cmap,
975
                plot_kw["vmin"],
976
                plot_kw["vmax"],
977
                levels=levels,
978
                divergent=divergent,
979
            )
980
            plot_kw.setdefault("norm", norm)
×
981

982
        # colorbar
983
        if cbar:
×
984
            plot_kw.setdefault("legend", True)
×
985
            plot_kw.setdefault("legend_kwds", {})
×
986
            plot_kw["legend_kwds"].setdefault("label", df_col)
×
987
            plot_kw["legend_kwds"].setdefault("orientation", "horizontal")
×
988
            plot_kw["legend_kwds"].setdefault("pad", 0.02)
×
989

990
        # plot
991
        plot = df.plot(column=df_col, ax=ax, cmap=cmap, **plot_kw)
×
992

993
    if frame is False:
×
994
        # cbar
995
        if len(plot.figure.axes) > 1:  # only if it exists
×
996
            plot.figure.axes[1].spines["outline"].set_visible(False)
×
997
            plot.figure.axes[1].tick_params(size=0)
×
998
        # main axes
999
        ax.spines["geo"].set_visible(False)
×
1000

1001
    return ax
×
1002

1003

1004
def violin(
5✔
1005
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
1006
    ax: matplotlib.axes.Axes | None = None,
1007
    use_attrs: dict[str, Any] | None = None,
1008
    fig_kw: dict[str, Any] | None = None,
1009
    plot_kw: dict[str, Any] | None = None,
1010
    color: str | int | list[str | int] | None = None,
1011
) -> matplotlib.axes.Axes:
1012
    """
1013
    Make violin plot using seaborn.
1014

1015
    Parameters
1016
    ----------
1017
    data : dict or Dataset/DataArray
1018
        Input data to plot. If a dict, must contain DataArrays and/or Datasets.
1019
    ax : matplotlib.axes.Axes, optional
1020
        Matplotlib axis on which to plot.
1021
    use_attrs : dict, optional
1022
        A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1023
        Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}.
1024
        Only the keys found in the default dict can be used.
1025
    fig_kw : dict, optional
1026
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
1027
    plot_kw : dict, optional
1028
        Arguments to pass to the `seaborn.violinplot()` function.
1029
    color :  str, int or list, optional
1030
        Unique color or list of colors to use. Integers point to the applied stylesheet's colors, in zero-indexed order.
1031
        Passing 'color' or 'palette' in plot_kw overrides this argument.
1032

1033
    Returns
1034
    -------
1035
    matplotlib.axes.Axes
1036
    """
1037
    # create empty dicts if None
1038
    use_attrs = empty_dict(use_attrs)
×
1039
    fig_kw = empty_dict(fig_kw)
×
1040
    plot_kw = empty_dict(plot_kw)
×
1041

1042
    # if data is dict, assemble into one DataFrame
1043
    non_dict_data = True
×
1044
    if isinstance(data, dict):
×
1045
        non_dict_data = False
×
1046
        df = pd.DataFrame()
×
1047
        for key, xr_obj in data.items():
×
1048
            if isinstance(xr_obj, xr.Dataset):
×
1049
                # if one data var, use key
1050
                if len(list(xr_obj.data_vars)) == 1:
×
1051
                    df[key] = xr_obj[list(xr_obj.data_vars)[0]].values
×
1052
                # if more than one data var, use key + name of var
1053
                else:
1054
                    for data_var in list(xr_obj.data_vars):
×
1055
                        df[key + "_" + data_var] = xr_obj[data_var].values
×
1056

1057
            elif isinstance(xr_obj, xr.DataArray):
×
1058
                df[key] = xr_obj.values
×
1059

1060
            else:
1061
                raise TypeError(
×
1062
                    '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
1063
                )
1064

1065
    elif isinstance(data, xr.Dataset):
×
1066
        # create dataframe
1067
        df = data.to_dataframe()
×
1068
        df = df[data.data_vars]
×
1069

1070
    elif isinstance(data, xr.DataArray):
×
1071
        # create dataframe
1072
        df = data.to_dataframe()
×
1073
        for coord in list(data.coords):
×
1074
            if coord in df.columns:
×
1075
                df = df.drop(columns=coord)
×
1076

1077
    else:
1078
        raise TypeError(
×
1079
            '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
1080
        )
1081

1082
    # set fig, ax if not provided
1083
    if ax is None:
×
1084
        fig, ax = plt.subplots(**fig_kw)
×
1085

1086
    # set default use_attrs values
1087
    if "orient" in plot_kw and plot_kw["orient"] == "h":
×
1088
        use_attrs = {"xlabel": "long_name", "xunits": "units"} | use_attrs
×
1089
    else:
1090
        use_attrs = {"ylabel": "long_name", "yunits": "units"} | use_attrs
×
1091

1092
    #  add/modify plot elements according to the first entry.
1093
    if non_dict_data:
×
1094
        set_plot_obj = data
×
1095
    else:
1096
        set_plot_obj = list(data.values())[0]
×
1097

1098
    set_plot_attrs(
×
1099
        use_attrs,
1100
        xr_obj=set_plot_obj,
1101
        ax=ax,
1102
        title_loc="left",
1103
        wrap_kw={"min_line_len": 35, "max_line_len": 48},
1104
    )
1105

1106
    # color
1107
    if color:
×
1108
        style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]
×
1109
        if isinstance(color, str):
×
1110
            plot_kw.setdefault("color", color)
×
1111
        elif isinstance(color, int):
×
1112
            try:
×
1113
                plot_kw.setdefault("color", style_colors[color])
×
NEW
1114
            except IndexError as err:
×
NEW
1115
                raise IndexError("Index out of range of stylesheet colors") from err
×
1116
        elif isinstance(color, list):
×
NEW
1117
            for c, i in zip(color, np.arange(len(color)), strict=False):
×
1118
                if isinstance(c, int):
×
1119
                    try:
×
1120
                        color[i] = style_colors[c]
×
NEW
1121
                    except IndexError as err:
×
NEW
1122
                        raise IndexError("Index out of range of stylesheet colors") from err
×
UNCOV
1123
            plot_kw.setdefault("palette", color)
×
1124

1125
    # plot
1126
    sns.violinplot(df, ax=ax, **plot_kw)
×
1127

1128
    # grid
1129
    if "orient" in plot_kw and plot_kw["orient"] == "h":
×
1130
        ax.grid(visible=True, axis="x")
×
1131

1132
    return ax
×
1133

1134

1135
def stripes(
5✔
1136
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
1137
    ax: matplotlib.axes.Axes | None = None,
1138
    fig_kw: dict[str, Any] | None = None,
1139
    divide: int | None = None,
1140
    cmap: str | matplotlib.colors.Colormap | None = None,
1141
    cmap_center: int | float = 0,
1142
    cbar: bool = True,
1143
    cbar_kw: dict[str, Any] | None = None,
1144
) -> matplotlib.axes.Axes:
1145
    """
1146
    Create stripes plot with or without multiple scenarios.
1147

1148
    Parameters
1149
    ----------
1150
    data : dict or DataArray or Dataset
1151
        Data to plot. If a dictionary of xarray objects, each will correspond to a scenario.
1152
    ax : matplotlib.axes.Axes, optional
1153
        Matplotlib axis on which to plot.
1154
    fig_kw : : dict, optional
1155
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
1156
    divide : int, optional
1157
        Year at which the plot is divided into scenarios. If not provided, the horizontal separators
1158
        will extend over the full time axis.
1159
    cmap : matplotlib.colors.Colormap or str, optional
1160
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1161
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1162
        or its 'history' attribute and use corresponding diverging colormap, aligned with the IPCC Visual Style
1163
        Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1164
    cmap_center : int or float
1165
        Center of the colormap in data coordinates. Default is 0.
1166
    cbar : bool
1167
        Show colorbar.
1168
    cbar_kw : dict, optional
1169
        Arguments to pass to plt.colorbar.
1170

1171
    Returns
1172
    -------
1173
    matplotlib.axes.Axes
1174
    """
1175
    # create empty dicts if None
1176
    fig_kw = empty_dict(fig_kw)
×
1177
    cbar_kw = empty_dict(cbar_kw)
×
1178

1179
    # init main (figure) axis
1180
    if ax is None:
×
1181
        fig_kw.setdefault("figsize", (10, 5))
×
1182
        fig, ax = plt.subplots(**fig_kw)
×
1183
    ax.set_yticks([])
×
1184
    ax.set_xticks([])
×
1185
    ax.spines[["top", "bottom", "left", "right"]].set_visible(False)
×
1186

1187
    # init plot axis
1188
    ax_0 = ax.inset_axes([0, 0.15, 1, 0.75])
×
1189

1190
    # handle non-dict data
1191
    if not isinstance(data, dict):
×
1192
        data = {"_no_label": data}
×
1193

1194
    # convert SSP, RCP, CMIP formats in keys
1195
    data = process_keys(data, convert_scen_name)
×
1196

1197
    n = len(data)
×
1198

1199
    # extract DataArrays from datasets
1200
    for key, obj in data.items():
×
1201
        if isinstance(obj, xr.DataArray):
×
1202
            pass
×
1203
        elif isinstance(obj, xr.Dataset):
×
1204
            data[key] = obj[list(obj.data_vars)[0]]
×
1205
        else:
1206
            raise TypeError("data must contain xarray DataArrays or Datasets")
×
1207

1208
    # get time interval
1209
    time_index = list(data.values())[0].time.dt.year.values
×
1210
    delta_time = [
×
1211
        time_index[i] - time_index[i - 1] for i in np.arange(1, len(time_index), 1)
1212
    ]
1213

1214
    if all(i == delta_time[0] for i in delta_time):
×
1215
        dtime = delta_time[0]
×
1216
    else:
1217
        raise ValueError("Time delta between each array element must be constant")
×
1218

1219
    # modify axes
1220
    ax.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime)
×
1221
    ax_0.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime)
×
1222
    ax_0.set_ylim(0, 1)
×
1223
    ax_0.set_yticks([])
×
1224
    ax_0.xaxis.set_ticks_position("top")
×
1225
    ax_0.tick_params(axis="x", direction="out", zorder=10)
×
1226
    ax_0.spines[["top", "left", "right", "bottom"]].set_visible(False)
×
1227

1228
    # width of bars, to fill x axis limits
1229
    width = (max(time_index) + 0.5 - min(time_index) - 0.5) / len(time_index)
×
1230

1231
    # create historical/projection divide
1232
    if divide is not None:
×
1233
        # convert divide year to transAxes
1234
        divide_disp = ax_0.transData.transform(
×
1235
            (divide - width * 0.5, 1)
1236
        )  # left limit of stripe, 1 is placeholder
1237
        divide_ax = ax_0.transAxes.inverted().transform(divide_disp)
×
1238
        divide_ax = divide_ax[0]
×
1239
    else:
1240
        divide_ax = 0
×
1241

1242
    # create an inset ax for each da in data
1243
    subaxes = {}
×
1244
    for i in np.arange(n):
×
1245
        name = "subax_" + str(i)
×
1246
        y = (1 / n) * i
×
1247
        subaxes[name] = ax_0.inset_axes([0, y, 1, 1 / n], transform=ax_0.transAxes)
×
1248
        subaxes[name].set(xlim=ax_0.get_xlim(), ylim=(0, 1), xticks=[], yticks=[])
×
1249
        subaxes[name].spines[["top", "bottom", "left", "right"]].set_visible(False)
×
1250
        # lines separating axes
1251
        if i > 0:
×
1252
            subaxes[name].spines["bottom"].set_visible(True)
×
1253
            subaxes[name].spines["bottom"].set(
×
1254
                lw=2,
1255
                color="w",
1256
                bounds=(divide_ax, 1),
1257
                transform=subaxes[name].transAxes,
1258
            )
1259
            # circles
1260
            if divide:
×
1261
                circle = matplotlib.patches.Ellipse(
×
1262
                    xy=(divide_ax, y),
1263
                    width=0.01,
1264
                    height=0.03,
1265
                    color="w",
1266
                    transform=ax_0.transAxes,
1267
                    zorder=10,
1268
                )
1269
                ax_0.add_patch(circle)
×
1270

1271
    # get max and min of all data
1272
    data_min = 1e6
×
1273
    data_max = -1e6
×
1274
    for da in data.values():
×
1275
        if min(da.values) < data_min:
×
1276
            data_min = min(da.values)
×
1277
        if max(da.values) > data_max:
×
1278
            data_max = max(da.values)
×
1279

1280
    # colormap
1281
    if isinstance(cmap, str):
×
1282
        if cmap in plt.colormaps():
×
1283
            cmap = matplotlib.colormaps[cmap]
×
1284
        else:
1285
            try:
×
1286
                cmap = create_cmap(filename=cmap)
×
1287
            except FileNotFoundError as e:
×
1288
                logger.error(e)
×
1289
                pass
×
1290

1291
    elif cmap is None:
×
1292
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1293
        cmap = create_cmap(
×
1294
            get_var_group(path_to_json=cdata, da=list(data.values())[0]),
1295
            divergent=True,
1296
        )
1297

1298
    # create cmap norm
1299
    if cmap_center is not None:
×
1300
        norm = matplotlib.colors.TwoSlopeNorm(cmap_center, vmin=data_min, vmax=data_max)
×
1301
    else:
1302
        norm = matplotlib.colors.Normalize(data_min, data_max)
×
1303

1304
    # plot
NEW
1305
    for (_name, subax), (key, da) in zip(subaxes.items(), data.items(), strict=False):
×
1306
        subax.bar(da.time.dt.year, height=1, width=dtime, color=cmap(norm(da.values)))
×
1307
        if divide:
×
1308
            if key != "_no_label":
×
1309
                subax.text(
×
1310
                    0.99,
1311
                    0.5,
1312
                    key,
1313
                    transform=subax.transAxes,
1314
                    fontsize=14,
1315
                    ha="right",
1316
                    va="center",
1317
                    c="w",
1318
                    weight="bold",
1319
                )
1320

1321
    # colorbar
1322
    if cbar is True:
×
1323
        sm = ScalarMappable(cmap=cmap, norm=norm)
×
1324
        cax = ax.inset_axes([0.01, 0.05, 0.35, 0.06])
×
1325
        cbar_tcks = np.arange(math.floor(data_min), math.ceil(data_max), 2)
×
1326
        # label
1327
        da = list(data.values())[0]
×
1328
        label = get_attributes("long_name", da)
×
1329
        if label != "":
×
1330
            if "units" in da.attrs:
×
1331
                u = da.units
×
1332
                label += f" ({u})"
×
1333
            label = wrap_text(label, max_line_len=40)
×
1334

1335
        cbar_kw = {
×
1336
            "cax": cax,
1337
            "orientation": "horizontal",
1338
            "ticks": cbar_tcks,
1339
            "label": label,
1340
        } | cbar_kw
1341
        plt.colorbar(sm, **cbar_kw)
×
1342
        cax.spines["outline"].set_visible(False)
×
1343
        cax.set_xscale("linear")
×
1344

1345
    return ax
×
1346

1347

1348
def heatmap(
5✔
1349
    data: xr.DataArray | xr.Dataset | dict[str, Any],
1350
    ax: matplotlib.axes.Axes | None = None,
1351
    use_attrs: dict[str, Any] | None = None,
1352
    fig_kw: dict[str, Any] | None = None,
1353
    plot_kw: dict[str, Any] | None = None,
1354
    transpose: bool = False,
1355
    cmap: str | matplotlib.colors.Colormap | None = "RdBu",
1356
    divergent: bool | int | float = False,
1357
) -> matplotlib.axes.Axes:
1358
    """
1359
    Create heatmap from a DataArray.
1360

1361
    Parameters
1362
    ----------
1363
    data : dict or DataArray or Dataset
1364
        Input data do plot. If dictionary, must have only one entry.
1365
    ax : matplotlib axis, optional
1366
        Matplotlib axis on which to plot, with the same projection as the one specified.
1367
    use_attrs : dict, optional
1368
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1369
        Default value is {'cbar_label': 'long_name'}.
1370
        Only the keys found in the default dict can be used.
1371
    fig_kw : dict, optional
1372
        Arguments to pass to `plt.figure()`.
1373
    plot_kw :  dict, optional
1374
        Arguments to pass to the 'seaborn.heatmap()' function.
1375
        If 'data' is a dictionary, can be a nested dictionary with the same key as 'data'.
1376
    transpose : bool
1377
        If true, the 2D data will be transposed, so that the original x-axis becomes the y-axis and vice versa.
1378
    cmap : matplotlib.colors.Colormap or str, optional
1379
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1380
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1381
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
1382
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1383
    divergent : bool or int or float
1384
        If int or float, becomes center of cmap. Default center is 0.
1385

1386
    Returns
1387
    -------
1388
    matplotlib.axes.Axes
1389
    """
1390
    # create empty dicts if None
1391
    use_attrs = empty_dict(use_attrs)
×
1392
    fig_kw = empty_dict(fig_kw)
×
1393
    plot_kw = empty_dict(plot_kw)
×
1394

1395
    # set default use_attrs values
1396
    use_attrs.setdefault("cbar_label", "long_name")
×
1397

1398
    # if data is dict, extract
1399
    if isinstance(data, dict):
×
1400
        if plot_kw and list(data.keys())[0] in plot_kw.keys():
×
1401
            plot_kw = plot_kw[list(data.keys())[0]]
×
1402
        if len(data) == 1:
×
1403
            data = list(data.values())[0]
×
1404
        else:
1405
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
1406

1407
    # select data to plot
1408
    if isinstance(data, xr.DataArray):
×
1409
        da = data
×
1410
    elif isinstance(data, xr.Dataset):
×
1411
        if len(data.data_vars) > 1:
×
1412
            warnings.warn(
×
1413
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
1414
            )
1415
        da = list(data.values())[0]
×
1416
    else:
1417
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
1418

1419
    # setup fig, axis
1420
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
1421
        fig, ax = plt.subplots(**fig_kw)
×
1422
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
1423
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
1424
    elif ax is None:
×
1425
        if any([k != "figsize" for k in fig_kw.keys()]):
×
1426
            warnings.warn(
×
1427
                "Only figsize arguments can be passed to fig_kw when using facetgrid.", stacklevel=2
1428
            )
1429
        plot_kw.setdefault("col", None)
×
1430
        plot_kw.setdefault("row", None)
×
1431
        plot_kw.setdefault("margin_titles", True)
×
1432
        heatmap_dims = list(
×
1433
            set(da.dims)
1434
            - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None}
1435
        )
1436
        if da.name is None:
×
1437
            da = da.to_dataset(name="data").data
×
1438
        da_name = da.name
×
1439

1440
    # create cbar label
1441
    if (
×
1442
        "cbar_units" in use_attrs
1443
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
1444
    ):  # avoids '()' as label
1445
        cbar_label = (
×
1446
            get_attributes(use_attrs["cbar_label"], data)
1447
            + " ("
1448
            + get_attributes(use_attrs["cbar_units"], data)
1449
            + ")"
1450
        )
1451
    else:
1452
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
1453

1454
    # colormap
1455
    if isinstance(cmap, str):
×
1456
        if cmap not in plt.colormaps():
×
1457
            try:
×
1458
                cmap = create_cmap(filename=cmap)
×
1459
            except FileNotFoundError as e:
×
1460
                logger.error(e)
×
1461
                pass
×
1462

1463
    elif cmap is None:
×
1464
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1465
        cmap = create_cmap(
×
1466
            get_var_group(path_to_json=cdata, da=da),
1467
            divergent=divergent,
1468
        )
1469

1470
    # convert data to DataFrame
1471
    if transpose:
×
1472
        da = da.transpose()
×
1473
    if "col" not in plot_kw and "row" not in plot_kw:
×
1474
        if len(da.dims) != 2:
×
1475
            raise ValueError("DataArray must have exactly two dimensions")
×
1476
        df = da.to_pandas()
×
1477
    else:
1478
        if len(heatmap_dims) != 2:
×
1479
            raise ValueError("DataArray must have exactly two dimensions")
×
1480
        df = da.to_dataframe().reset_index()
×
1481

1482
    # set defaults
1483
    if divergent is not False:
×
NEW
1484
        if isinstance(divergent, int | float):
×
1485
            plot_kw.setdefault("center", divergent)
×
1486
        else:
1487
            plot_kw.setdefault("center", 0)
×
1488

1489
    if "cbar" not in plot_kw or plot_kw["cbar"] is not False:
×
1490
        plot_kw.setdefault("cbar_kws", {})
×
1491
        plot_kw["cbar_kws"].setdefault("label", wrap_text(cbar_label))
×
1492

1493
    plot_kw.setdefault("cmap", cmap)
×
1494

1495
    # plot
1496
    def draw_heatmap(*args, **kwargs):
×
1497
        data = kwargs.pop("data")
×
1498
        d = (
×
1499
            data
1500
            if len(args) == 0
1501
            # Any sorting should be performed before sending a DataArray in `fg.heatmap`
1502
            else data.pivot_table(
1503
                index=args[1], columns=args[0], values=args[2], sort=False
1504
            )
1505
        )
1506
        ax = sns.heatmap(d, **kwargs)
×
1507
        ax.set_xticklabels(
×
1508
            ax.get_xticklabels(),
1509
            rotation=45,
1510
            ha="right",
1511
            rotation_mode="anchor",
1512
        )
1513
        ax.tick_params(axis="both", direction="out")
×
1514
        set_plot_attrs(
×
1515
            use_attrs,
1516
            da,
1517
            ax,
1518
            title_loc="center",
1519
            wrap_kw={"min_line_len": 35, "max_line_len": 44},
1520
        )
1521
        return ax
×
1522

1523
    if ax is not None:
×
1524
        ax = draw_heatmap(data=df, ax=ax, **plot_kw)
×
1525
        return ax
×
1526
    elif "col" in plot_kw or "row" in plot_kw:
×
1527
        # When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function
1528
        # With Seaborn, we need to be more careful and separate keywords.
1529
        plot_kw_hm = {
×
1530
            k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters
1531
        }
1532
        plot_kw_fg = {
×
1533
            k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters
1534
        }
1535
        unused_keys = (
×
1536
            set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys())
1537
        )
1538
        if unused_keys != set():
×
1539
            raise ValueError(
×
1540
                f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords "
1541
                "allowed in `sns.heatmap` or `sns.FacetGrid`. "
1542
            )
1543

1544
        g = sns.FacetGrid(df, **plot_kw_fg)
×
1545
        cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9])
×
1546
        g.map_dataframe(
×
1547
            draw_heatmap,
1548
            *heatmap_dims,
1549
            da_name,
1550
            **plot_kw_hm,
1551
            cbar=True,
1552
            cbar_ax=cax,
1553
        )
1554
        g.fig.subplots_adjust(right=0.9)
×
1555
        if "figsize" in fig_kw.keys():
×
1556
            g.fig.set_size_inches(*fig_kw["figsize"])
×
1557
        return g
×
1558

1559

1560
def scattermap(
5✔
1561
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
1562
    ax: matplotlib.axes.Axes | None = None,
1563
    use_attrs: dict[str, Any] | None = None,
1564
    fig_kw: dict[str, Any] | None = None,
1565
    plot_kw: dict[str, Any] | None = None,
1566
    projection: ccrs.Projection = ccrs.LambertConformal(),
1567
    transform: ccrs.Projection | None = None,
1568
    features: list[str] | dict[str, dict[str, Any]] | None = None,
1569
    geometries_kw: dict[str, Any] | None = None,
1570
    sizes: str | bool | None = None,
1571
    size_range: tuple = (10, 60),
1572
    cmap: str | matplotlib.colors.Colormap | None = None,
1573
    levels: int | None = None,
1574
    divergent: bool | int | float = False,
1575
    legend_kw: dict[str, Any] | None = None,
1576
    show_time: bool | str | int | tuple[float, float] = False,
1577
    frame: bool = False,
1578
    enumerate_subplots: bool = False,
1579
) -> matplotlib.axes.Axes:
1580
    """
1581
    Make a scatter plot of georeferenced data on a map.
1582

1583
    Parameters
1584
    ----------
1585
    data : dict, DataArray or Dataset
1586
        Input data do plot. If dictionary, must have only one entry.
1587
    ax : matplotlib axis, optional
1588
        Matplotlib axis on which to plot, with the same projection as the one specified.
1589
    use_attrs : dict, optional
1590
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1591
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1592
        Only the keys found in the default dict can be used.
1593
    fig_kw : dict, optional
1594
        Arguments to pass to `plt.figure()`.
1595
    plot_kw :  dict, optional
1596
        Arguments to pass to `plt.scatter()`.
1597
        If 'data' is a dictionary, can be a dictionary with the same key as 'data'.
1598
    projection : ccrs.Projection
1599
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1600
    transform : ccrs.Projection, optional
1601
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
1602
        ccrs.PlateCarree() or ccrs.RotatedPole().
1603
    features : list or dict, optional
1604
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1605
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1606
    geometries_kw : dict, optional
1607
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1608
    sizes : bool or str, optional
1609
        String name of the coordinate to use for determining point size. If True, use the same data as in the colorbar.
1610
    size_range : tuple
1611
        Tuple of the minimum and maximum size of the points.
1612
    cmap : matplotlib.colors.Colormap or str, optional
1613
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1614
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1615
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
1616
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1617
    levels : int, optional
1618
        Number of levels to divide the colormap into.
1619
    divergent : bool or int or float
1620
        If int or float, becomes center of cmap. Default center is 0.
1621
    legend_kw : dict, optional
1622
        Arguments to pass to plt.legend(). Some defaults {"loc": "lower left", "facecolor": "w", "framealpha": 1,
1623
            "edgecolor": "w", "bbox_to_anchor": (-0.05, 0)}
1624
    show_time : bool, tuple, string or int.
1625
        If True, show time (as date) at the bottom right of the figure.
1626
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
1627
        of the text. If a string or an int, the same values as those of the 'loc' parameter
1628
        of matplotlib's legends are accepted.
1629

1630
        ==================   =============
1631
        Location String      Location Code
1632
        ==================   =============
1633
        'upper right'        1
1634
        'upper left'         2
1635
        'lower left'         3
1636
        'lower right'        4
1637
        'right'              5
1638
        'center left'        6
1639
        'center right'       7
1640
        'lower center'       8
1641
        'upper center'       9
1642
        'center'             10
1643
        ==================   =============
1644
    frame : bool
1645
        Show or hide frame. Default False.
1646
    enumerate_subplots: bool
1647
        If True, enumerate subplots with letters.
1648
        Only works with facetgrids (pass `col` or `row` in plot_kw).
1649

1650
    Returns
1651
    -------
1652
    matplotlib.axes.Axes
1653
    """
1654
    # create empty dicts if None
1655
    use_attrs = empty_dict(use_attrs)
×
1656
    fig_kw = empty_dict(fig_kw)
×
1657
    plot_kw = empty_dict(plot_kw)
×
1658
    legend_kw = empty_dict(legend_kw)
×
1659

1660
    # set default use_attrs values
1661
    use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs
×
1662
    if "row" not in plot_kw and "col" not in plot_kw:
×
1663
        use_attrs.setdefault("title", "description")
×
1664

1665
    # extract plot_kw from dict if needed
1666
    if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys():
×
1667
        plot_kw = plot_kw[list(data.keys())[0]]
×
1668

1669
    # figanos does not use xr.plot.scatter default markersize
1670
    if "markersize" in plot_kw.keys():
×
1671
        if not sizes:
×
1672
            sizes = plot_kw["markersize"]
×
1673
        plot_kw.pop("markersize")
×
1674

1675
    # if data is dict, extract
1676
    if isinstance(data, dict):
×
1677
        if len(data) == 1:
×
1678
            data = list(data.values())[0].squeeze()
×
1679
            if len(data.data_vars) > 1:
×
1680
                warnings.warn(
×
1681
                    "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
1682
                )
1683
        else:
1684
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
1685

1686
    # select data to plot and its xr.Dataset
1687
    if isinstance(data, xr.DataArray):
×
1688
        plot_data = data
×
1689
        data = xr.Dataset({plot_data.name: plot_data})
×
1690
    elif isinstance(data, xr.Dataset):
×
1691
        if len(data.data_vars) > 1:
×
1692
            warnings.warn(
×
1693
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
1694
            )
1695
        plot_data = data[list(data.keys())[0]]
×
1696
    else:
1697
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
1698

1699
    # setup transform
1700
    if transform is None:
×
1701
        if "rlat" in data.dims and "rlon" in data.dims:
×
1702
            transform = get_rotpole(data)
×
1703
        elif (
×
1704
            "lat" in data.coords and "lon" in data.coords
1705
        ):  # need to work with station dims
1706
            transform = ccrs.PlateCarree()
×
1707

1708
    # setup fig, ax
1709
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
1710
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
1711
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
1712
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
1713
    elif ax is None:
×
1714
        plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
1715
        cfig_kw = fig_kw.copy()
×
1716
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
1717
            plot_kw.setdefault("figsize", fig_kw["figsize"])
×
1718
            cfig_kw.pop("figsize")
×
1719
        if len(cfig_kw) >= 1:
×
1720
            plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
1721
            warnings.warn(
×
1722
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2
1723
            )
1724

1725
    # create cbar label
1726
    if (
×
1727
        "cbar_units" in use_attrs
1728
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
1729
    ):  # avoids '[]' as label
1730
        cbar_label = (
×
1731
            get_attributes(use_attrs["cbar_label"], data)
1732
            + " ("
1733
            + get_attributes(use_attrs["cbar_units"], data)
1734
            + ")"
1735
        )
1736
    else:
1737
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
1738

1739
    if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False:
×
1740
        plot_kw.setdefault("cbar_kwargs", {})
×
1741
        plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
×
1742
        plot_kw["cbar_kwargs"].setdefault("pad", 0.015)
×
1743

1744
    # colormap
1745
    if isinstance(cmap, str):
×
1746
        if cmap not in plt.colormaps():
×
1747
            try:
×
1748
                cmap = create_cmap(filename=cmap)
×
1749
            except FileNotFoundError as e:
×
1750
                logger.error(e)
×
1751
                pass
×
1752

1753
    elif cmap is None:
×
1754
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1755
        cmap = create_cmap(
×
1756
            get_var_group(path_to_json=cdata, da=plot_data),
1757
            divergent=divergent,
1758
        )
1759

1760
    # nans (not required for plotting since xarray.plot handles np.nan, but needs to be found for sizes legend and to
1761
    # inform user on how many stations were dropped)
1762
    mask = ~np.isnan(plot_data.values)
×
1763
    if np.sum(mask) < len(mask):
×
1764
        warnings.warn(
×
1765
            f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the color values", stacklevel=2
1766
        )
1767

1768
    # point sizes
1769
    if sizes:
×
1770
        if sizes is True:
×
1771
            sdata = plot_data
×
1772
        elif isinstance(sizes, str):
×
NEW
1773
            if hasattr(data, "name") and data.name == sizes:
×
1774
                sdata = plot_data
×
1775
            elif sizes in list(data.coords.keys()):
×
1776
                sdata = plot_data[sizes]
×
1777
            else:
1778
                raise ValueError(f"{sizes} not found")
×
1779
        else:
1780
            raise TypeError("sizes must be a string or a bool")
×
1781

1782
        # nans sizes
1783
        smask = ~np.isnan(sdata.values) & mask
×
1784
        if np.sum(smask) < np.sum(mask):
×
1785
            warnings.warn(
×
1786
                f"{np.sum(mask) - np.sum(smask)} nan values were dropped when setting the point size", stacklevel=2
1787
            )
1788
            mask = smask
×
1789

1790
        pt_sizes = norm2range(
×
1791
            data=sdata.where(mask).values,
1792
            target_range=size_range,
1793
            data_range=None,
1794
        )
1795
        plot_kw.setdefault("add_legend", False)
×
1796
        if ax:
×
1797
            plot_kw.setdefault("s", pt_sizes)
×
1798
        else:
1799
            plot_kw.setdefault("s", pt_sizes[0])
×
1800

1801
    # norm
1802
    plot_kw.setdefault("vmin", np.nanmin(plot_data.values[mask]))
×
1803
    plot_kw.setdefault("vmax", np.nanmax(plot_data.values[mask]))
×
1804
    if levels is not None:
×
1805
        if isinstance(levels, Iterable):
×
1806
            lin = levels
×
1807
        else:
1808
            lin = custom_cmap_norm(
×
1809
                cmap,
1810
                np.nanmin(plot_data.values[mask]),
1811
                np.nanmax(plot_data.values[mask]),
1812
                levels=levels,
1813
                divergent=divergent,
1814
                linspace_out=True,
1815
            )
1816
        plot_kw.setdefault("levels", lin)
×
1817

1818
    elif (divergent is not False) and ("levels" not in plot_kw):
×
1819
        vmin = plot_kw.pop("vmin", np.nanmin(plot_data.values[mask]))
×
1820
        vmax = plot_kw.pop("vmax", np.nanmax(plot_data.values[mask]))
×
1821
        norm = custom_cmap_norm(
×
1822
            cmap,
1823
            vmin,
1824
            vmax,
1825
            levels=levels,
1826
            divergent=divergent,
1827
        )
1828
        plot_kw.setdefault("norm", norm)
×
1829

1830
    # matplotlib.pyplot.scatter treats "edgecolor" and "edgecolors" as aliases so we accept "edgecolor" and convert it
1831
    if "edgecolor" in plot_kw and "edgecolors" not in plot_kw:
×
1832
        plot_kw["edgecolors"] = plot_kw["edgecolor"]
×
1833
        plot_kw.pop("edgecolor")
×
1834

1835
    # set defaults and create copy without vmin, vmax (conflicts with norm)
1836
    plot_kw = {
×
1837
        "cmap": cmap,
1838
        "transform": transform,
1839
        "zorder": 8,
1840
        "marker": "o",
1841
    } | plot_kw
1842

1843
    # check if edgecolors in plot_kw and match len of plot_data
1844
    if "edgecolors" in plot_kw:
×
1845
        if matplotlib.colors.is_color_like(plot_kw["edgecolors"]):
×
1846
            plot_kw["edgecolors"] = np.repeat(
×
1847
                plot_kw["edgecolors"], len(plot_data.where(mask).values)
1848
            )
1849
        elif len(plot_kw["edgecolors"]) != len(plot_data.values):
×
1850
            plot_kw["edgecolors"] = np.repeat(
×
1851
                plot_kw["edgecolors"][0], len(plot_data.where(mask).values)
1852
            )
1853
            warnings.warn(
×
1854
                "Length of edgecolors does not match length of data. Only first edgecolor is used for plotting.", stacklevel=2
1855
            )
1856
        else:
1857
            if isinstance(plot_kw["edgecolors"], list):
×
1858
                plot_kw["edgecolors"] = np.array(plot_kw["edgecolors"])
×
1859
            plot_kw["edgecolors"] = plot_kw["edgecolors"][mask]
×
1860
    else:
1861
        plot_kw.setdefault("edgecolors", "none")
×
1862

1863
    for key in ["vmin", "vmax"]:
×
1864
        plot_kw.pop(key, None)
×
1865
    # plot
1866
    plot_kw = {"x": "lon", "y": "lat", "hue": plot_data.name} | plot_kw
×
1867
    if ax:
×
1868
        plot_kw.setdefault("ax", ax)
×
1869

1870
    plot_data_masked = plot_data.where(mask).to_dataset()
×
1871
    im = plot_data_masked.plot.scatter(**plot_kw)
×
1872

1873
    # add features
1874
    if ax:
×
1875
        ax = add_features_map(
×
1876
            data,
1877
            ax,
1878
            use_attrs,
1879
            projection,
1880
            features,
1881
            geometries_kw,
1882
            frame,
1883
        )
1884

1885
        if show_time:
×
1886
            if isinstance(show_time, bool):
×
1887
                plot_coords(
×
1888
                    ax,
1889
                    plot_data,
1890
                    param="time",
1891
                    loc="lower right",
1892
                    backgroundalpha=1,
1893
                )
NEW
1894
            elif isinstance(show_time, str | tuple | int):
×
1895
                plot_coords(
×
1896
                    ax,
1897
                    plot_data,
1898
                    param="time",
1899
                    loc=show_time,
1900
                    backgroundalpha=1,
1901
                )
1902

1903
        if (frame is False) and (im.colorbar is not None):
×
1904
            im.colorbar.outline.set_visible(False)
×
1905

1906
    else:
1907
        for i, fax in enumerate(im.axs.flat):
×
1908
            fax = add_features_map(
×
1909
                data,
1910
                fax,
1911
                use_attrs,
1912
                projection,
1913
                features,
1914
                geometries_kw,
1915
                frame,
1916
            )
1917

1918
            if sizes:
×
1919
                # correct markersize for facetgrid
1920
                scat = fax.collections[0]
×
1921
                scat.set_sizes(pt_sizes[i])
×
1922

1923
        if (frame is False) and (im.cbar is not None):
×
1924
            im.cbar.outline.set_visible(False)
×
1925

1926
        if show_time:
×
1927
            if isinstance(show_time, bool):
×
1928
                plot_coords(
×
1929
                    None,
1930
                    plot_data,
1931
                    param="time",
1932
                    loc="lower right",
1933
                    backgroundalpha=1,
1934
                )
NEW
1935
            elif isinstance(show_time, str | tuple | int):
×
1936
                plot_coords(
×
1937
                    None,
1938
                    plot_data,
1939
                    param="time",
1940
                    loc=show_time,
1941
                    backgroundalpha=1,
1942
                )
1943

1944
    # size legend
1945
    if sizes:
×
1946
        legend_elements = size_legend_elements(
×
1947
            np.resize(sdata.values[mask], (sdata.values[mask].size, 1)),
1948
            np.resize(pt_sizes[mask], (pt_sizes[mask].size, 1)),
1949
            max_entries=6,
1950
            marker=plot_kw["marker"],
1951
        )
1952
        # legend spacing
1953
        if size_range[1] > 200:
×
1954
            ls = 0.5 + size_range[1] / 100 * 0.125
×
1955
        else:
1956
            ls = 0.5
×
1957

1958
        legend_kw = {
×
1959
            "loc": "lower left",
1960
            "facecolor": "w",
1961
            "framealpha": 1,
1962
            "edgecolor": "w",
1963
            "labelspacing": ls,
1964
            "handles": legend_elements,
1965
            "bbox_to_anchor": (-0.05, -0.1),
1966
        } | legend_kw
1967

1968
        if "title" not in legend_kw:
×
1969
            if hasattr(sdata, "long_name"):
×
1970
                lgd_title = wrap_text(
×
1971
                    sdata.long_name, min_line_len=1, max_line_len=15
1972
                )
1973
                if hasattr(sdata, "units"):
×
NEW
1974
                    lgd_title += f" ({sdata.units})"
×
1975
            else:
1976
                lgd_title = sizes
×
1977
            legend_kw.setdefault("title", lgd_title)
×
1978

1979
        if ax:
×
1980
            lgd = ax.legend(**legend_kw)
×
1981
            lgd.set_zorder(11)
×
1982
        else:
1983
            im.figlegend = im.fig.legend(**legend_kw)
×
1984
        # im._adjust_fig_for_guide(im.figlegend)
1985

1986
    if ax:
×
1987
        return ax
×
1988
    else:
1989
        im.fig.suptitle(get_attributes("long_name", data))
×
1990
        im.set_titles(template="{value}")
×
1991
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
1992
            for idx, ax in enumerate(im.axs.flat):
×
1993
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
1994

1995
        return im
×
1996

1997

1998
def taylordiagram(
5✔
1999
    data: xr.DataArray | dict[str, xr.DataArray],
2000
    plot_kw: dict[str, Any] | None = None,
2001
    fig_kw: dict[str, Any] | None = None,
2002
    std_range: tuple = (0, 1.5),
2003
    contours: int | None = 4,
2004
    contours_kw: dict[str, Any] | None = None,
2005
    ref_std_line: bool = False,
2006
    legend_kw: dict[str, Any] | None = None,
2007
    std_label: str | None = None,
2008
    corr_label: str | None = None,
2009
    colors_key: str | None = None,
2010
    markers_key: str | None = None,
2011
):
2012
    """
2013
    Build a Taylor diagram.
2014

2015
    Based on the following code: https://gist.github.com/ycopin/3342888.
2016

2017
    Parameters
2018
    ----------
2019
    data : xr.DataArray or dict
2020
        DataArray or dictionary of DataArrays created by xclim.sdba.measures.taylordiagram, each corresponding
2021
        to a point on the diagram. The dictionary keys will become their labels.
2022
    plot_kw : dict, optional
2023
        Arguments to pass to the `plot()` function. Changes how the markers look.
2024
        If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'.
2025
    fig_kw : dict, optional
2026
        Arguments to pass to `plt.figure()`.
2027
    std_range : tuple
2028
        Range of the x and y axes, in units of the highest standard deviation in the data.
2029
    contours : int, optional
2030
        Number of rsme contours to plot.
2031
    contours_kw : dict, optional
2032
        Arguments to pass to `plt.contour()` for the rmse contours.
2033
    ref_std_line : bool, optional
2034
        If True, draws a circular line on radius `std = ref_std`. Default: False
2035
    legend_kw : dict, optional
2036
        Arguments to pass to `plt.legend()`.
2037
    std_label : str, optional
2038
        Label for the standard deviation (x and y) axes.
2039
    corr_label : str, optional
2040
        Label for the correlation axis.
2041
    colors_key : str, optional
2042
        Attribute or dimension of DataArrays used to separate DataArrays into groups with different colors. If present,
2043
        it overrides the "color" key in `plot_kw`.
2044
    markers_key : str, optional
2045
        Attribute or dimension of DataArrays used to separate DataArrays into groups with different markers. If present,
2046
        it overrides the "marker" key in `plot_kw`.
2047

2048
    Returns
2049
    -------
2050
    (plt.figure, mpl_toolkits.axisartist.floating_axes.FloatingSubplot, plt.legend)
2051
    """
2052
    plot_kw = empty_dict(plot_kw)
×
2053
    fig_kw = empty_dict(fig_kw)
×
2054
    contours_kw = empty_dict(contours_kw)
×
2055
    legend_kw = empty_dict(legend_kw)
×
2056

2057
    # preserve order of dimensions if used for marker/color
2058
    ordered_markers_type = None
×
2059
    ordered_colors_type = None
×
2060

2061
    # convert SSP, RCP, CMIP formats in keys
2062
    if isinstance(data, dict):
×
2063
        data = process_keys(data, convert_scen_name)
×
2064
    if isinstance(plot_kw, dict):
×
2065
        plot_kw = process_keys(plot_kw, convert_scen_name)
×
2066

2067
    # if only one data input, insert in dict.
2068
    if not isinstance(data, dict):
×
2069
        data = {"_no_label": data}  # mpl excludes labels starting with "_" from legend
×
2070
        plot_kw = {"_no_label": empty_dict(plot_kw)}
×
2071
    elif not plot_kw:
×
2072
        plot_kw = {k: {} for k in data.keys()}
×
2073
    # check type
2074
    for key, v in data.items():
×
2075
        if not isinstance(v, xr.DataArray):
×
2076
            raise TypeError("All objects in 'data' must be xarray DataArrays.")
×
2077
        if "taylor_param" not in v.dims:
×
2078
            raise ValueError("All DataArrays must contain a 'taylor_param' dimension.")
×
2079
        if key == "reference":
×
2080
            raise ValueError("'reference' is not allowed as a key in data.")
×
2081

2082
    # If there are other dimensions than 'taylor_param', create a bigger dict with them
2083
    data_keys = list(data.keys())
×
2084
    for data_key in data_keys:
×
2085
        da = data[data_key]
×
2086
        dims = list(set(da.dims) - {"taylor_param"})
×
2087
        if dims != []:
×
2088
            if markers_key in dims:
×
2089
                ordered_markers_type = da[markers_key].values
×
2090
            if colors_key in dims:
×
2091
                ordered_colors_type = da[colors_key].values
×
2092

2093
            da = da.stack(pl_dims=dims)
×
2094
            for i, dim_key in enumerate(da.pl_dims.values):
×
2095
                if isinstance(dim_key, list) or isinstance(dim_key, tuple):
×
2096
                    dim_key = "-".join([str(k) for k in dim_key])
×
2097
                da0 = da.isel(pl_dims=i)
×
2098
                # if colors_key/markers_key is a dimension, add it as an attribute for later use
2099
                if markers_key in dims:
×
2100
                    da0.attrs[markers_key] = da0[markers_key].values.item()
×
2101
                if colors_key in dims:
×
2102
                    da0.attrs[colors_key] = da0[colors_key].values.item()
×
2103
                new_data_key = (
×
2104
                    f"{data_key}-{dim_key}" if data_key != "_no_label" else dim_key
2105
                )
2106
                data[new_data_key] = da0
×
2107
                plot_kw[new_data_key] = empty_dict(plot_kw[f"{data_key}"])
×
2108
            data.pop(data_key)
×
2109
            plot_kw.pop(data_key)
×
2110

2111
    # remove negative correlations
2112
    initial_len = len(data)
×
2113
    removed = [
×
2114
        key for key, da in data.items() if da.sel(taylor_param="corr").values < 0
2115
    ]
2116
    data = {
×
2117
        key: da for key, da in data.items() if da.sel(taylor_param="corr").values >= 0
2118
    }
2119
    if len(data) != initial_len:
×
2120
        warnings.warn(
×
2121
            f"{initial_len - len(data)} points with negative correlations will not be plotted: {', '.join(removed)}", stacklevel=2
2122
        )
2123

2124
    # add missing keys to plot_kw
2125
    for key in data.keys():
×
2126
        if key not in plot_kw:
×
2127
            plot_kw[key] = {}
×
2128

2129
    # extract ref to be used in plot
2130
    ref_std = list(data.values())[0].sel(taylor_param="ref_std").values
×
2131
    # check if ref is the same in all DataArrays and get the highest std (for ax limits)
2132
    if len(data) > 1:
×
NEW
2133
        for da in data.values():
×
2134
            if da.sel(taylor_param="ref_std").values != ref_std:
×
2135
                raise ValueError(
×
2136
                    "All reference standard deviation values must be identical"
2137
                )
2138

2139
    # get highest std for axis limits
2140
    max_std = [ref_std]
×
NEW
2141
    for da in data.values():
×
NEW
2142
        max_std.extend(
×
2143
            [
2144
                max(
2145
                    da.sel(taylor_param="ref_std").values,
2146
                    da.sel(taylor_param="sim_std").values,
2147
                ).astype(float)
2148
            ]
2149
        )
2150

2151
    # make labels
2152
    if not std_label:
×
2153
        try:
×
2154
            units = list(data.values())[0].units
×
2155
            std_label = get_localized_term("standard deviation")
×
2156
            std_label = std_label if units == "" else f"{std_label} ({units})"
×
2157
        except AttributeError:
×
2158
            std_label = get_localized_term("standard deviation").capitalize()
×
2159

2160
    if not corr_label:
×
2161
        try:
×
2162
            if "Pearson" in list(data.values())[0].correlation_type:
×
2163
                corr_label = get_localized_term("pearson correlation").capitalize()
×
2164
            else:
2165
                corr_label = get_localized_term("correlation").capitalize()
×
2166
        except AttributeError:
×
2167
            corr_label = get_localized_term("correlation").capitalize()
×
2168

2169
    # build diagram
2170
    transform = PolarAxes.PolarTransform()
×
2171

2172
    # Setup the axis, here we map angles in degrees to angles in radius
2173
    # Correlation labels
2174
    rlocs = np.array([0, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1])
×
2175
    tlocs = np.arccos(rlocs)  # Conversion to polar angles
×
2176
    gl1 = gf.FixedLocator(tlocs)  # Positions
×
NEW
2177
    tf1 = gf.DictFormatter(dict(zip(tlocs, map(str, rlocs), strict=False)))
×
2178
    # Standard deviation axis extent
2179
    radius_min = std_range[0] * max(max_std)
×
2180
    radius_max = std_range[1] * max(max_std)
×
2181

2182
    # Set up the axes range in the parameter "extremes"
2183
    ghelper = GridHelperCurveLinear(
×
2184
        transform,
2185
        extremes=(0, np.pi / 2, radius_min, radius_max),
2186
        grid_locator1=gl1,
2187
        tick_formatter1=tf1,
2188
    )
2189

2190
    fig = plt.figure(**fig_kw)
×
2191
    floating_ax = FloatingSubplot(fig, 111, grid_helper=ghelper)
×
2192
    fig.add_subplot(floating_ax)
×
2193

2194
    # Adjust axes
2195
    floating_ax.axis["top"].set_axis_direction("bottom")  # "Angle axis"
×
2196
    floating_ax.axis["top"].toggle(ticklabels=True, label=True)
×
2197
    floating_ax.axis["top"].major_ticklabels.set_axis_direction("top")
×
2198
    floating_ax.axis["top"].label.set_axis_direction("top")
×
2199
    floating_ax.axis["top"].label.set_text(corr_label)
×
2200

2201
    floating_ax.axis["left"].set_axis_direction("bottom")  # "X axis"
×
2202
    floating_ax.axis["left"].label.set_text(std_label)
×
2203

2204
    floating_ax.axis["right"].set_axis_direction("top")  # "Y axis"
×
2205
    floating_ax.axis["right"].toggle(ticklabels=True, label=True)
×
2206
    floating_ax.axis["right"].major_ticklabels.set_axis_direction("left")
×
2207
    floating_ax.axis["right"].label.set_text(std_label)
×
2208

2209
    floating_ax.axis["bottom"].set_visible(False)  # Useless
×
2210

2211
    # Contours along standard deviations
2212
    floating_ax.grid(visible=True, alpha=0.4)
×
2213
    floating_ax.set_title("")
×
2214

2215
    ax = floating_ax.get_aux_axes(transform)  # return the axes that can be plotted on
×
2216

2217
    # plot reference
2218
    if "reference" in plot_kw:
×
2219
        ref_kw = plot_kw.pop("reference")
×
2220
    else:
2221
        ref_kw = {}
×
2222
    ref_kw = {
×
2223
        "color": "#154504",
2224
        "marker": "s",
2225
        "label": get_localized_term("reference"),
2226
    } | ref_kw
2227

2228
    ref_pt = ax.scatter(0, ref_std, **ref_kw)
×
2229

2230
    points = [ref_pt]  # set up for later
×
2231

2232
    # plot a circular line along `ref_std`
2233
    if ref_std_line:
×
2234
        angles_for_line = np.linspace(0, np.pi / 2, 100)
×
2235
        radii_for_line = np.full_like(angles_for_line, ref_std)
×
2236
        ax.plot(
×
2237
            angles_for_line,
2238
            radii_for_line,
2239
            color=ref_kw["color"],
2240
            linewidth=0.5,
2241
            linestyle="-",
2242
        )
2243

2244
    # rmse contours from reference standard deviation
2245
    if contours:
×
2246
        radii, angles = np.meshgrid(
×
2247
            np.linspace(radius_min, radius_max),
2248
            np.linspace(0, np.pi / 2),
2249
        )
2250
        # Compute centered RMS difference
2251
        rms = np.sqrt(ref_std**2 + radii**2 - 2 * ref_std * radii * np.cos(angles))
×
2252

2253
        contours_kw = {"linestyles": "--", "linewidths": 0.5} | contours_kw
×
2254
        ct = ax.contour(angles, radii, rms, levels=contours, **contours_kw)
×
2255

2256
        ax.clabel(ct, ct.levels, fontsize=8)
×
2257

2258
        # points.append(ct_line)
2259
        ct_line = ax.plot(
×
2260
            [0],
2261
            [0],
2262
            ls=contours_kw["linestyles"],
2263
            lw=1,
2264
            c="k" if "colors" not in contours_kw else contours_kw["colors"],
2265
            label="rmse",
2266
        )
2267
        points.append(ct_line[0])
×
2268

2269
    # get color options
2270
    style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]
×
2271
    if len(data) > len(style_colors):
×
2272
        style_colors = style_colors * math.ceil(len(data) / len(style_colors))
×
2273
    cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
×
2274
    # get marker options (only used if `markers_key` is set)
2275
    style_markers = "oDv^<>p*hH+x|_"
×
2276
    if len(data) > len(style_markers):
×
2277
        style_markers = style_markers * math.ceil(len(data) / len(style_markers))
×
2278

2279
    # set colors and markers styles based on discrimnating attributes (if specified)
2280
    if colors_key or markers_key:
×
2281
        if colors_key:
×
2282
            # get_scen_color : look for SSP, RCP, CMIP model color
2283
            colors_type = (
×
2284
                ordered_colors_type
2285
                if ordered_colors_type is not None
2286
                else {da.attrs[colors_key] for da in data.values()}
2287
            )
2288
            colorsd = {
×
2289
                k: get_scen_color(k, cat_colors) or style_colors[i]
2290
                for i, k in enumerate(colors_type)
2291
            }
2292
        if markers_key:
×
2293
            markers_type = (
×
2294
                ordered_markers_type
2295
                if ordered_markers_type is not None
2296
                else {da.attrs[markers_key] for da in data.values()}
2297
            )
2298
            markersd = {k: style_markers[i] for i, k in enumerate(markers_type)}
×
2299

2300
        for key, da in data.items():
×
2301
            if colors_key:
×
2302
                plot_kw[key]["color"] = colorsd[da.attrs[colors_key]]
×
2303
            if markers_key:
×
2304
                plot_kw[key]["marker"] = markersd[da.attrs[markers_key]]
×
2305

2306
    # plot scatter
NEW
2307
    for (key, da), i in zip(data.items(), range(len(data)), strict=False):
×
2308
        # look for SSP, RCP, CMIP model color
2309
        if colors_key is None:
×
2310
            plot_kw[key].setdefault(
×
2311
                "color", get_scen_color(key, cat_colors) or style_colors[i]
2312
            )
2313
        # set defaults
2314
        plot_kw[key] = {"label": key} | plot_kw[key]
×
2315

2316
        # legend will be handled later in this case
2317
        if markers_key or colors_key:
×
2318
            plot_kw[key]["label"] = ""
×
2319

2320
        # plot
2321
        pt = ax.scatter(
×
2322
            np.arccos(da.sel(taylor_param="corr").values),
2323
            da.sel(taylor_param="sim_std").values,
2324
            **plot_kw[key],
2325
        )
2326
        points.append(pt)
×
2327

2328
    # legend
2329
    legend_kw.setdefault("loc", "upper right")
×
2330
    legend = fig.legend(points, [pt.get_label() for pt in points], **legend_kw)
×
2331

2332
    # plot new legend if markers/colors represent a certain dimension
2333
    if colors_key or markers_key:
×
2334
        handles = list(floating_ax.get_legend_handles_labels()[0])
×
2335
        if markers_key:
×
2336
            for k, m in markersd.items():
×
2337
                handles.append(Line2D([0], [0], color="k", label=k, marker=m, ls=""))
×
2338
        if colors_key:
×
2339
            for k, c in colorsd.items():
×
2340
                handles.append(Line2D([0], [0], color=c, label=k, ls="-"))
×
2341
        legend.remove()
×
2342
        legend = fig.legend(handles=handles, **legend_kw)
×
2343

2344
    return fig, floating_ax, legend
×
2345

2346

2347
def hatchmap(
5✔
2348
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
2349
    ax: matplotlib.axes.Axes | None = None,
2350
    use_attrs: dict[str, Any] | None = None,
2351
    fig_kw: dict[str, Any] | None = None,
2352
    plot_kw: dict[str, Any] | None = None,
2353
    projection: ccrs.Projection = ccrs.LambertConformal(),
2354
    transform: ccrs.Projection | None = None,
2355
    features: list[str] | dict[str, dict[str, Any]] | None = None,
2356
    geometries_kw: dict[str, Any] | None = None,
2357
    levels: int | None = None,
2358
    legend_kw: dict[str, Any] | bool = True,
2359
    show_time: bool | str | int | tuple[float, float] = False,
2360
    frame: bool = False,
2361
    enumerate_subplots: bool = False,
2362
) -> matplotlib.axes.Axes:
2363
    """
2364
    Create map of hatches from 2D data.
2365

2366
    Parameters
2367
    ----------
2368
    data : dict, DataArray or Dataset
2369
        Input data do plot.
2370
    ax : matplotlib axis, optional
2371
        Matplotlib axis on which to plot, with the same projection as the one specified.
2372
    use_attrs : dict, optional
2373
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
2374
        Default value is {'title': 'description'}.
2375
        Only the keys found in the default dict can be used.
2376
    fig_kw : dict, optional
2377
        Arguments to pass to `plt.figure()`.
2378
    plot_kw:  dict, optional
2379
        Arguments to pass to 'xarray.plot.contourf()' function.
2380
        If 'data' is a dictionary, can be a nested dictionary with the same keys as 'data'.
2381
    projection : ccrs.Projection
2382
        The projection to use, taken from the cartopy.ccrs options. Ignored if ax is not None.
2383
    transform : ccrs.Projection, optional
2384
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
2385
        ccrs.PlateCarree() or ccrs.RotatedPole().
2386
    features : list or dict, optional
2387
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
2388
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
2389
    geometries_kw : dict, optional
2390
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
2391
    legend_kw : dict or boolean, optional
2392
        Arguments to pass to `ax.legend()`. No legend is added if legend_kw == False.
2393
    show_time : bool, tuple, string or int.
2394
        If True, show time (as date) at the bottom right of the figure.
2395
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
2396
        of the text. If a string or an int, the same values as those of the 'loc' parameter
2397
        of matplotlib's legends are accepted.
2398

2399
        ==================   =============
2400
        Location String      Location Code
2401
        ==================   =============
2402
        'upper right'        1
2403
        'upper left'         2
2404
        'lower left'         3
2405
        'lower right'        4
2406
        'right'              5
2407
        'center left'        6
2408
        'center right'       7
2409
        'lower center'       8
2410
        'upper center'       9
2411
        'center'             10
2412
        ==================   =============
2413
    frame : bool
2414
        Show or hide frame. Default False.
2415
    enumerate_subplots: bool
2416
        If True, enumerate subplots with letters.
2417
        Only works with facetgrids (pass `col` or `row` in plot_kw).
2418

2419
    Returns
2420
    -------
2421
    matplotlib.axes.Axes
2422
    """
2423
    # default hatches
2424
    dfh = [
×
2425
        "/",
2426
        "\\",
2427
        "|",
2428
        "-",
2429
        "+",
2430
        "x",
2431
        "o",
2432
        "O",
2433
        ".",
2434
        "*",
2435
        "//",
2436
        "\\\\",
2437
        "||",
2438
        "--",
2439
        "++",
2440
        "xx",
2441
        "oo",
2442
        "OO",
2443
        "..",
2444
        "**",
2445
    ]
2446

2447
    # create empty dicts if None
2448
    use_attrs = empty_dict(use_attrs)
×
2449
    fig_kw = empty_dict(fig_kw)
×
2450
    plot_kw = empty_dict(plot_kw)
×
2451
    legend_kw = empty_dict(legend_kw)
×
2452

2453
    dattrs = None
×
2454
    plot_data = {}
×
2455

2456
    # convert data to dict (if not one)
2457
    if not isinstance(data, dict):
×
2458
        if isinstance(data, xr.DataArray):
×
2459
            plot_data = {data.name: data}
×
2460
            if data.name not in plot_kw.keys():
×
2461
                plot_kw = {data.name: plot_kw}
×
2462
        elif isinstance(data, xr.Dataset):
×
2463
            dattrs = data
×
2464
            plot_data = {var: data[var] for var in data.data_vars}
×
2465
            for v in plot_data.keys():
×
2466
                if v not in plot_kw.keys():
×
2467
                    plot_kw[v] = plot_kw
×
2468
    else:
2469
        for k, v in data.items():
×
2470
            if k not in plot_kw.keys():
×
2471
                plot_kw[k] = plot_kw
×
2472
            if isinstance(v, xr.Dataset):
×
2473
                dattrs = k
×
2474
                plot_data[k] = v[list(v.data_vars)[0]]
×
NEW
2475
                warnings.warn("Only first variable of Dataset is plotted.", stacklevel=2)
×
2476
            else:
2477
                plot_data[k] = v
×
2478

2479
    # setup transform from first data entry
2480
    trdata = list(plot_data.values())[0]
×
2481
    if transform is None:
×
2482
        if "lat" in trdata.dims and "lon" in trdata.dims:
×
2483
            transform = ccrs.PlateCarree()
×
2484
        elif "rlat" in trdata.dims and "rlon" in trdata.dims:
×
2485
            transform = get_rotpole(list(plot_data.values())[0])
×
2486

2487
    # bug xlim / ylim + transform in facetgrids
2488
    # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766)
2489
    if transform and (
×
2490
        "xlim" in list(plot_kw.values())[0] and "ylim" in list(plot_kw.values())[0]
2491
    ):
2492
        extent = [
×
2493
            list(plot_kw.values())[0]["xlim"][0],
2494
            list(plot_kw.values())[0]["xlim"][1],
2495
            list(plot_kw.values())[0]["ylim"][0],
2496
            list(plot_kw.values())[0]["ylim"][1],
2497
        ]
2498
        [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]
×
2499

2500
    elif transform and (
×
2501
        "xlim" in list(plot_kw.values())[0] or "ylim" in list(plot_kw.values())[0]
2502
    ):
2503
        extent = None
×
2504
        warnings.warn(
×
2505
            "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped", stacklevel=2
2506
        )
2507
        [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]
×
2508

2509
    else:
2510
        extent = None
×
2511

2512
    # setup fig, ax
2513
    if ax is None and (
×
2514
        "row" not in list(plot_kw.values())[0].keys()
2515
        and "col" not in list(plot_kw.values())[0].keys()
2516
    ):
2517
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
2518
    elif ax is not None and (
×
2519
        "col" in list(plot_kw.values())[0].keys()
2520
        or "row" in list(plot_kw.values())[0].keys()
2521
    ):
2522
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
2523
    elif ax is None:
×
2524
        [
×
2525
            v.setdefault("subplot_kws", {}).setdefault("projection", projection)
2526
            for v in plot_kw.values()
2527
        ]
2528
        cfig_kw = copy.deepcopy(fig_kw)
×
2529
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
2530
            plot_kw[0].setdefault("figsize", fig_kw["figsize"])
×
2531
            cfig_kw.pop("figsize")
×
2532
        if cfig_kw:
×
2533
            for v in plot_kw.values():
×
2534
                {"subplots_kws": cfig_kw} | v
×
2535
            warnings.warn(
×
2536
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid.", stacklevel=2
2537
            )
2538

2539
    pat_leg = []
×
2540
    n = 0
×
2541
    for k, v in plot_data.items():
×
2542
        # if levels plot multiple hatching from one data entry
2543
        if "levels" in plot_kw[k] and len(plot_data) == 1:
×
2544
            # nans
2545
            mask = ~np.isnan(v.values)
×
2546
            if np.sum(mask) < len(mask):
×
2547
                warnings.warn(
×
2548
                    f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the pattern values", stacklevel=2
2549
                )
2550
            if "hatches" in plot_kw[k] and plot_kw[k]["levels"] != len(
×
2551
                plot_kw[k]["hatches"]
2552
            ):
NEW
2553
                warnings.warn("Hatches number is not equivalent to number of levels", stacklevel=2)
×
2554
                hatches = dfh[0:levels]
×
2555
            if "hatches" not in plot_kw[k]:
×
2556
                hatches = dfh[0:levels]
×
2557

2558
            plot_kw[k] = {
×
2559
                "hatches": hatches,
2560
                "colors": "none",
2561
                "add_colorbar": False,
2562
            } | plot_kw[k]
2563

2564
            if "lat" in v.dims:
×
2565
                v.coords["mask"] = (("lat", "lon"), mask)
×
2566
            else:
2567
                v.coords["mask"] = (("rlat", "rlon"), mask)
×
2568

2569
            plot_kw[k].setdefault("transform", transform)
×
2570
            if ax:
×
2571
                plot_kw[k].setdefault("ax", ax)
×
2572

2573
            im = v.where(mask is not True).plot.contourf(**plot_kw[k])
×
2574
            artists, labels = im.legend_elements(str_format="{:2.1f}".format)
×
2575

2576
            if ax and legend_kw:
×
2577
                ax.legend(artists, labels, **legend_kw)
×
2578
            elif legend_kw:
×
2579
                im.figlegend = im.fig.legend(**legend_kw)
×
2580

2581
        elif len(plot_data) > 1 and "levels" in plot_kw[k]:
×
2582
            raise TypeError(
×
2583
                "To plot levels only one xr.DataArray or xr.Dataset accepted"
2584
            )
2585
        else:
2586
            # since pattern remove colors and colorbar from plotting (done by gridmap)
2587
            plot_kw[k] = {"colors": "none", "add_colorbar": False} | plot_kw[k]
×
2588

2589
            if "hatches" not in plot_kw[k].keys():
×
2590
                plot_kw[k]["hatches"] = dfh[n]
×
2591
                n += 1
×
2592
            elif isinstance(
×
2593
                plot_kw[k]["hatches"], str
2594
            ):  # make sure the hatches are in a list
2595
                warnings.warn(
×
2596
                    "Hatches argument must be of type 'list'. Wrapping string argument as list.", stacklevel=2
2597
                )
2598
                plot_kw[k]["hatches"] = [plot_kw[k]["hatches"]]
×
2599

2600
            plot_kw[k].setdefault("transform", transform)
×
2601
            if ax:
×
2602
                im = v.plot.contourf(ax=ax, **plot_kw[k])
×
2603

2604
            if not ax:
×
2605
                if k == list(plot_data.keys())[0]:
×
2606
                    c_pkw = plot_kw[k].copy()
×
2607
                    if "col" in plot_kw[k].keys() or "row" in plot_kw[k].keys():
×
2608
                        if c_pkw["colors"] == "none":
×
2609
                            c_pkw.pop("colors")
×
2610
                        im = v.plot.contourf(**c_pkw)
×
2611

2612
                for i, fax in enumerate(im.axs.flat):
×
2613
                    if (
×
2614
                        k == list(plot_data.keys())[0]
2615
                        and plot_kw[k]["colors"] == "none"
2616
                    ):
2617
                        fax.clear()
×
2618
                    if len(plot_data) > 1:
×
2619
                        # select data to plot from DataSet in loop to plot on facetgrids axis
2620
                        c_pkw = plot_kw[k].copy()
×
2621
                        c_pkw.pop("subplot_kws")
×
2622
                        sel = {}
×
2623
                        if "row" in c_pkw.keys():
×
2624
                            sel[c_pkw["row"]] = i
×
2625
                            c_pkw.pop("row")
×
2626
                        elif "col" in c_pkw.keys():
×
2627
                            sel[c_pkw["col"]] = i
×
2628
                            c_pkw.pop("col")
×
2629
                        v.isel(sel).plot.contourf(ax=fax, **c_pkw)
×
2630

2631
                    if k == list(plot_data.keys())[-1]:
×
2632
                        add_features_map(
×
2633
                            dattrs,
2634
                            fax,
2635
                            use_attrs,
2636
                            projection,
2637
                            features,
2638
                            geometries_kw,
2639
                            frame,
2640
                        )
2641
                        if extent:
×
2642
                            fax.set_extent(extent)
×
2643

2644
            pat_leg.append(
×
2645
                matplotlib.patches.Patch(
2646
                    hatch=plot_kw[k]["hatches"][0], fill=False, label=k
2647
                )
2648
            )
2649

2650
    if pat_leg and legend_kw:
×
2651
        legend_kw = {
×
2652
            "loc": "lower right",
2653
            "handleheight": 2,
2654
            "handlelength": 4,
2655
        } | legend_kw
2656

2657
        if ax and legend_kw:
×
2658
            ax.legend(handles=pat_leg, **legend_kw)
×
2659
        elif legend_kw:
×
2660
            im.figlegend = im.fig.legend(handles=pat_leg, **legend_kw)
×
2661

2662
    # add features
2663
    if ax:
×
2664
        if extent:
×
2665
            ax.set_extent(extent)
×
2666
        if dattrs:
×
2667
            use_attrs.setdefault("title", "description")
×
2668

2669
        ax = add_features_map(
×
2670
            dattrs,
2671
            ax,
2672
            use_attrs,
2673
            projection,
2674
            features,
2675
            geometries_kw,
2676
            frame,
2677
        )
2678

2679
        if show_time:
×
2680
            if isinstance(show_time, bool):
×
2681
                plot_coords(
×
2682
                    ax,
2683
                    plot_data,
2684
                    param="time",
2685
                    loc="lower right",
2686
                    backgroundalpha=1,
2687
                )
NEW
2688
            elif isinstance(show_time, str | tuple | int):
×
2689
                plot_coords(
×
2690
                    ax,
2691
                    plot_data,
2692
                    param="time",
2693
                    loc=show_time,
2694
                    backgroundalpha=1,
2695
                )
2696

2697
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
2698
        if (frame is False) and (
×
2699
            (getattr(im, "colorbar", None) is not None)
2700
            or (getattr(im, "cbar", None) is not None)
2701
        ):
2702
            im.colorbar.outline.set_visible(False)
×
2703

2704
            set_plot_attrs(use_attrs, dattrs, ax, wrap_kw={"max_line_len": 60})
×
2705
        return ax
×
2706

2707
    else:
2708
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
2709
        if (frame is False) and (
×
2710
            (getattr(im, "colorbar", None) is not None)
2711
            or (getattr(im, "cbar", None) is not None)
2712
        ):
2713
            im.cbar.outline.set_visible(False)
×
2714

2715
        if show_time:
×
2716
            if show_time is True:
×
2717
                plot_coords(
×
2718
                    None,
2719
                    dattrs,
2720
                    param="time",
2721
                    loc="lower right",
2722
                    backgroundalpha=1,
2723
                )
NEW
2724
            elif isinstance(show_time, str | tuple | int):
×
2725
                plot_coords(
×
2726
                    None, dattrs, param="time", loc=show_time, backgroundalpha=1
2727
                )
2728
        if dattrs:
×
2729
            use_attrs.setdefault("suptitle", "long_name")
×
2730
            set_plot_attrs(use_attrs, dattrs, facetgrid=im)
×
2731

2732
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
2733
            for idx, ax in enumerate(im.axs.flat):
×
2734
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
2735

2736
        return im
×
2737

2738

2739
def _add_lead_time_coord(da, ref):
5✔
2740
    """Add a lead time coordinate to the data. Modifies da in-place."""
2741
    lead_time = da.time.dt.year - int(ref)
×
2742
    da["Lead time"] = lead_time
×
2743
    da["Lead time"].attrs["units"] = f"years from {ref}"
×
2744
    return lead_time
×
2745

2746

2747
def partition(
5✔
2748
    data: xr.DataArray | xr.Dataset,
2749
    ax: matplotlib.axes.Axes | None = None,
2750
    start_year: str | None = None,
2751
    show_num: bool = True,
2752
    fill_kw: dict[str, Any] | None = None,
2753
    line_kw: dict[str, Any] | None = None,
2754
    fig_kw: dict[str, Any] | None = None,
2755
    legend_kw: dict[str, Any] | None = None,
2756
) -> matplotlib.axes.Axes:
2757
    """
2758
    Figure of the partition of total uncertainty by components.
2759

2760
    Uncertainty fractions can be computed with xclim (https://xclim.readthedocs.io/en/stable/api.html#uncertainty-partitioning).
2761
    Make sure the use `fraction=True` in the xclim function call.
2762

2763
    Parameters
2764
    ----------
2765
    data : xr.DataArray or xr.Dataset
2766
        Variance over time of the different components of uncertainty.
2767
        Output of a `xclim.ensembles._partitioning` function.
2768
    ax : matplotlib axis, optional
2769
        Matplotlib axis on which to plot.
2770
    start_year : str
2771
        If None, the x-axis will be the time in year.
2772
        If str, the x-axis will show the number of year since start_year.
2773
    show_num : bool
2774
        If True, show the number of elements for each uncertainty components in parentheses in the legend.
2775
        `data` should have attributes named after the components with a list of its the elements.
2776
    fill_kw : dict
2777
        Keyword arguments passed to `ax.fill_between`.
2778
        It is possible to pass a dictionary of keywords for each component (uncertainty coordinates).
2779
    line_kw : dict
2780
        Keyword arguments passed to `ax.plot` for the lines in between the components.
2781
        The default is {color="k", lw=2}. We recommend always using lw>=2.
2782
    fig_kw : dict
2783
        Keyword arguments passed to `plt.subplots`.
2784
    legend_kw : dict
2785
        Keyword arguments passed to `ax.legend`.
2786

2787
    Returns
2788
    -------
2789
    mpl.axes.Axes
2790
    """
2791
    if isinstance(data, xr.Dataset):
×
2792
        if len(data.data_vars) > 1:
×
2793
            warnings.warn(
×
2794
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
2795
            )
2796
        data = data[list(data.keys())[0]].squeeze()
×
2797

2798
    if data.attrs["units"] != "%":
×
2799
        raise ValueError(
×
2800
            "The units are not %. Use `fraction=True` in the xclim function call."
2801
        )
2802

2803
    fill_kw = empty_dict(fill_kw)
×
2804
    line_kw = empty_dict(line_kw)
×
2805
    fig_kw = empty_dict(fig_kw)
×
2806
    legend_kw = empty_dict(legend_kw)
×
2807

2808
    # select data to plot
2809
    if isinstance(data, xr.DataArray):
×
2810
        data = data.squeeze()
×
2811
    elif isinstance(data, xr.Dataset):  # in case, it was saved to disk before plotting.
×
2812
        if len(data.data_vars) > 1:
×
2813
            warnings.warn(
×
2814
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
2815
            )
2816
        data = data[list(data.keys())[0]].squeeze()
×
2817
    else:
2818
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
2819

2820
    if ax is None:
×
2821
        fig, ax = plt.subplots(**fig_kw)
×
2822

2823
    # Select data from reference year onward
2824
    if start_year:
×
2825
        data = data.sel(time=slice(start_year, None))
×
2826

2827
        # Lead time coordinate
2828
        time = _add_lead_time_coord(data, start_year)
×
2829
        ax.set_xlabel(f"Lead time (years from {start_year})")
×
2830
    else:
2831
        time = data.time.dt.year
×
2832

2833
    # fill_kw that are direct (not with uncertainty as key)
2834
    fk_direct = {k: v for k, v in fill_kw.items() if (k not in data.uncertainty.values)}
×
2835

2836
    # Draw areas
2837
    past_y = 0
×
2838
    black_lines = []
×
2839
    for u in data.uncertainty.values:
×
2840
        if u not in ["total", "variability"]:
×
2841
            present_y = past_y + data.sel(uncertainty=u)
×
2842
            num = len(data.attrs.get(u, []))  # compatible with pre PR PR #1529
×
2843
            label = f"{u} ({num})" if show_num and num else u
×
2844
            ax.fill_between(
×
2845
                time,
2846
                past_y,
2847
                present_y,
2848
                label=label,
2849
                **fill_kw.get(u, fk_direct),
2850
            )
2851
            black_lines.append(present_y)
×
2852
            past_y = present_y
×
2853
    ax.fill_between(
×
2854
        time,
2855
        past_y,
2856
        100,
2857
        label="variability",
2858
        **fill_kw.get("variability", fk_direct),
2859
    )
2860

2861
    # Draw black lines
2862
    line_kw.setdefault("color", "k")
×
2863
    line_kw.setdefault("lw", 2)
×
2864
    ax.plot(time, np.array(black_lines).T, **line_kw)
×
2865

2866
    ax.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(20))
×
2867
    ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=5))
×
2868

2869
    ax.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(10))
×
2870
    ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=2))
×
2871

2872
    ax.set_ylabel(f"{data.attrs['long_name']} ({data.attrs['units']})")  #
×
2873

2874
    ax.set_ylim(0, 100)
×
2875
    ax.legend(**legend_kw)
×
2876

2877
    return ax
×
2878

2879

2880
def triheatmap(
5✔
2881
    data: xr.DataArray | xr.Dataset,
2882
    z: str,
2883
    ax: matplotlib.axes.Axes | None = None,
2884
    use_attrs: dict[str, Any] | None = None,
2885
    fig_kw: dict[str, Any] | None = None,
2886
    plot_kw: dict[str, Any] | None | list = None,
2887
    cmap: str | matplotlib.colors.Colormap | None = None,
2888
    divergent: bool | int | float = False,
2889
    cbar: bool | str = "unique",
2890
    cbar_kw: dict[str, Any] | None | list = None,
2891
) -> matplotlib.axes.Axes:
2892
    """
2893
    Create a triangle heatmap from a DataArray.
2894

2895
    Note that most of the code comes from:
2896
    https://stackoverflow.com/questions/66048529/how-to-create-a-heatmap-where-each-cell-is-divided-into-4-triangles
2897

2898
    Parameters
2899
    ----------
2900
    data : DataArray or Dataset
2901
        Input data do plot.
2902
    z: str
2903
        Dimension to plot on the triangles. Its length should be 2 or 4.
2904
    ax : matplotlib axis, optional
2905
        Matplotlib axis on which to plot, with the same projection as the one specified.
2906
    use_attrs : dict, optional
2907
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
2908
        Default value is {'cbar_label': 'long_name',"cbar_units": "units"}.
2909
        Valid keys are: 'title', 'xlabel', 'ylabel', 'cbar_label', 'cbar_units'.
2910
    fig_kw : dict, optional
2911
        Arguments to pass to `plt.figure()`.
2912
    plot_kw :  dict, optional
2913
        Arguments to pass to the 'plt.tripcolor()' function.
2914
        It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west).
2915
    cmap : matplotlib.colors.Colormap or str, optional
2916
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
2917
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
2918
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
2919
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
2920
    divergent : bool or int or float
2921
        If int or float, becomes center of cmap. Default center is 0.
2922
    cbar : {False, True, 'unique', 'each'}
2923
        If False, don't show the colorbar.
2924
        If True or 'unique', show a unique colorbar for all triangle types. (The cbar of the first triangle is used).
2925
        If 'each', show a colorbar for each triangle type.
2926
    cbar_kw : dict or list
2927
        Arguments to pass to 'fig.colorbar()'.
2928
        It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west).
2929

2930
    Returns
2931
    -------
2932
    matplotlib.axes.Axes
2933
    """
2934
    # create empty dicts if None
2935
    use_attrs = empty_dict(use_attrs)
×
2936
    fig_kw = empty_dict(fig_kw)
×
2937
    plot_kw = empty_dict(plot_kw)
×
2938
    cbar_kw = empty_dict(cbar_kw)
×
2939

2940
    # select data to plot
2941
    if isinstance(data, xr.DataArray):
×
2942
        da = data
×
2943
    elif isinstance(data, xr.Dataset):
×
2944
        if len(data.data_vars) > 1:
×
2945
            warnings.warn(
×
2946
                "data is xr.Dataset; only the first variable will be used in plot", stacklevel=2
2947
            )
2948
        da = list(data.values())[0]
×
2949
    else:
2950
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
2951

2952
    # setup fig, axis
2953
    if ax is None:
×
2954
        fig, ax = plt.subplots(**fig_kw)
×
2955

2956
    # colormap
2957
    if isinstance(cmap, str):
×
2958
        if cmap not in plt.colormaps():
×
2959
            try:
×
2960
                cmap = create_cmap(filename=cmap)
×
2961
            except FileNotFoundError:
×
2962
                pass
×
2963
                logging.log("Colormap not found. Using default.")
×
2964

2965
    elif cmap is None:
×
2966
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
2967
        cmap = create_cmap(
×
2968
            get_var_group(path_to_json=cdata, da=da),
2969
            divergent=divergent,
2970
        )
2971

2972
    # prep data
2973
    d = [da.sel(**{z: v}).values for v in da[z].values]
×
2974

2975
    other_dims = [di for di in da.dims if di != z]
×
2976
    if len(other_dims) > 2:
×
2977
        warnings.warn(
×
2978
            "More than 3 dimensions in data. The first two after dim will be used as the dimensions of the heatmap.", stacklevel=2
2979
        )
2980
    if len(other_dims) < 2:
×
2981
        raise ValueError(
×
2982
            "Data must have 3 dimensions. If you only have 2 dimensions, use fg.heatmap."
2983
        )
2984

2985
    if plot_kw == {} and cbar in ["unique", True]:
×
2986
        warnings.warn(
×
2987
            'With cbar="unique" only the colorbar of the first triangle'
2988
            " will be shown. No `plot_kw` was passed. vmin and vmax will be set the max"
2989
            " and min of data.", stacklevel=2
2990
        )
2991
        plot_kw = {"vmax": da.max().values, "vmin": da.min().values}
×
2992

2993
    if isinstance(plot_kw, dict):
×
2994
        plot_kw.setdefault("cmap", cmap)
×
2995
        plot_kw.setdefault("ec", "white")
×
2996
        plot_kw = [plot_kw for _ in range(len(d))]
×
2997

2998
    labels_x = da[other_dims[0]].values
×
2999
    labels_y = da[other_dims[1]].values
×
3000
    m, n = d[0].shape[0], d[0].shape[1]
×
3001

3002
    # plot
3003
    if len(d) == 2:
×
3004
        x = np.arange(m + 1)
×
3005
        y = np.arange(n + 1)
×
3006
        xss, ys = np.meshgrid(x, y)
×
NEW
3007
        (xss * ys) % 10
×
3008
        triangles1 = [
×
3009
            (i + j * (m + 1), i + 1 + j * (m + 1), i + (j + 1) * (m + 1))
3010
            for j in range(n)
3011
            for i in range(m)
3012
        ]
3013
        triangles2 = [
×
3014
            (
3015
                i + 1 + j * (m + 1),
3016
                i + 1 + (j + 1) * (m + 1),
3017
                i + (j + 1) * (m + 1),
3018
            )
3019
            for j in range(n)
3020
            for i in range(m)
3021
        ]
3022
        triang1 = Triangulation(xss.ravel(), ys.ravel(), triangles1)
×
3023
        triang2 = Triangulation(xss.ravel(), ys.ravel(), triangles2)
×
3024
        triangul = [triang1, triang2]
×
3025

3026
        imgs = [
×
3027
            ax.tripcolor(t, np.ravel(val), **plotkw)
3028
            for t, val, plotkw in zip(triangul, d, plot_kw, strict=False)
3029
        ]
3030

3031
        ax.set_xticks(np.array(range(m)) + 0.5, labels=labels_x, rotation=45)
×
3032
        ax.set_yticks(np.array(range(n)) + 0.5, labels=labels_y, rotation=90)
×
3033

3034
    elif len(d) == 4:
×
3035
        xv, yv = np.meshgrid(
×
3036
            np.arange(-0.5, m), np.arange(-0.5, n)
3037
        )  # vertices of the little squares
3038
        xc, yc = np.meshgrid(
×
3039
            np.arange(0, m), np.arange(0, n)
3040
        )  # centers of the little squares
3041
        x = np.concatenate([xv.ravel(), xc.ravel()])
×
3042
        y = np.concatenate([yv.ravel(), yc.ravel()])
×
3043
        cstart = (m + 1) * (n + 1)  # indices of the centers
×
3044

3045
        triangles_n = [
×
3046
            (i + j * (m + 1), i + 1 + j * (m + 1), cstart + i + j * m)
3047
            for j in range(n)
3048
            for i in range(m)
3049
        ]
3050
        triangles_e = [
×
3051
            (i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), cstart + i + j * m)
3052
            for j in range(n)
3053
            for i in range(m)
3054
        ]
3055
        triangles_s = [
×
3056
            (
3057
                i + 1 + (j + 1) * (m + 1),
3058
                i + (j + 1) * (m + 1),
3059
                cstart + i + j * m,
3060
            )
3061
            for j in range(n)
3062
            for i in range(m)
3063
        ]
3064
        triangles_w = [
×
3065
            (i + (j + 1) * (m + 1), i + j * (m + 1), cstart + i + j * m)
3066
            for j in range(n)
3067
            for i in range(m)
3068
        ]
3069
        triangul = [
×
3070
            Triangulation(x, y, triangles)
3071
            for triangles in [
3072
                triangles_n,
3073
                triangles_e,
3074
                triangles_s,
3075
                triangles_w,
3076
            ]
3077
        ]
3078

3079
        imgs = [
×
3080
            ax.tripcolor(t, np.ravel(val), **plotkw)
3081
            for t, val, plotkw in zip(triangul, d, plot_kw, strict=False)
3082
        ]
3083
        ax.set_xticks(np.array(range(m)), labels=labels_x, rotation=45)
×
3084
        ax.set_yticks(np.array(range(n)), labels=labels_y, rotation=90)
×
3085

3086
    else:
3087
        raise ValueError(
×
3088
            f"The length of the dimensiondim ({z},{len(d)}) should be either 2 or 4. It represents the number of triangles."
3089
        )
3090

3091
    ax.set_title(get_attributes(use_attrs.get("title", None), data))
×
3092
    ax.set_xlabel(other_dims[0])
×
3093
    ax.set_ylabel(other_dims[1])
×
3094
    if "xlabel" in use_attrs:
×
3095
        ax.set_xlabel(get_attributes(use_attrs["xlabel"], data))
×
3096
    if "ylabel" in use_attrs:
×
3097
        ax.set_ylabel(get_attributes(use_attrs["ylabel"], data))
×
3098
    ax.set_aspect("equal", "box")
×
3099
    ax.invert_yaxis()
×
3100
    ax.tick_params(left=False, bottom=False)
×
3101
    ax.spines["bottom"].set_visible(False)
×
3102
    ax.spines["left"].set_visible(False)
×
3103

3104
    # create cbar label
3105
    # set default use_attrs values
3106
    use_attrs.setdefault("cbar_label", "long_name")
×
3107
    use_attrs.setdefault("cbar_units", "units")
×
3108
    if (
×
3109
        "cbar_units" in use_attrs
3110
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
3111
    ):  # avoids '()' as label
3112
        cbar_label = (
×
3113
            get_attributes(use_attrs["cbar_label"], data)
3114
            + " ("
3115
            + get_attributes(use_attrs["cbar_units"], data)
3116
            + ")"
3117
        )
3118
    else:
3119
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
3120

3121
    if isinstance(cbar_kw, dict):
×
3122
        cbar_kw.setdefault("label", cbar_label)
×
3123
        cbar_kw = [cbar_kw for _ in range(len(d))]
×
3124
    if cbar == "unique":
×
3125
        plt.colorbar(imgs[0], ax=ax, **cbar_kw[0])
×
3126

3127
    elif (cbar == "each") or (cbar is True):
×
3128
        for i in reversed(range(len(d))):  # switch order of colour bars
×
3129
            plt.colorbar(imgs[i], ax=ax, **cbar_kw[i])
×
3130

3131
    return ax
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc