• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

Ouranosinc / figanos / 13506883248

24 Feb 2025 08:09PM UTC coverage: 8.124% (-0.03%) from 8.149%
13506883248

push

github

web-flow
Hatchmap and scattermap fixes (#195)

<!-- Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [ ] This PR addresses an already opened issue (for bug fixes /
features)
  - This PR fixes #xyz
- [ ] (If applicable) Documentation has been added / updated (for bug
fixes / features).
- [ ] (If applicable) Tests have been added.
- [ ] CHANGES.rst has been updated (with summary of main changes).
- [ ] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added.

### What kind of change does this PR introduce?

Some fixes and simplification of code:
- Option ‘no legend’ in `hatchmap`.
- Ensure ‘hatches’ are passed as a list.
- `scattermap` can use `’edgecolor’` and `’egdecolors’` interchangeably
like `matplotlib.pyplot.scatter’
- removed `deepcopy` operation in `hatchmap` and `scattermap` which is
already called in `figanos.matplotlib.utils.empyt_dict(). Consequently,
helper variables are renamed (`plot_kw_pop` -> `plot_kw` in
`scattermap`; `dc` -> `plot_kw` in `hatchmap`.
- Correct `extend` to `extent`

### Does this PR introduce a breaking change?

No


### Other information:

0 of 53 new or added lines in 1 file covered. (0.0%)

1 existing line in 1 file now uncovered.

155 of 1908 relevant lines covered (8.12%)

0.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

3.65
/src/figanos/matplotlib/plot.py
1
# noqa: D100
2
from __future__ import annotations
6✔
3

4
import copy
6✔
5
import logging
6✔
6
import math
6✔
7
import string
6✔
8
import warnings
6✔
9
from collections.abc import Iterable
6✔
10
from inspect import signature
6✔
11
from pathlib import Path
6✔
12
from typing import Any
6✔
13

14
import cartopy.mpl.geoaxes
6✔
15
import geopandas as gpd
6✔
16
import matplotlib
6✔
17
import matplotlib.axes
6✔
18
import matplotlib.cm
6✔
19
import matplotlib.colors
6✔
20
import matplotlib.pyplot as plt
6✔
21
import mpl_toolkits.axisartist.grid_finder as gf
6✔
22
import numpy as np
6✔
23
import pandas as pd
6✔
24
import seaborn as sns
6✔
25
import xarray as xr
6✔
26
from cartopy import crs as ccrs
6✔
27
from matplotlib.cm import ScalarMappable
6✔
28
from matplotlib.lines import Line2D
6✔
29
from matplotlib.projections import PolarAxes
6✔
30
from matplotlib.tri import Triangulation
6✔
31
from mpl_toolkits.axisartist.floating_axes import FloatingSubplot, GridHelperCurveLinear
6✔
32

33
from figanos.matplotlib.utils import (  # masknan_sizes_key,
6✔
34
    add_cartopy_features,
35
    add_features_map,
36
    check_timeindex,
37
    convert_scen_name,
38
    create_cmap,
39
    custom_cmap_norm,
40
    empty_dict,
41
    fill_between_label,
42
    get_array_categ,
43
    get_attributes,
44
    get_localized_term,
45
    get_rotpole,
46
    get_scen_color,
47
    get_var_group,
48
    gpd_to_ccrs,
49
    norm2range,
50
    plot_coords,
51
    process_keys,
52
    set_plot_attrs,
53
    size_legend_elements,
54
    sort_lines,
55
    split_legend,
56
    wrap_text,
57
)
58

59
logger = logging.getLogger(__name__)
6✔
60

61

62
def _plot_realizations(
6✔
63
    ax: matplotlib.axes.Axes,
64
    da: xr.DataArray,
65
    name: str,
66
    plot_kw: dict[str, Any],
67
    non_dict_data: dict[str, Any],
68
) -> matplotlib.axes.Axes:
69
    """Plot realizations from a DataArray, inside or outside a Dataset.
70

71
    Parameters
72
    ----------
73
    ax : matplotlib.axes.Axes
74
        The Matplotlib axis object.
75
    da : DataArray
76
        The DataArray containing the realizations.
77
    name : str
78
        The label to be used in the first part of a composite label.
79
        Can be the name of the parent Dataset or that of the DataArray.
80
    plot_kw : dict
81
        Dictionary of kwargs coming from the timeseries() input.
82
    non_dict_data : dict
83
        TBD.
84

85
    Returns
86
    -------
87
    matplotlib.axes.Axes
88
    """
89
    ignore_label = False
×
90

91
    for r in da.realization.values:
×
92
        if plot_kw[name]:  # if kwargs (all lines identical)
×
93
            if not ignore_label:  # if label not already in legend
×
94
                label = "" if non_dict_data is True else name
×
95
                ignore_label = True
×
96
            else:
97
                label = ""
×
98
        else:
99
            label = str(r) if non_dict_data is True else (name + "_" + str(r))
×
100

101
        ax.plot(
×
102
            da.sel(realization=r)["time"],
103
            da.sel(realization=r).values,
104
            label=label,
105
            **plot_kw[name],
106
        )
107

108
    return ax
×
109

110

111
def _plot_timeseries(
6✔
112
    ax: matplotlib.axes.Axes,
113
    name: str,
114
    arr: xr.DataArray | xr.Dataset,
115
    plot_kw: dict[str, Any],
116
    non_dict_data: bool,
117
    array_categ: dict[str, Any],
118
    legend: str,
119
) -> matplotlib.axes.Axes:
120
    """Plot figanos timeseries.
121

122
    Parameters
123
    ----------
124
    ax: matplotlib.axes.Axes
125
        Axe to be used for plotting.
126
    name : str
127
        Dictionary key of the plotted data.
128
    arr : Dataset/DataArray
129
        Data to be plotted.
130
    plot_kw : dict
131
        Dictionary of kwargs coming from the timeseries() input.
132
    non_dic_data : bool
133
        If True, plot_kw is not a dictionary.
134
    array_categ: dict
135
        Categories of data.
136
    legend: str
137
        Legend type.
138

139
    Returns
140
    -------
141
    matplotlib.axes.Axes
142
    """
143
    lines_dict = {}  # created to facilitate accessing line properties later
×
144
    # look for SSP, RCP, CMIP model color
145
    cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
×
146
    if get_scen_color(name, cat_colors):
×
147
        plot_kw[name].setdefault("color", get_scen_color(name, cat_colors))
×
148

149
    #  remove 'label' to avoid error due to double 'label' args
150
    if "label" in plot_kw[name]:
×
151
        del plot_kw[name]["label"]
×
152
        warnings.warn(f'"label" entry in plot_kw[{name}] will be ignored.')
×
153

154
    if array_categ[name] == "ENS_REALS_DA":
×
155
        _plot_realizations(ax, arr, name, plot_kw, non_dict_data)
×
156

157
    elif array_categ[name] == "ENS_REALS_DS":
×
158
        if len(arr.data_vars) >= 2:
×
159
            raise TypeError(
×
160
                "To plot multiple ensembles containing realizations, use DataArrays outside a Dataset"
161
            )
162
        for k, sub_arr in arr.data_vars.items():
×
163
            _plot_realizations(ax, sub_arr, name, plot_kw, non_dict_data)
×
164

165
    elif array_categ[name] == "ENS_PCT_DIM_DS":
×
166
        for k, sub_arr in arr.data_vars.items():
×
167
            sub_name = (
×
168
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
169
            )
170

171
            # extract each percentile array from the dims
172
            array_data = {}
×
173
            for pct in sub_arr.percentiles.values:
×
174
                array_data[str(pct)] = sub_arr.sel(percentiles=pct)
×
175

176
            # create a dictionary labeling the middle, upper and lower line
177
            sorted_lines = sort_lines(array_data)
×
178

179
            # plot
180
            lines_dict[sub_name] = ax.plot(
×
181
                array_data[sorted_lines["middle"]]["time"],
182
                array_data[sorted_lines["middle"]].values,
183
                label=sub_name,
184
                **plot_kw[name],
185
            )
186

187
            ax.fill_between(
×
188
                array_data[sorted_lines["lower"]]["time"],
189
                array_data[sorted_lines["lower"]].values,
190
                array_data[sorted_lines["upper"]].values,
191
                color=lines_dict[sub_name][0].get_color(),
192
                linewidth=0.0,
193
                alpha=0.2,
194
                label=fill_between_label(sorted_lines, name, array_categ, legend),
195
            )
196

197
    # other ensembles
198
    elif array_categ[name] in [
×
199
        "ENS_PCT_VAR_DS",
200
        "ENS_STATS_VAR_DS",
201
        "ENS_PCT_DIM_DA",
202
    ]:
203
        # extract each array from the datasets
204
        array_data = {}
×
205
        if array_categ[name] == "ENS_PCT_DIM_DA":
×
206
            for pct in arr.percentiles:
×
207
                array_data[str(int(pct))] = arr.sel(percentiles=int(pct))
×
208
        else:
209
            for k, v in arr.data_vars.items():
×
210
                array_data[k] = v
×
211

212
        # create a dictionary labeling the middle, upper and lower line
213
        sorted_lines = sort_lines(array_data)
×
214

215
        # plot
216
        lines_dict[name] = ax.plot(
×
217
            array_data[sorted_lines["middle"]]["time"],
218
            array_data[sorted_lines["middle"]].values,
219
            label=name,
220
            **plot_kw[name],
221
        )
222

223
        ax.fill_between(
×
224
            array_data[sorted_lines["lower"]]["time"],
225
            array_data[sorted_lines["lower"]].values,
226
            array_data[sorted_lines["upper"]].values,
227
            color=lines_dict[name][0].get_color(),
228
            linewidth=0.0,
229
            alpha=0.2,
230
            label=fill_between_label(sorted_lines, name, array_categ, legend),
231
        )
232

233
    #  non-ensemble Datasets
234
    elif array_categ[name] == "DS":
×
235
        ignore_label = False
×
236
        for k, sub_arr in arr.data_vars.items():
×
237
            sub_name = (
×
238
                sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
239
            )
240

241
            #  if kwargs are specified by user, all lines are the same and we want one legend entry
242
            if plot_kw[name]:
×
243
                label = name if not ignore_label else ""
×
244
                ignore_label = True
×
245
            else:
246
                label = sub_name
×
247

248
            lines_dict[sub_name] = ax.plot(
×
249
                sub_arr["time"], sub_arr.values, label=label, **plot_kw[name]
250
            )
251

252
    #  non-ensemble DataArrays
253
    elif array_categ[name] in ["DA"]:
×
254
        lines_dict[name] = ax.plot(arr["time"], arr.values, label=name, **plot_kw[name])
×
255

256
    else:
257
        raise ValueError(
×
258
            "Data structure not supported"
259
        )  # can probably be removed along with elif logic above,
260
        # given that get_array_categ() also does this check
261
    return ax
×
262

263

264
def timeseries(
6✔
265
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
266
    ax: matplotlib.axes.Axes | None = None,
267
    use_attrs: dict[str, Any] | None = None,
268
    fig_kw: dict[str, Any] | None = None,
269
    plot_kw: dict[str, Any] | None = None,
270
    legend: str = "lines",
271
    show_lat_lon: bool | str | int | tuple[float, float] = True,
272
    enumerate_subplots: bool = False,
273
) -> matplotlib.axes.Axes:
274
    """Plot time series from 1D Xarray Datasets or DataArrays as line plots.
275

276
    Parameters
277
    ----------
278
    data : dict or Dataset/DataArray
279
        Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets.
280
    ax : matplotlib.axes.Axes, optional
281
        Matplotlib axis on which to plot.
282
    use_attrs : dict, optional
283
        A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
284
        Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}.
285
        Only the keys found in the default dict can be used.
286
    fig_kw : dict, optional
287
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
288
    plot_kw : dict, optional
289
        Arguments to pass to the `plot()` function. Changes how the line looks.
290
        If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'.
291
    legend : str (default 'lines') or dict
292
        'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines),
293
         'edge' (out of plot), 'facetgrid' under figure, 'none' (no legend). If dict, arguments to pass to ax.legend().
294
    show_lat_lon : bool, tuple, str or int
295
        If True, show latitude and longitude at the bottom right of the figure.
296
        Can be a tuple of axis coordinates (from 0 to 1, as a fraction of the axis length) representing
297
        the location of the text. If a string or an int, the same values as those of the 'loc' parameter
298
        of matplotlib's legends are accepted.
299

300
        ==================   =============
301
        Location String      Location Code
302
        ==================   =============
303
        'upper right'        1
304
        'upper left'         2
305
        'lower left'         3
306
        'lower right'        4
307
        'right'              5
308
        'center left'        6
309
        'center right'       7
310
        'lower center'       8
311
        'upper center'       9
312
        'center'             10
313
        ==================   =============
314
    enumerate_subplots: bool
315
        If True, enumerate subplots with letters.
316
        Only works with facetgrids (pass `col` or `row` in plot_kw).
317

318
    Returns
319
    -------
320
    matplotlib.axes.Axes
321
    """
322
    # convert SSP, RCP, CMIP formats in keys
323
    if isinstance(data, dict):
×
324
        data = process_keys(data, convert_scen_name)
×
325
    if isinstance(plot_kw, dict):
×
326
        plot_kw = process_keys(plot_kw, convert_scen_name)
×
327

328
    # create empty dicts if None
329
    use_attrs = empty_dict(use_attrs)
×
330
    fig_kw = empty_dict(fig_kw)
×
331
    plot_kw = empty_dict(plot_kw)
×
332

333
    # if only one data input, insert in dict.
334
    non_dict_data = False
×
335
    if not isinstance(data, dict):
×
336
        non_dict_data = True
×
337
        data = {"_no_label": data}  # mpl excludes labels starting with "_" from legend
×
338
        plot_kw = {"_no_label": empty_dict(plot_kw)}
×
339

340
    # assign keys to plot_kw if not there
341
    if non_dict_data is False:
×
342
        for name in data:
×
343
            if name not in plot_kw:
×
344
                plot_kw[name] = {}
×
345
        for key in plot_kw:
×
346
            if key not in data:
×
347
                raise KeyError(
×
348
                    'plot_kw must be a nested dictionary with keys corresponding to the keys in "data"'
349
                )
350

351
    # check: type
352
    for name, arr in data.items():
×
353
        if not isinstance(arr, (xr.Dataset, xr.DataArray)):
×
354
            raise TypeError(
×
355
                '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
356
            )
357

358
    # check: 'time' dimension and calendar format
359
    data = check_timeindex(data)
×
360

361
    # set fig, ax if not provided
362
    if ax is None and (
×
363
        "row" not in list(plot_kw.values())[0].keys()
364
        and "col" not in list(plot_kw.values())[0].keys()
365
    ):
366
        fig, ax = plt.subplots(**fig_kw)
×
367
    elif ax is not None and (
×
368
        "col" in list(plot_kw.values())[0].keys()
369
        or "row" in list(plot_kw.values())[0].keys()
370
    ):
371
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
372
    elif ax is None:
×
373
        cfig_kw = fig_kw.copy()
×
374
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
375
            list(plot_kw.values())[0].setdefault("figsize", fig_kw["figsize"])
×
376
            cfig_kw.pop("figsize")
×
377
        if cfig_kw:
×
378
            for v in plot_kw.values():
×
379
                {"subplots_kws": cfig_kw} | v
×
380
            warnings.warn(
×
381
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid."
382
            )
383

384
    # set default use_attrs values
385
    if ax:
×
386
        use_attrs.setdefault("title", "description")
×
387
    else:
388
        use_attrs.setdefault("suptitle", "description")
×
389
    use_attrs.setdefault("ylabel", "long_name")
×
390
    use_attrs.setdefault("yunits", "units")
×
391

392
    # dict of array 'categories'
393
    array_categ = {name: get_array_categ(array) for name, array in data.items()}
×
394
    cp_plot_kw = copy.deepcopy(plot_kw)
×
395
    # get data and plot
396
    for name, arr in data.items():
×
397
        if ax:
×
398
            _plot_timeseries(ax, name, arr, plot_kw, non_dict_data, array_categ, legend)
×
399
        else:
400
            if name == list(data.keys())[0]:
×
401
                # create empty DataArray with same dimensions as data first entry to create an empty xr.plot.FacetGrid
402
                if isinstance(arr, xr.Dataset):
×
403
                    da = arr[list(arr.keys())[0]]
×
404
                else:
405
                    da = arr
×
406
                da = da.where(da == np.nan)
×
407
                im = da.plot(**plot_kw[name], color="white")
×
408

409
            [
×
410
                cp_plot_kw[name].pop(key)
411
                for key in ["row", "col", "figsize"]
412
                if key in cp_plot_kw[name].keys()
413
            ]
414

415
            # plot data in every axis of the facetgrid
416
            for i in range(0, im.axs.shape[0]):
×
417
                for j in range(0, im.axs.shape[1]):
×
418
                    sel_arr = {}
×
419

420
                    if "row" in plot_kw[name]:
×
421
                        sel_arr[plot_kw[name]["row"]] = i
×
422
                    if "col" in plot_kw[name]:
×
423
                        sel_arr[plot_kw[name]["col"]] = j
×
424

425
                    _plot_timeseries(
×
426
                        im.axs[i, j],
427
                        name,
428
                        arr.isel(**sel_arr).squeeze(),
429
                        cp_plot_kw,
430
                        non_dict_data,
431
                        array_categ,
432
                        legend,
433
                    )
434

435
    #  add/modify plot elements according to the first entry.
436
    if ax:
×
437
        set_plot_attrs(
×
438
            use_attrs,
439
            list(data.values())[0],
440
            ax,
441
            title_loc="left",
442
            wrap_kw={"min_line_len": 35, "max_line_len": 48},
443
        )
444
        ax.set_xlabel(
×
445
            get_localized_term("time").capitalize()
446
        )  # check_timeindex() already checks for 'time'
447

448
        # other plot elements
449
        if show_lat_lon:
×
450
            if show_lat_lon is True:
×
451
                plot_coords(
×
452
                    ax,
453
                    list(data.values())[0],
454
                    param="location",
455
                    loc="lower right",
456
                    backgroundalpha=1,
457
                )
458
            elif isinstance(show_lat_lon, (str, tuple, int)):
×
459
                plot_coords(
×
460
                    ax,
461
                    list(data.values())[0],
462
                    param="location",
463
                    loc=show_lat_lon,
464
                    backgroundalpha=1,
465
                )
466
            else:
467
                raise TypeError(" show_lat_lon must be a bool, string, int, or tuple")
×
468

469
        if legend is not None:
×
470
            if not ax.get_legend_handles_labels()[0]:  # check if legend is empty
×
471
                pass
×
472
            elif legend == "in_plot":
×
473
                split_legend(ax, in_plot=True)
×
474
            elif legend == "edge":
×
475
                split_legend(ax, in_plot=False)
×
476
            elif isinstance(legend, dict):
×
477
                ax.legend(**legend)
×
478
            else:
479
                ax.legend()
×
480

481
        return ax
×
482
    else:
483
        if legend is not None:
×
484
            if not im.axs[-1, -1].get_legend_handles_labels()[
×
485
                0
486
            ]:  # check if legend is empty
487
                pass
×
488
            elif legend == "in_plot":
×
489
                split_legend(im.axs[-1, -1], in_plot=True)
×
490
            elif legend == "edge":
×
491
                split_legend(im.axs[-1, -1], in_plot=False)
×
492
            elif isinstance(legend, dict):
×
493
                handles, labels = im.axs[-1, -1].get_legend_handles_labels()
×
494
                legend = {"handles": handles, "labels": labels} | legend
×
495
                im.fig.legend(**legend)
×
496
            elif legend == "facetgrid":
×
497
                handles, labels = im.axs[-1, -1].get_legend_handles_labels()
×
498
                im.fig.legend(
×
499
                    handles,
500
                    labels,
501
                    loc="lower center",
502
                    ncol=len(im.axs[-1, -1].lines),
503
                    bbox_to_anchor=(0.5, -0.05),
504
                )
505

506
        if show_lat_lon:
×
507
            if show_lat_lon is True:
×
508
                plot_coords(
×
509
                    None,
510
                    list(data.values())[0].isel(lat=0, lon=0),
511
                    param="location",
512
                    loc="lower right",
513
                    backgroundalpha=1,
514
                )
515
            elif isinstance(show_lat_lon, (str, tuple, int)):
×
516
                plot_coords(
×
517
                    None,
518
                    list(data.values())[0].isel(lat=0, lon=0),
519
                    param="location",
520
                    loc=show_lat_lon,
521
                    backgroundalpha=1,
522
                )
523
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
524
            for idx, ax in enumerate(im.axs.flat):
×
525
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
526

527
        return im
×
528

529

530
def gridmap(
6✔
531
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
532
    ax: matplotlib.axes.Axes | None = None,
533
    use_attrs: dict[str, Any] | None = None,
534
    fig_kw: dict[str, Any] | None = None,
535
    plot_kw: dict[str, Any] | None = None,
536
    projection: ccrs.Projection = ccrs.LambertConformal(),
537
    transform: ccrs.Projection | None = None,
538
    features: list[str] | dict[str, dict[str, Any]] | None = None,
539
    geometries_kw: dict[str, Any] | None = None,
540
    contourf: bool = False,
541
    cmap: str | matplotlib.colors.Colormap | None = None,
542
    levels: int | list | np.ndarray | None = None,
543
    divergent: bool | int | float = False,
544
    show_time: bool | str | int | tuple[float, float] = False,
545
    frame: bool = False,
546
    enumerate_subplots: bool = False,
547
) -> matplotlib.axes.Axes:
548
    """Create map from 2D data.
549

550
    Parameters
551
    ----------
552
    data : dict, DataArray or Dataset
553
        Input data do plot. If dictionary, must have only one entry.
554
    ax : matplotlib axis, optional
555
        Matplotlib axis on which to plot, with the same projection as the one specified.
556
    use_attrs : dict, optional
557
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
558
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
559
        Only the keys found in the default dict can be used.
560
    fig_kw : dict, optional
561
        Arguments to pass to `plt.figure()`.
562
    plot_kw:  dict, optional
563
        Arguments to pass to the `xarray.plot.pcolormesh()` or 'xarray.plot.contourf()' function.
564
    projection : ccrs.Projection
565
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
566
    transform : ccrs.Projection, optional
567
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
568
        ccrs.PlateCarree() or ccrs.RotatedPole().
569
    features : list or dict, optional
570
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
571
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
572
    geometries_kw : dict, optional
573
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
574
    contourf : bool
575
        By default False, use plt.pcolormesh(). If True, use plt.contourf().
576
    cmap : matplotlib.colors.Colormap or str, optional
577
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
578
        If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of the DataArray
579
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC visual style guide 2022
580
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
581
    levels : int, list, np.ndarray, optional
582
        Number of levels to divide the colormap into or list of level boundaries (in data units).
583
    divergent : bool or int or float
584
        If int or float, becomes center of cmap. Default center is 0.
585
    show_time : bool, tuple, string or int.
586
        If True, show time (as date) at the bottom right of the figure.
587
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
588
        of the text. If a string or an int, the same values as those of the 'loc' parameter
589
        of matplotlib's legends are accepted.
590

591
        ==================   =============
592
        Location String      Location Code
593
        ==================   =============
594
        'upper right'        1
595
        'upper left'         2
596
        'lower left'         3
597
        'lower right'        4
598
        'right'              5
599
        'center left'        6
600
        'center right'       7
601
        'lower center'       8
602
        'upper center'       9
603
        'center'             10
604
        ==================   =============
605
    frame : bool
606
        Show or hide frame. Default False.
607
    enumerate_subplots: bool
608
        If True, enumerate subplots with letters.
609
        Only works with facetgrids (pass `col` or `row` in plot_kw).
610

611
    Returns
612
    -------
613
    matplotlib.axes.Axes
614
    """
615
    # create empty dicts if None
616
    use_attrs = empty_dict(use_attrs)
×
617
    fig_kw = empty_dict(fig_kw)
×
618
    plot_kw = empty_dict(plot_kw)
×
619

620
    # set default use_attrs values
621
    use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs
×
622
    if "row" not in plot_kw and "col" not in plot_kw:
×
623
        use_attrs.setdefault("title", "description")
×
624

625
    # extract plot_kw from dict if needed
626
    if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys():
×
627
        plot_kw = plot_kw[list(data.keys())[0]]
×
628

629
    # if data is dict, extract
630
    if isinstance(data, dict):
×
631
        if len(data) == 1:
×
632
            data = list(data.values())[0]
×
633
        else:
634
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
635

636
    # select data to plot
637
    if isinstance(data, xr.DataArray):
×
638
        plot_data = data.squeeze()
×
639
    elif isinstance(data, xr.Dataset):
×
640
        if len(data.data_vars) > 1:
×
641
            warnings.warn(
×
642
                "data is xr.Dataset; only the first variable will be used in plot"
643
            )
644
        plot_data = data[list(data.keys())[0]].squeeze()
×
645
    else:
646
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
647

648
    # setup transform
649
    if transform is None:
×
650
        if "lat" in data.dims and "lon" in data.dims:
×
651
            transform = ccrs.PlateCarree()
×
652
        if "rlat" in data.dims and "rlon" in data.dims:
×
653
            if hasattr(data, "rotated_pole"):
×
654
                transform = get_rotpole(data)
×
655

656
    # setup fig, ax
657
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
658
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
659
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
660
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
661
    elif ax is None:
×
662
        plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
663
        cfig_kw = fig_kw.copy()
×
664
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
665
            plot_kw.setdefault("figsize", fig_kw["figsize"])
×
666
            cfig_kw.pop("figsize")
×
667
        if len(cfig_kw) >= 1:
×
668
            plot_kw = {"subplot_kws": {"projection": cfig_kw}} | plot_kw
×
669
            warnings.warn(
×
670
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid."
671
            )
672

673
    # create cbar label
674
    if (
×
675
        "cbar_units" in use_attrs
676
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
677
    ):  # avoids '[]' as label
678
        cbar_label = (
×
679
            get_attributes(use_attrs["cbar_label"], data)
680
            + " ("
681
            + get_attributes(use_attrs["cbar_units"], data)
682
            + ")"
683
        )
684
    else:
685
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
686

687
    # colormap
688
    if isinstance(cmap, str):
×
689
        if cmap not in plt.colormaps():
×
690
            try:
×
691
                cmap = create_cmap(filename=cmap)
×
692
            except FileNotFoundError as e:
×
693
                logger.error(e)
×
694
                pass
×
695

696
    elif cmap is None:
×
697
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
698
        cmap = create_cmap(
×
699
            get_var_group(path_to_json=cdata, da=plot_data),
700
            divergent=divergent,
701
        )
702
    plot_kw.setdefault("cmap", cmap)
×
703

704
    if levels is not None:
×
705
        if isinstance(levels, Iterable):
×
706
            lin = levels
×
707
        else:
708
            lin = custom_cmap_norm(
×
709
                cmap,
710
                np.nanmin(plot_data.values),
711
                np.nanmax(plot_data.values),
712
                levels=levels,
713
                divergent=divergent,
714
                linspace_out=True,
715
            )
716
        plot_kw.setdefault("levels", lin)
×
717

718
    elif (divergent is not False) and ("levels" not in plot_kw):
×
719
        norm = custom_cmap_norm(
×
720
            cmap,
721
            np.nanmin(plot_data.values),
722
            np.nanmax(plot_data.values),
723
            levels=levels,
724
            divergent=divergent,
725
        )
726
        plot_kw.setdefault("norm", norm)
×
727

728
    # set defaults
729
    if divergent is not False:
×
730
        if isinstance(divergent, (int, float)):
×
731
            plot_kw.setdefault("center", divergent)
×
732
        else:
733
            plot_kw.setdefault("center", 0)
×
734

735
    if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False:
×
736
        plot_kw.setdefault("cbar_kwargs", {})
×
737
        plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
×
738

739
    # bug xlim / ylim + transform in facetgrids
740
    # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766)
741
    if transform and ("xlim" in plot_kw and "ylim" in plot_kw):
×
742
        extent = [
×
743
            plot_kw["xlim"][0],
744
            plot_kw["xlim"][1],
745
            plot_kw["ylim"][0],
746
            plot_kw["ylim"][1],
747
        ]
748
        plot_kw.pop("xlim")
×
749
        plot_kw.pop("ylim")
×
750
    elif transform and ("xlim" in plot_kw or "ylim" in plot_kw):
×
751
        extent = None
×
752
        warnings.warn(
×
753
            "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped"
754
        )
755
        if "xlim" in plot_kw.keys():
×
756
            plot_kw.pop("xlim")
×
757
        if "ylim" in plot_kw.keys():
×
758
            plot_kw.pop("ylim")
×
759
    else:
760
        extent = None
×
761

762
    # plot
763
    if ax:
×
764
        plot_kw.setdefault("ax", ax)
×
765
    if transform:
×
766
        plot_kw.setdefault("transform", transform)
×
767

768
    if contourf is False:
×
769
        im = plot_data.plot.pcolormesh(**plot_kw)
×
770
    else:
771
        im = plot_data.plot.contourf(**plot_kw)
×
772

773
    if ax:
×
774
        if extent:
×
775
            ax.set_extent(extent)
×
776

777
        ax = add_features_map(
×
778
            data,
779
            ax,
780
            use_attrs,
781
            projection,
782
            features,
783
            geometries_kw,
784
            frame,
785
        )
786
        if show_time:
×
787
            if isinstance(show_time, bool):
×
788
                plot_coords(
×
789
                    ax,
790
                    plot_data,
791
                    param="time",
792
                    loc="lower right",
793
                    backgroundalpha=1,
794
                )
795
            elif isinstance(show_time, (str, tuple, int)):
×
796
                plot_coords(
×
797
                    ax,
798
                    plot_data,
799
                    param="time",
800
                    loc=show_time,
801
                    backgroundalpha=1,
802
                )
803

804
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
805
        if (frame is False) and (
×
806
            (getattr(im, "colorbar", None) is not None)
807
            or (getattr(im, "cbar", None) is not None)
808
        ):
809
            im.colorbar.outline.set_visible(False)
×
810
        return ax
×
811

812
    else:
813
        for i, fax in enumerate(im.axs.flat):
×
814
            add_features_map(
×
815
                data,
816
                fax,
817
                use_attrs,
818
                projection,
819
                features,
820
                geometries_kw,
821
                frame,
822
            )
823
            if extent:
×
824
                fax.set_extent(extent)
×
825

826
            # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
827
        if (frame is False) and (
×
828
            (getattr(im, "colorbar", None) is not None)
829
            or (getattr(im, "cbar", None) is not None)
830
        ):
831
            im.cbar.outline.set_visible(False)
×
832

833
        if show_time:
×
834
            if isinstance(show_time, bool):
×
835
                plot_coords(
×
836
                    None,
837
                    plot_data,
838
                    param="time",
839
                    loc="lower right",
840
                    backgroundalpha=1,
841
                )
842
            elif isinstance(show_time, (str, tuple, int)):
×
843
                plot_coords(
×
844
                    None,
845
                    plot_data,
846
                    param="time",
847
                    loc=show_time,
848
                    backgroundalpha=1,
849
                )
850

851
        use_attrs.setdefault("suptitle", "long_name")
×
852
        im = set_plot_attrs(use_attrs, data, facetgrid=im)
×
853
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
854
            for idx, ax in enumerate(im.axs.flat):
×
855
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
856

857
        return im
×
858

859

860
def gdfmap(
6✔
861
    df: gpd.GeoDataFrame,
862
    df_col: str,
863
    ax: cartopy.mpl.geoaxes.GeoAxes | cartopy.mpl.geoaxes.GeoAxesSubplot | None = None,
864
    fig_kw: dict[str, Any] | None = None,
865
    plot_kw: dict[str, Any] | None = None,
866
    projection: ccrs.Projection = ccrs.LambertConformal(),
867
    features: list[str] | dict[str, dict[str, Any]] | None = None,
868
    cmap: str | matplotlib.colors.Colormap | None = None,
869
    levels: int | list[int | float] | None = None,
870
    divergent: bool | int | float = False,
871
    cbar: bool = True,
872
    frame: bool = False,
873
) -> matplotlib.axes.Axes:
874
    """Create a map plot from geometries.
875

876
    Parameters
877
    ----------
878
    df : geopandas.GeoDataFrame
879
        Dataframe containing the geometries and the data to plot. Must have a column named 'geometry'.
880
    df_col : str
881
        Name of the column of 'df' containing the data to plot using the colorscale.
882
    ax : cartopy.mpl.geoaxes.GeoAxes or cartopy.mpl.geoaxes.GeoaxesSubplot, optional
883
        Matplotlib axis built with a projection, on which to plot.
884
    fig_kw : dict, optional
885
        Arguments to pass to `plt.figure()`.
886
    plot_kw :  dict, optional
887
        Arguments to pass to the GeoDataFrame.plot() method.
888
    projection : ccrs.Projection
889
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
890
    features : list or dict, optional
891
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
892
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
893
    cmap : matplotlib.colors.Colormap or str
894
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
895
        If None, look for common variables (from data/ipcc_colors/varaibles_groups.json) in the name of df_col
896
        and use corresponding colormap, aligned with the IPCC visual style guide 2022
897
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
898
    levels : int or list, optional
899
        Number of  levels or list of level boundaries (in data units) to use to divide the colormap.
900
    divergent : bool or int or float
901
        If int or float, becomes center of cmap. Default center is 0.
902
    cbar : bool
903
        Show colorbar. Default 'True'.
904
    frame : bool
905
        Show or hide frame. Default False.
906

907
    Returns
908
    -------
909
    matplotlib.axes.Axes
910
    """
911
    # create empty dicts if None
912
    fig_kw = empty_dict(fig_kw)
×
913
    plot_kw = empty_dict(plot_kw)
×
914
    features = empty_dict(features)
×
915

916
    # checks
917
    if not isinstance(df, gpd.GeoDataFrame):
×
918
        raise TypeError("df myst be an instance of class geopandas.GeoDataFrame")
×
919

920
    if "geometry" not in df.columns:
×
921
        raise ValueError("column 'geometry' not found in GeoDataFrame")
×
922

923
    # convert to projection
924
    if ax is None:
×
925
        df = gpd_to_ccrs(df=df, proj=projection)
×
926
    else:
927
        df = gpd_to_ccrs(df=df, proj=ax.projection)
×
928

929
    # setup fig, ax
930
    if ax is None:
×
931
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
932
        ax.set_aspect("equal")  # recommended by geopandas
×
933

934
    # add features
935
    if features:
×
936
        add_cartopy_features(ax, features)
×
937

938
    # colormap
939
    if isinstance(cmap, str):
×
940
        if cmap in plt.colormaps():
×
941
            cmap = matplotlib.colormaps[cmap]
×
942
        else:
943
            try:
×
944
                cmap = create_cmap(filename=cmap)
×
945
            except FileNotFoundError:
×
946
                warnings.warn("invalid cmap, using default")
×
947
                cmap = create_cmap(filename="slev_seq")
×
948

949
    elif cmap is None:
×
950
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
951
        cmap = create_cmap(
×
952
            get_var_group(unique_str=df_col, path_to_json=cdata),
953
            divergent=divergent,
954
        )
955

956
    # create normalization for colormap
957
    plot_kw.setdefault("vmin", df[df_col].min())
×
958
    plot_kw.setdefault("vmax", df[df_col].max())
×
959

960
    if (levels is not None) or (divergent is not False):
×
961
        norm = custom_cmap_norm(
×
962
            cmap,
963
            plot_kw["vmin"],
964
            plot_kw["vmax"],
965
            levels=levels,
966
            divergent=divergent,
967
        )
968
        plot_kw.setdefault("norm", norm)
×
969

970
    # colorbar
971
    if cbar:
×
972
        plot_kw.setdefault("legend", True)
×
973
        plot_kw.setdefault("legend_kwds", {})
×
974
        plot_kw["legend_kwds"].setdefault("label", df_col)
×
975
        plot_kw["legend_kwds"].setdefault("orientation", "horizontal")
×
976
        plot_kw["legend_kwds"].setdefault("pad", 0.02)
×
977

978
    # plot
979
    plot = df.plot(column=df_col, ax=ax, cmap=cmap, **plot_kw)
×
980

981
    if frame is False:
×
982
        # cbar
983
        plot.figure.axes[1].spines["outline"].set_visible(False)
×
984
        plot.figure.axes[1].tick_params(size=0)
×
985
        # main axes
986
        ax.spines["geo"].set_visible(False)
×
987

988
    return ax
×
989

990

991
def violin(
6✔
992
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
993
    ax: matplotlib.axes.Axes | None = None,
994
    use_attrs: dict[str, Any] | None = None,
995
    fig_kw: dict[str, Any] | None = None,
996
    plot_kw: dict[str, Any] | None = None,
997
    color: str | int | list[str | int] | None = None,
998
) -> matplotlib.axes.Axes:
999
    """Make violin plot using seaborn.
1000

1001
    Parameters
1002
    ----------
1003
    data : dict or Dataset/DataArray
1004
        Input data to plot. If a dict, must contain DataArrays and/or Datasets.
1005
    ax : matplotlib.axes.Axes, optional
1006
        Matplotlib axis on which to plot.
1007
    use_attrs : dict, optional
1008
        A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1009
        Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}.
1010
        Only the keys found in the default dict can be used.
1011
    fig_kw : dict, optional
1012
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
1013
    plot_kw : dict, optional
1014
        Arguments to pass to the `seaborn.violinplot()` function.
1015
    color :  str, int or list, optional
1016
        Unique color or list of colors to use. Integers point to the applied stylesheet's colors, in zero-indexed order.
1017
        Passing 'color' or 'palette' in plot_kw overrides this argument.
1018

1019
    Returns
1020
    -------
1021
    matplotlib.axes.Axes
1022
    """
1023
    # create empty dicts if None
1024
    use_attrs = empty_dict(use_attrs)
×
1025
    fig_kw = empty_dict(fig_kw)
×
1026
    plot_kw = empty_dict(plot_kw)
×
1027

1028
    # if data is dict, assemble into one DataFrame
1029
    non_dict_data = True
×
1030
    if isinstance(data, dict):
×
1031
        non_dict_data = False
×
1032
        df = pd.DataFrame()
×
1033
        for key, xr_obj in data.items():
×
1034
            if isinstance(xr_obj, xr.Dataset):
×
1035
                # if one data var, use key
1036
                if len(list(xr_obj.data_vars)) == 1:
×
1037
                    df[key] = xr_obj[list(xr_obj.data_vars)[0]].values
×
1038
                # if more than one data var, use key + name of var
1039
                else:
1040
                    for data_var in list(xr_obj.data_vars):
×
1041
                        df[key + "_" + data_var] = xr_obj[data_var].values
×
1042

1043
            elif isinstance(xr_obj, xr.DataArray):
×
1044
                df[key] = xr_obj.values
×
1045

1046
            else:
1047
                raise TypeError(
×
1048
                    '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
1049
                )
1050

1051
    elif isinstance(data, xr.Dataset):
×
1052
        # create dataframe
1053
        df = data.to_dataframe()
×
1054
        df = df[data.data_vars]
×
1055

1056
    elif isinstance(data, xr.DataArray):
×
1057
        # create dataframe
1058
        df = data.to_dataframe()
×
1059
        for coord in list(data.coords):
×
1060
            if coord in df.columns:
×
1061
                df = df.drop(columns=coord)
×
1062

1063
    else:
1064
        raise TypeError(
×
1065
            '"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
1066
        )
1067

1068
    # set fig, ax if not provided
1069
    if ax is None:
×
1070
        fig, ax = plt.subplots(**fig_kw)
×
1071

1072
    # set default use_attrs values
1073
    if "orient" in plot_kw and plot_kw["orient"] == "h":
×
1074
        use_attrs = {"xlabel": "long_name", "xunits": "units"} | use_attrs
×
1075
    else:
1076
        use_attrs = {"ylabel": "long_name", "yunits": "units"} | use_attrs
×
1077

1078
    #  add/modify plot elements according to the first entry.
1079
    if non_dict_data:
×
1080
        set_plot_obj = data
×
1081
    else:
1082
        set_plot_obj = list(data.values())[0]
×
1083

1084
    set_plot_attrs(
×
1085
        use_attrs,
1086
        xr_obj=set_plot_obj,
1087
        ax=ax,
1088
        title_loc="left",
1089
        wrap_kw={"min_line_len": 35, "max_line_len": 48},
1090
    )
1091

1092
    # color
1093
    if color:
×
1094
        style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]
×
1095
        if isinstance(color, str):
×
1096
            plot_kw.setdefault("color", color)
×
1097
        elif isinstance(color, int):
×
1098
            try:
×
1099
                plot_kw.setdefault("color", style_colors[color])
×
1100
            except IndexError:
×
1101
                raise IndexError("Index out of range of stylesheet colors")
×
1102
        elif isinstance(color, list):
×
1103
            for c, i in zip(color, np.arange(len(color))):
×
1104
                if isinstance(c, int):
×
1105
                    try:
×
1106
                        color[i] = style_colors[c]
×
1107
                    except IndexError:
×
1108
                        raise IndexError("Index out of range of stylesheet colors")
×
1109
            plot_kw.setdefault("palette", color)
×
1110

1111
    # plot
1112
    sns.violinplot(df, ax=ax, **plot_kw)
×
1113

1114
    # grid
1115
    if "orient" in plot_kw and plot_kw["orient"] == "h":
×
1116
        ax.grid(visible=True, axis="x")
×
1117

1118
    return ax
×
1119

1120

1121
def stripes(
6✔
1122
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
1123
    ax: matplotlib.axes.Axes | None = None,
1124
    fig_kw: dict[str, Any] | None = None,
1125
    divide: int | None = None,
1126
    cmap: str | matplotlib.colors.Colormap | None = None,
1127
    cmap_center: int | float = 0,
1128
    cbar: bool = True,
1129
    cbar_kw: dict[str, Any] | None = None,
1130
) -> matplotlib.axes.Axes:
1131
    """Create stripes plot with or without multiple scenarios.
1132

1133
    Parameters
1134
    ----------
1135
    data : dict or DataArray or Dataset
1136
        Data to plot. If a dictionary of xarray objects, each will correspond to a scenario.
1137
    ax : matplotlib.axes.Axes, optional
1138
        Matplotlib axis on which to plot.
1139
    fig_kw : : dict, optional
1140
        Arguments to pass to `plt.subplots()`. Only works if `ax` is not provided.
1141
    divide : int, optional
1142
        Year at which the plot is divided into scenarios. If not provided, the horizontal separators
1143
        will extend over the full time axis.
1144
    cmap : matplotlib.colors.Colormap or str, optional
1145
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1146
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1147
        or its 'history' attribute and use corresponding diverging colormap, aligned with the IPCC Visual Style
1148
        Guide 2022 (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1149
    cmap_center : int or float
1150
        Center of the colormap in data coordinates. Default is 0.
1151
    cbar : bool
1152
        Show colorbar.
1153
    cbar_kw : dict, optional
1154
        Arguments to pass to plt.colorbar.
1155

1156
    Returns
1157
    -------
1158
    matplotlib.axes.Axes
1159
    """
1160
    # create empty dicts if None
1161
    fig_kw = empty_dict(fig_kw)
×
1162
    cbar_kw = empty_dict(cbar_kw)
×
1163

1164
    # init main (figure) axis
1165
    if ax is None:
×
1166
        fig_kw.setdefault("figsize", (10, 5))
×
1167
        fig, ax = plt.subplots(**fig_kw)
×
1168
    ax.set_yticks([])
×
1169
    ax.set_xticks([])
×
1170
    ax.spines[["top", "bottom", "left", "right"]].set_visible(False)
×
1171

1172
    # init plot axis
1173
    ax_0 = ax.inset_axes([0, 0.15, 1, 0.75])
×
1174

1175
    # handle non-dict data
1176
    if not isinstance(data, dict):
×
1177
        data = {"_no_label": data}
×
1178

1179
    # convert SSP, RCP, CMIP formats in keys
1180
    data = process_keys(data, convert_scen_name)
×
1181

1182
    n = len(data)
×
1183

1184
    # extract DataArrays from datasets
1185
    for key, obj in data.items():
×
1186
        if isinstance(obj, xr.DataArray):
×
1187
            pass
×
1188
        elif isinstance(obj, xr.Dataset):
×
1189
            data[key] = obj[list(obj.data_vars)[0]]
×
1190
        else:
1191
            raise TypeError("data must contain xarray DataArrays or Datasets")
×
1192

1193
    # get time interval
1194
    time_index = list(data.values())[0].time.dt.year.values
×
1195
    delta_time = [
×
1196
        time_index[i] - time_index[i - 1] for i in np.arange(1, len(time_index), 1)
1197
    ]
1198

1199
    if all(i == delta_time[0] for i in delta_time):
×
1200
        dtime = delta_time[0]
×
1201
    else:
1202
        raise ValueError("Time delta between each array element must be constant")
×
1203

1204
    # modify axes
1205
    ax.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime)
×
1206
    ax_0.set_xlim(min(time_index) - 0.5 * dtime, max(time_index) + 0.5 * dtime)
×
1207
    ax_0.set_ylim(0, 1)
×
1208
    ax_0.set_yticks([])
×
1209
    ax_0.xaxis.set_ticks_position("top")
×
1210
    ax_0.tick_params(axis="x", direction="out", zorder=10)
×
1211
    ax_0.spines[["top", "left", "right", "bottom"]].set_visible(False)
×
1212

1213
    # width of bars, to fill x axis limits
1214
    width = (max(time_index) + 0.5 - min(time_index) - 0.5) / len(time_index)
×
1215

1216
    # create historical/projection divide
1217
    if divide is not None:
×
1218
        # convert divide year to transAxes
1219
        divide_disp = ax_0.transData.transform(
×
1220
            (divide - width * 0.5, 1)
1221
        )  # left limit of stripe, 1 is placeholder
1222
        divide_ax = ax_0.transAxes.inverted().transform(divide_disp)
×
1223
        divide_ax = divide_ax[0]
×
1224
    else:
1225
        divide_ax = 0
×
1226

1227
    # create an inset ax for each da in data
1228
    subaxes = {}
×
1229
    for i in np.arange(n):
×
1230
        name = "subax_" + str(i)
×
1231
        y = (1 / n) * i
×
1232
        subaxes[name] = ax_0.inset_axes([0, y, 1, 1 / n], transform=ax_0.transAxes)
×
1233
        subaxes[name].set(xlim=ax_0.get_xlim(), ylim=(0, 1), xticks=[], yticks=[])
×
1234
        subaxes[name].spines[["top", "bottom", "left", "right"]].set_visible(False)
×
1235
        # lines separating axes
1236
        if i > 0:
×
1237
            subaxes[name].spines["bottom"].set_visible(True)
×
1238
            subaxes[name].spines["bottom"].set(
×
1239
                lw=2,
1240
                color="w",
1241
                bounds=(divide_ax, 1),
1242
                transform=subaxes[name].transAxes,
1243
            )
1244
            # circles
1245
            if divide:
×
1246
                circle = matplotlib.patches.Ellipse(
×
1247
                    xy=(divide_ax, y),
1248
                    width=0.01,
1249
                    height=0.03,
1250
                    color="w",
1251
                    transform=ax_0.transAxes,
1252
                    zorder=10,
1253
                )
1254
                ax_0.add_patch(circle)
×
1255

1256
    # get max and min of all data
1257
    data_min = 1e6
×
1258
    data_max = -1e6
×
1259
    for da in data.values():
×
1260
        if min(da.values) < data_min:
×
1261
            data_min = min(da.values)
×
1262
        if max(da.values) > data_max:
×
1263
            data_max = max(da.values)
×
1264

1265
    # colormap
1266
    if isinstance(cmap, str):
×
1267
        if cmap in plt.colormaps():
×
1268
            cmap = matplotlib.colormaps[cmap]
×
1269
        else:
1270
            try:
×
1271
                cmap = create_cmap(filename=cmap)
×
1272
            except FileNotFoundError as e:
×
1273
                logger.error(e)
×
1274
                pass
×
1275

1276
    elif cmap is None:
×
1277
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1278
        cmap = create_cmap(
×
1279
            get_var_group(path_to_json=cdata, da=list(data.values())[0]),
1280
            divergent=True,
1281
        )
1282

1283
    # create cmap norm
1284
    if cmap_center is not None:
×
1285
        norm = matplotlib.colors.TwoSlopeNorm(cmap_center, vmin=data_min, vmax=data_max)
×
1286
    else:
1287
        norm = matplotlib.colors.Normalize(data_min, data_max)
×
1288

1289
    # plot
1290
    for (name, subax), (key, da) in zip(subaxes.items(), data.items()):
×
1291
        subax.bar(da.time.dt.year, height=1, width=dtime, color=cmap(norm(da.values)))
×
1292
        if divide:
×
1293
            if key != "_no_label":
×
1294
                subax.text(
×
1295
                    0.99,
1296
                    0.5,
1297
                    key,
1298
                    transform=subax.transAxes,
1299
                    fontsize=14,
1300
                    ha="right",
1301
                    va="center",
1302
                    c="w",
1303
                    weight="bold",
1304
                )
1305

1306
    # colorbar
1307
    if cbar is True:
×
1308
        sm = ScalarMappable(cmap=cmap, norm=norm)
×
1309
        cax = ax.inset_axes([0.01, 0.05, 0.35, 0.06])
×
1310
        cbar_tcks = np.arange(math.floor(data_min), math.ceil(data_max), 2)
×
1311
        # label
1312
        da = list(data.values())[0]
×
1313
        label = get_attributes("long_name", da)
×
1314
        if label != "":
×
1315
            if "units" in da.attrs:
×
1316
                u = da.units
×
1317
                label += f" ({u})"
×
1318
            label = wrap_text(label, max_line_len=40)
×
1319

1320
        cbar_kw = {
×
1321
            "cax": cax,
1322
            "orientation": "horizontal",
1323
            "ticks": cbar_tcks,
1324
            "label": label,
1325
        } | cbar_kw
1326
        plt.colorbar(sm, **cbar_kw)
×
1327
        cax.spines["outline"].set_visible(False)
×
1328
        cax.set_xscale("linear")
×
1329

1330
    return ax
×
1331

1332

1333
def heatmap(
6✔
1334
    data: xr.DataArray | xr.Dataset | dict[str, Any],
1335
    ax: matplotlib.axes.Axes | None = None,
1336
    use_attrs: dict[str, Any] | None = None,
1337
    fig_kw: dict[str, Any] | None = None,
1338
    plot_kw: dict[str, Any] | None = None,
1339
    transpose: bool = False,
1340
    cmap: str | matplotlib.colors.Colormap | None = "RdBu",
1341
    divergent: bool | int | float = False,
1342
) -> matplotlib.axes.Axes:
1343
    """Create heatmap from a DataArray.
1344

1345
    Parameters
1346
    ----------
1347
    data : dict or DataArray or Dataset
1348
        Input data do plot. If dictionary, must have only one entry.
1349
    ax : matplotlib axis, optional
1350
        Matplotlib axis on which to plot, with the same projection as the one specified.
1351
    use_attrs : dict, optional
1352
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1353
        Default value is {'cbar_label': 'long_name'}.
1354
        Only the keys found in the default dict can be used.
1355
    fig_kw : dict, optional
1356
        Arguments to pass to `plt.figure()`.
1357
    plot_kw :  dict, optional
1358
        Arguments to pass to the 'seaborn.heatmap()' function.
1359
        If 'data' is a dictionary, can be a nested dictionary with the same key as 'data'.
1360
    transpose : bool
1361
        If true, the 2D data will be transposed, so that the original x-axis becomes the y-axis and vice versa.
1362
    cmap : matplotlib.colors.Colormap or str, optional
1363
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1364
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1365
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
1366
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1367
    divergent : bool or int or float
1368
        If int or float, becomes center of cmap. Default center is 0.
1369

1370
    Returns
1371
    -------
1372
    matplotlib.axes.Axes
1373
    """
1374
    # create empty dicts if None
1375
    use_attrs = empty_dict(use_attrs)
×
1376
    fig_kw = empty_dict(fig_kw)
×
1377
    plot_kw = empty_dict(plot_kw)
×
1378

1379
    # set default use_attrs values
1380
    use_attrs.setdefault("cbar_label", "long_name")
×
1381

1382
    # if data is dict, extract
1383
    if isinstance(data, dict):
×
1384
        if plot_kw and list(data.keys())[0] in plot_kw.keys():
×
1385
            plot_kw = plot_kw[list(data.keys())[0]]
×
1386
        if len(data) == 1:
×
1387
            data = list(data.values())[0]
×
1388
        else:
1389
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
1390

1391
    # select data to plot
1392
    if isinstance(data, xr.DataArray):
×
1393
        da = data
×
1394
    elif isinstance(data, xr.Dataset):
×
1395
        if len(data.data_vars) > 1:
×
1396
            warnings.warn(
×
1397
                "data is xr.Dataset; only the first variable will be used in plot"
1398
            )
1399
        da = list(data.values())[0]
×
1400
    else:
1401
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
1402

1403
    # setup fig, axis
1404
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
1405
        fig, ax = plt.subplots(**fig_kw)
×
1406
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
1407
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
1408
    elif ax is None:
×
1409
        if any([k != "figsize" for k in fig_kw.keys()]):
×
1410
            warnings.warn(
×
1411
                "Only figsize arguments can be passed to fig_kw when using facetgrid."
1412
            )
1413
        plot_kw.setdefault("col", None)
×
1414
        plot_kw.setdefault("row", None)
×
1415
        plot_kw.setdefault("margin_titles", True)
×
1416
        heatmap_dims = list(
×
1417
            set(da.dims)
1418
            - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None}
1419
        )
1420
        if da.name is None:
×
1421
            da = da.to_dataset(name="data").data
×
1422
        da_name = da.name
×
1423

1424
    # create cbar label
1425
    if (
×
1426
        "cbar_units" in use_attrs
1427
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
1428
    ):  # avoids '()' as label
1429
        cbar_label = (
×
1430
            get_attributes(use_attrs["cbar_label"], data)
1431
            + " ("
1432
            + get_attributes(use_attrs["cbar_units"], data)
1433
            + ")"
1434
        )
1435
    else:
1436
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
1437

1438
    # colormap
1439
    if isinstance(cmap, str):
×
1440
        if cmap not in plt.colormaps():
×
1441
            try:
×
1442
                cmap = create_cmap(filename=cmap)
×
1443
            except FileNotFoundError as e:
×
1444
                logger.error(e)
×
1445
                pass
×
1446

1447
    elif cmap is None:
×
1448
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1449
        cmap = create_cmap(
×
1450
            get_var_group(path_to_json=cdata, da=da),
1451
            divergent=divergent,
1452
        )
1453

1454
    # convert data to DataFrame
1455
    if transpose:
×
1456
        da = da.transpose()
×
1457
    if "col" not in plot_kw and "row" not in plot_kw:
×
1458
        if len(da.dims) != 2:
×
1459
            raise ValueError("DataArray must have exactly two dimensions")
×
1460
        df = da.to_pandas()
×
1461
    else:
1462
        if len(heatmap_dims) != 2:
×
1463
            raise ValueError("DataArray must have exactly two dimensions")
×
1464
        df = da.to_dataframe().reset_index()
×
1465

1466
    # set defaults
1467
    if divergent is not False:
×
1468
        if isinstance(divergent, (int, float)):
×
1469
            plot_kw.setdefault("center", divergent)
×
1470
        else:
1471
            plot_kw.setdefault("center", 0)
×
1472

1473
    if "cbar" not in plot_kw or plot_kw["cbar"] is not False:
×
1474
        plot_kw.setdefault("cbar_kws", {})
×
1475
        plot_kw["cbar_kws"].setdefault("label", wrap_text(cbar_label))
×
1476

1477
    plot_kw.setdefault("cmap", cmap)
×
1478

1479
    # plot
1480
    def draw_heatmap(*args, **kwargs):
×
1481
        data = kwargs.pop("data")
×
1482
        d = (
×
1483
            data
1484
            if len(args) == 0
1485
            # Any sorting should be performed before sending a DataArray in `fg.heatmap`
1486
            else data.pivot_table(
1487
                index=args[1], columns=args[0], values=args[2], sort=False
1488
            )
1489
        )
1490
        ax = sns.heatmap(d, **kwargs)
×
1491
        ax.set_xticklabels(
×
1492
            ax.get_xticklabels(),
1493
            rotation=45,
1494
            ha="right",
1495
            rotation_mode="anchor",
1496
        )
1497
        ax.tick_params(axis="both", direction="out")
×
1498
        set_plot_attrs(
×
1499
            use_attrs,
1500
            da,
1501
            ax,
1502
            title_loc="center",
1503
            wrap_kw={"min_line_len": 35, "max_line_len": 44},
1504
        )
1505
        return ax
×
1506

1507
    if ax is not None:
×
1508
        ax = draw_heatmap(data=df, ax=ax, **plot_kw)
×
1509
        return ax
×
1510
    elif "col" in plot_kw or "row" in plot_kw:
×
1511
        # When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function
1512
        # With Seaborn, we need to be more careful and separate keywords.
1513
        plot_kw_hm = {
×
1514
            k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters
1515
        }
1516
        plot_kw_fg = {
×
1517
            k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters
1518
        }
1519
        unused_keys = (
×
1520
            set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys())
1521
        )
1522
        if unused_keys != set():
×
1523
            raise ValueError(
×
1524
                f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords "
1525
                "allowed in `sns.heatmap` or `sns.FacetGrid`. "
1526
            )
1527

1528
        g = sns.FacetGrid(df, **plot_kw_fg)
×
1529
        cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9])
×
1530
        g.map_dataframe(
×
1531
            draw_heatmap,
1532
            *heatmap_dims,
1533
            da_name,
1534
            **plot_kw_hm,
1535
            cbar=True,
1536
            cbar_ax=cax,
1537
        )
1538
        g.fig.subplots_adjust(right=0.9)
×
1539
        if "figsize" in fig_kw.keys():
×
1540
            g.fig.set_size_inches(*fig_kw["figsize"])
×
1541
        return g
×
1542

1543

1544
def scattermap(
6✔
1545
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
1546
    ax: matplotlib.axes.Axes | None = None,
1547
    use_attrs: dict[str, Any] | None = None,
1548
    fig_kw: dict[str, Any] | None = None,
1549
    plot_kw: dict[str, Any] | None = None,
1550
    projection: ccrs.Projection = ccrs.LambertConformal(),
1551
    transform: ccrs.Projection | None = None,
1552
    features: list[str] | dict[str, dict[str, Any]] | None = None,
1553
    geometries_kw: dict[str, Any] | None = None,
1554
    sizes: str | bool | None = None,
1555
    size_range: tuple = (10, 60),
1556
    cmap: str | matplotlib.colors.Colormap | None = None,
1557
    levels: int | None = None,
1558
    divergent: bool | int | float = False,
1559
    legend_kw: dict[str, Any] | None = None,
1560
    show_time: bool | str | int | tuple[float, float] = False,
1561
    frame: bool = False,
1562
    enumerate_subplots: bool = False,
1563
) -> matplotlib.axes.Axes:
1564
    """Make a scatter plot of georeferenced data on a map.
1565

1566
    Parameters
1567
    ----------
1568
    data : dict, DataArray or Dataset
1569
        Input data do plot. If dictionary, must have only one entry.
1570
    ax : matplotlib axis, optional
1571
        Matplotlib axis on which to plot, with the same projection as the one specified.
1572
    use_attrs : dict, optional
1573
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
1574
        Default value is {'title': 'description', 'cbar_label': 'long_name', 'cbar_units': 'units'}.
1575
        Only the keys found in the default dict can be used.
1576
    fig_kw : dict, optional
1577
        Arguments to pass to `plt.figure()`.
1578
    plot_kw :  dict, optional
1579
        Arguments to pass to `plt.scatter()`.
1580
        If 'data' is a dictionary, can be a dictionary with the same key as 'data'.
1581
    projection : ccrs.Projection
1582
        The projection to use, taken from the cartopy.crs options. Ignored if ax is not None.
1583
    transform : ccrs.Projection, optional
1584
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
1585
        ccrs.PlateCarree() or ccrs.RotatedPole().
1586
    features : list or dict, optional
1587
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
1588
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
1589
    geometries_kw : dict, optional
1590
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
1591
    sizes : bool or str, optional
1592
        String name of the coordinate to use for determining point size. If True, use the same data as in the colorbar.
1593
    size_range : tuple
1594
        Tuple of the minimum and maximum size of the points.
1595
    cmap : matplotlib.colors.Colormap or str, optional
1596
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
1597
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
1598
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
1599
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
1600
    levels : int, optional
1601
        Number of levels to divide the colormap into.
1602
    divergent : bool or int or float
1603
        If int or float, becomes center of cmap. Default center is 0.
1604
    legend_kw : dict, optional
1605
        Arguments to pass to plt.legend(). Some defaults {"loc": "lower left", "facecolor": "w", "framealpha": 1,
1606
            "edgecolor": "w", "bbox_to_anchor": (-0.05, 0)}
1607
    show_time : bool, tuple, string or int.
1608
        If True, show time (as date) at the bottom right of the figure.
1609
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
1610
        of the text. If a string or an int, the same values as those of the 'loc' parameter
1611
        of matplotlib's legends are accepted.
1612

1613
        ==================   =============
1614
        Location String      Location Code
1615
        ==================   =============
1616
        'upper right'        1
1617
        'upper left'         2
1618
        'lower left'         3
1619
        'lower right'        4
1620
        'right'              5
1621
        'center left'        6
1622
        'center right'       7
1623
        'lower center'       8
1624
        'upper center'       9
1625
        'center'             10
1626
        ==================   =============
1627
    frame : bool
1628
        Show or hide frame. Default False.
1629
    enumerate_subplots: bool
1630
        If True, enumerate subplots with letters.
1631
        Only works with facetgrids (pass `col` or `row` in plot_kw).
1632

1633
    Returns
1634
    -------
1635
    matplotlib.axes.Axes
1636
    """
1637
    # create empty dicts if None
1638
    use_attrs = empty_dict(use_attrs)
×
1639
    fig_kw = empty_dict(fig_kw)
×
1640
    plot_kw = empty_dict(plot_kw)
×
1641
    legend_kw = empty_dict(legend_kw)
×
1642

1643
    # set default use_attrs values
1644
    use_attrs = {"cbar_label": "long_name", "cbar_units": "units"} | use_attrs
×
1645
    if "row" not in plot_kw and "col" not in plot_kw:
×
1646
        use_attrs.setdefault("title", "description")
×
1647

1648
    # extract plot_kw from dict if needed
UNCOV
1649
    if isinstance(data, dict) and plot_kw and list(data.keys())[0] in plot_kw.keys():
×
NEW
1650
        plot_kw = plot_kw[list(data.keys())[0]]
×
1651

1652
    # figanos does not use xr.plot.scatter default markersize
1653
    if "markersize" in plot_kw.keys():
×
1654
        if not sizes:
×
1655
            sizes = plot_kw["markersize"]
×
NEW
1656
        plot_kw.pop("markersize")
×
1657

1658
    # if data is dict, extract
1659
    if isinstance(data, dict):
×
1660
        if len(data) == 1:
×
1661
            data = list(data.values())[0].squeeze()
×
1662
            if len(data.data_vars) > 1:
×
1663
                warnings.warn(
×
1664
                    "data is xr.Dataset; only the first variable will be used in plot"
1665
                )
1666
        else:
1667
            raise ValueError("If `data` is a dict, it must be of length 1.")
×
1668

1669
    # select data to plot and its xr.Dataset
1670
    if isinstance(data, xr.DataArray):
×
1671
        plot_data = data
×
1672
        data = xr.Dataset({plot_data.name: plot_data})
×
1673
    elif isinstance(data, xr.Dataset):
×
1674
        if len(data.data_vars) > 1:
×
1675
            warnings.warn(
×
1676
                "data is xr.Dataset; only the first variable will be used in plot"
1677
            )
1678
        plot_data = data[list(data.keys())[0]]
×
1679
    else:
1680
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
1681

1682
    # setup transform
1683
    if transform is None:
×
1684
        if "rlat" in data.dims and "rlon" in data.dims:
×
1685
            if hasattr(data, "rotated_pole"):
×
1686
                transform = get_rotpole(data)
×
1687
        elif (
×
1688
            "lat" in data.coords and "lon" in data.coords
1689
        ):  # need to work with station dims
1690
            transform = ccrs.PlateCarree()
×
1691

1692
    # setup fig, ax
1693
    if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
×
1694
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
1695
    elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
×
1696
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
1697
    elif ax is None:
×
NEW
1698
        plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
1699
        cfig_kw = fig_kw.copy()
×
1700
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
NEW
1701
            plot_kw.setdefault("figsize", fig_kw["figsize"])
×
1702
            cfig_kw.pop("figsize")
×
1703
        if len(cfig_kw) >= 1:
×
NEW
1704
            plot_kw = {"subplot_kws": {"projection": projection}} | plot_kw
×
1705
            warnings.warn(
×
1706
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid."
1707
            )
1708

1709
    # create cbar label
1710
    if (
×
1711
        "cbar_units" in use_attrs
1712
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
1713
    ):  # avoids '[]' as label
1714
        cbar_label = (
×
1715
            get_attributes(use_attrs["cbar_label"], data)
1716
            + " ("
1717
            + get_attributes(use_attrs["cbar_units"], data)
1718
            + ")"
1719
        )
1720
    else:
1721
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
1722

1723
    if "add_colorbar" not in plot_kw or plot_kw["add_colorbar"] is not False:
×
NEW
1724
        plot_kw.setdefault("cbar_kwargs", {})
×
NEW
1725
        plot_kw["cbar_kwargs"].setdefault("label", wrap_text(cbar_label))
×
NEW
1726
        plot_kw["cbar_kwargs"].setdefault("pad", 0.015)
×
1727

1728
    # colormap
1729
    if isinstance(cmap, str):
×
1730
        if cmap not in plt.colormaps():
×
1731
            try:
×
1732
                cmap = create_cmap(filename=cmap)
×
1733
            except FileNotFoundError as e:
×
1734
                logger.error(e)
×
1735
                pass
×
1736

1737
    elif cmap is None:
×
1738
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
1739
        cmap = create_cmap(
×
1740
            get_var_group(path_to_json=cdata, da=plot_data),
1741
            divergent=divergent,
1742
        )
1743

1744
    # nans (not required for plotting since xarray.plot handles np.nan, but needs to be found for sizes legend and to
1745
    # inform user on how many stations were dropped)
1746
    mask = ~np.isnan(plot_data.values)
×
1747
    if np.sum(mask) < len(mask):
×
1748
        warnings.warn(
×
1749
            f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the color values"
1750
        )
1751

1752
    # point sizes
1753
    if sizes:
×
1754
        if sizes is True:
×
1755
            sdata = plot_data
×
1756
        elif isinstance(sizes, str):
×
1757
            if hasattr(data, "name") and getattr(data, "name") == sizes:
×
1758
                sdata = plot_data
×
1759
            elif sizes in list(data.coords.keys()):
×
1760
                sdata = plot_data[sizes]
×
1761
            else:
1762
                raise ValueError(f"{sizes} not found")
×
1763
        else:
1764
            raise TypeError("sizes must be a string or a bool")
×
1765

1766
        # nans sizes
1767
        smask = ~np.isnan(sdata.values) & mask
×
1768
        if np.sum(smask) < np.sum(mask):
×
1769
            warnings.warn(
×
1770
                f"{np.sum(mask) - np.sum(smask)} nan values were dropped when setting the point size"
1771
            )
1772
            mask = smask
×
1773

1774
        pt_sizes = norm2range(
×
1775
            data=sdata.where(mask).values,
1776
            target_range=size_range,
1777
            data_range=None,
1778
        )
NEW
1779
        plot_kw.setdefault("add_legend", False)
×
1780
        if ax:
×
NEW
1781
            plot_kw.setdefault("s", pt_sizes)
×
1782
        else:
NEW
1783
            plot_kw.setdefault("s", pt_sizes[0])
×
1784

1785
    # norm
NEW
1786
    plot_kw.setdefault("vmin", np.nanmin(plot_data.values[mask]))
×
NEW
1787
    plot_kw.setdefault("vmax", np.nanmax(plot_data.values[mask]))
×
1788
    if levels is not None:
×
1789
        if isinstance(levels, Iterable):
×
1790
            lin = levels
×
1791
        else:
1792
            lin = custom_cmap_norm(
×
1793
                cmap,
1794
                np.nanmin(plot_data.values[mask]),
1795
                np.nanmax(plot_data.values[mask]),
1796
                levels=levels,
1797
                divergent=divergent,
1798
                linspace_out=True,
1799
            )
NEW
1800
        plot_kw.setdefault("levels", lin)
×
1801

1802
    elif (divergent is not False) and ("levels" not in plot_kw):
×
1803
        norm = custom_cmap_norm(
×
1804
            cmap,
1805
            np.nanmin(plot_data.values[mask]),
1806
            np.nanmax(plot_data.values[mask]),
1807
            levels=levels,
1808
            divergent=divergent,
1809
        )
NEW
1810
        plot_kw.setdefault("norm", norm)
×
1811

1812
    # matplotlib.pyplot.scatter treats "edgecolor" and "edgecolors" as aliases so we accept "edgecolor" and convert it
NEW
1813
    if "edgecolor" in plot_kw and "edgecolors" not in plot_kw:
×
NEW
1814
        plot_kw["edgecolors"] = plot_kw["edgecolor"]
×
NEW
1815
        plot_kw.pop("edgecolor")
×
1816

1817
    # set defaults and create copy without vmin, vmax (conflicts with norm)
NEW
1818
    plot_kw = {
×
1819
        "cmap": cmap,
1820
        "transform": transform,
1821
        "zorder": 8,
1822
        "marker": "o",
1823
    } | plot_kw
1824

1825
    # check if edgecolors in plot_kw and match len of plot_data
1826
    if "edgecolors" in plot_kw:
×
1827
        if matplotlib.colors.is_color_like(plot_kw["edgecolors"]):
×
NEW
1828
            plot_kw["edgecolors"] = np.repeat(
×
1829
                plot_kw["edgecolors"], len(plot_data.where(mask).values)
1830
            )
1831
        elif len(plot_kw["edgecolors"]) != len(plot_data.values):
×
NEW
1832
            plot_kw["edgecolors"] = np.repeat(
×
1833
                plot_kw["edgecolors"][0], len(plot_data.where(mask).values)
1834
            )
1835
            warnings.warn(
×
1836
                "Length of edgecolors does not match length of data. Only first edgecolor is used for plotting."
1837
            )
1838
        else:
1839
            if isinstance(plot_kw["edgecolors"], list):
×
NEW
1840
                plot_kw["edgecolors"] = np.array(plot_kw["edgecolors"])
×
NEW
1841
            plot_kw["edgecolors"] = plot_kw["edgecolors"][mask]
×
1842
    else:
NEW
1843
        plot_kw.setdefault("edgecolors", "none")
×
1844

NEW
1845
    for key in ["vmin", "vmax"]:
×
NEW
1846
        plot_kw.pop(key)
×
1847
    # plot
NEW
1848
    plot_kw = {"x": "lon", "y": "lat", "hue": plot_data.name} | plot_kw
×
1849
    if ax:
×
NEW
1850
        plot_kw.setdefault("ax", ax)
×
1851

NEW
1852
    plot_data_masked = plot_data.where(mask).to_dataset()
×
NEW
1853
    im = plot_data_masked.plot.scatter(**plot_kw)
×
1854

1855
    # add features
1856
    if ax:
×
1857
        ax = add_features_map(
×
1858
            data,
1859
            ax,
1860
            use_attrs,
1861
            projection,
1862
            features,
1863
            geometries_kw,
1864
            frame,
1865
        )
1866

1867
        if show_time:
×
1868
            if isinstance(show_time, bool):
×
1869
                plot_coords(
×
1870
                    ax,
1871
                    plot_data,
1872
                    param="time",
1873
                    loc="lower right",
1874
                    backgroundalpha=1,
1875
                )
1876
            elif isinstance(show_time, (str, tuple, int)):
×
1877
                plot_coords(
×
1878
                    ax,
1879
                    plot_data,
1880
                    param="time",
1881
                    loc=show_time,
1882
                    backgroundalpha=1,
1883
                )
1884

1885
        if (frame is False) and (im.colorbar is not None):
×
1886
            im.colorbar.outline.set_visible(False)
×
1887

1888
    else:
1889
        for i, fax in enumerate(im.axs.flat):
×
1890
            fax = add_features_map(
×
1891
                data,
1892
                fax,
1893
                use_attrs,
1894
                projection,
1895
                features,
1896
                geometries_kw,
1897
                frame,
1898
            )
1899

1900
            if sizes:
×
1901
                # correct markersize for facetgrid
1902
                scat = fax.collections[0]
×
1903
                scat.set_sizes(pt_sizes[i])
×
1904

1905
        if (frame is False) and (im.cbar is not None):
×
1906
            im.cbar.outline.set_visible(False)
×
1907

1908
        if show_time:
×
1909
            if isinstance(show_time, bool):
×
1910
                plot_coords(
×
1911
                    None,
1912
                    plot_data,
1913
                    param="time",
1914
                    loc="lower right",
1915
                    backgroundalpha=1,
1916
                )
1917
            elif isinstance(show_time, (str, tuple, int)):
×
1918
                plot_coords(
×
1919
                    None,
1920
                    plot_data,
1921
                    param="time",
1922
                    loc=show_time,
1923
                    backgroundalpha=1,
1924
                )
1925

1926
    # size legend
1927
    if sizes:
×
1928
        legend_elements = size_legend_elements(
×
1929
            np.resize(sdata.values[mask], (sdata.values[mask].size, 1)),
1930
            np.resize(pt_sizes[mask], (pt_sizes[mask].size, 1)),
1931
            max_entries=6,
1932
            marker=plot_kw["marker"],
1933
        )
1934
        # legend spacing
1935
        if size_range[1] > 200:
×
1936
            ls = 0.5 + size_range[1] / 100 * 0.125
×
1937
        else:
1938
            ls = 0.5
×
1939

1940
        legend_kw = {
×
1941
            "loc": "lower left",
1942
            "facecolor": "w",
1943
            "framealpha": 1,
1944
            "edgecolor": "w",
1945
            "labelspacing": ls,
1946
            "handles": legend_elements,
1947
            "bbox_to_anchor": (-0.05, -0.1),
1948
        } | legend_kw
1949

1950
        if "title" not in legend_kw:
×
1951
            if hasattr(sdata, "long_name"):
×
1952
                lgd_title = wrap_text(
×
1953
                    getattr(sdata, "long_name"), min_line_len=1, max_line_len=15
1954
                )
1955
                if hasattr(sdata, "units"):
×
1956
                    lgd_title += f" ({getattr(sdata, 'units')})"
×
1957
            else:
1958
                lgd_title = sizes
×
1959
            legend_kw.setdefault("title", lgd_title)
×
1960

1961
        if ax:
×
1962
            lgd = ax.legend(**legend_kw)
×
1963
            lgd.set_zorder(11)
×
1964
        else:
1965
            im.figlegend = im.fig.legend(**legend_kw)
×
1966
        # im._adjust_fig_for_guide(im.figlegend)
1967

1968
    if ax:
×
1969
        return ax
×
1970
    else:
1971
        im.fig.suptitle(get_attributes("long_name", data))
×
1972
        im.set_titles(template="{value}")
×
1973
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
1974
            for idx, ax in enumerate(im.axs.flat):
×
1975
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
1976

1977
        return im
×
1978

1979

1980
def taylordiagram(
6✔
1981
    data: xr.DataArray | dict[str, xr.DataArray],
1982
    plot_kw: dict[str, Any] | None = None,
1983
    fig_kw: dict[str, Any] | None = None,
1984
    std_range: tuple = (0, 1.5),
1985
    contours: int | None = 4,
1986
    contours_kw: dict[str, Any] | None = None,
1987
    ref_std_line: bool = False,
1988
    legend_kw: dict[str, Any] | None = None,
1989
    std_label: str | None = None,
1990
    corr_label: str | None = None,
1991
    colors_key: str | None = None,
1992
    markers_key: str | None = None,
1993
):
1994
    """Build a Taylor diagram.
1995

1996
    Based on the following code: https://gist.github.com/ycopin/3342888.
1997

1998
    Parameters
1999
    ----------
2000
    data : xr.DataArray or dict
2001
        DataArray or dictionary of DataArrays created by xclim.sdba.measures.taylordiagram, each corresponding
2002
        to a point on the diagram. The dictionary keys will become their labels.
2003
    plot_kw : dict, optional
2004
        Arguments to pass to the `plot()` function. Changes how the markers look.
2005
        If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'.
2006
    fig_kw : dict, optional
2007
        Arguments to pass to `plt.figure()`.
2008
    std_range : tuple
2009
        Range of the x and y axes, in units of the highest standard deviation in the data.
2010
    contours : int, optional
2011
        Number of rsme contours to plot.
2012
    contours_kw : dict, optional
2013
        Arguments to pass to `plt.contour()` for the rmse contours.
2014
    ref_std_line : bool, optional
2015
        If True, draws a circular line on radius `std = ref_std`. Default: False
2016
    legend_kw : dict, optional
2017
        Arguments to pass to `plt.legend()`.
2018
    std_label : str, optional
2019
        Label for the standard deviation (x and y) axes.
2020
    corr_label : str, optional
2021
        Label for the correlation axis.
2022
    colors_key : str, optional
2023
        Attribute or dimension of DataArrays used to separate DataArrays into groups with different colors. If present,
2024
        it overrides the "color" key in `plot_kw`.
2025
    markers_key : str, optional
2026
        Attribute or dimension of DataArrays used to separate DataArrays into groups with different markers. If present,
2027
        it overrides the "marker" key in `plot_kw`.
2028

2029
    Returns
2030
    -------
2031
    (plt.figure, mpl_toolkits.axisartist.floating_axes.FloatingSubplot, plt.legend)
2032
    """
2033
    plot_kw = empty_dict(plot_kw)
×
2034
    fig_kw = empty_dict(fig_kw)
×
2035
    contours_kw = empty_dict(contours_kw)
×
2036
    legend_kw = empty_dict(legend_kw)
×
2037

2038
    # preserve order of dimensions if used for marker/color
2039
    ordered_markers_type = None
×
2040
    ordered_colors_type = None
×
2041

2042
    # convert SSP, RCP, CMIP formats in keys
2043
    if isinstance(data, dict):
×
2044
        data = process_keys(data, convert_scen_name)
×
2045
    if isinstance(plot_kw, dict):
×
2046
        plot_kw = process_keys(plot_kw, convert_scen_name)
×
2047

2048
    # if only one data input, insert in dict.
2049
    if not isinstance(data, dict):
×
2050
        data = {"_no_label": data}  # mpl excludes labels starting with "_" from legend
×
2051
        plot_kw = {"_no_label": empty_dict(plot_kw)}
×
2052
    elif not plot_kw:
×
2053
        plot_kw = {k: {} for k in data.keys()}
×
2054
    # check type
2055
    for key, v in data.items():
×
2056
        if not isinstance(v, xr.DataArray):
×
2057
            raise TypeError("All objects in 'data' must be xarray DataArrays.")
×
2058
        if "taylor_param" not in v.dims:
×
2059
            raise ValueError("All DataArrays must contain a 'taylor_param' dimension.")
×
2060
        if key == "reference":
×
2061
            raise ValueError("'reference' is not allowed as a key in data.")
×
2062

2063
    # If there are other dimensions than 'taylor_param', create a bigger dict with them
2064
    data_keys = list(data.keys())
×
2065
    for data_key in data_keys:
×
2066
        da = data[data_key]
×
2067
        dims = list(set(da.dims) - {"taylor_param"})
×
2068
        if dims != []:
×
2069
            if markers_key in dims:
×
2070
                ordered_markers_type = da[markers_key].values
×
2071
            if colors_key in dims:
×
2072
                ordered_colors_type = da[colors_key].values
×
2073

2074
            da = da.stack(pl_dims=dims)
×
2075
            for i, dim_key in enumerate(da.pl_dims.values):
×
2076
                if isinstance(dim_key, list) or isinstance(dim_key, tuple):
×
2077
                    dim_key = "-".join([str(k) for k in dim_key])
×
2078
                da0 = da.isel(pl_dims=i)
×
2079
                # if colors_key/markers_key is a dimension, add it as an attribute for later use
2080
                if markers_key in dims:
×
2081
                    da0.attrs[markers_key] = da0[markers_key].values.item()
×
2082
                if colors_key in dims:
×
2083
                    da0.attrs[colors_key] = da0[colors_key].values.item()
×
2084
                new_data_key = (
×
2085
                    f"{data_key}-{dim_key}" if data_key != "_no_label" else dim_key
2086
                )
2087
                data[new_data_key] = da0
×
2088
                plot_kw[new_data_key] = empty_dict(plot_kw[f"{data_key}"])
×
2089
            data.pop(data_key)
×
2090
            plot_kw.pop(data_key)
×
2091

2092
    # remove negative correlations
2093
    initial_len = len(data)
×
2094
    removed = [
×
2095
        key for key, da in data.items() if da.sel(taylor_param="corr").values < 0
2096
    ]
2097
    data = {
×
2098
        key: da for key, da in data.items() if da.sel(taylor_param="corr").values >= 0
2099
    }
2100
    if len(data) != initial_len:
×
2101
        warnings.warn(
×
2102
            f"{initial_len - len(data)} points with negative correlations will not be plotted: {', '.join(removed)}"
2103
        )
2104

2105
    # add missing keys to plot_kw
2106
    for key in data.keys():
×
2107
        if key not in plot_kw:
×
2108
            plot_kw[key] = {}
×
2109

2110
    # extract ref to be used in plot
2111
    ref_std = list(data.values())[0].sel(taylor_param="ref_std").values
×
2112
    # check if ref is the same in all DataArrays and get the highest std (for ax limits)
2113
    if len(data) > 1:
×
2114
        for key, da in data.items():
×
2115
            if da.sel(taylor_param="ref_std").values != ref_std:
×
2116
                raise ValueError(
×
2117
                    "All reference standard deviation values must be identical"
2118
                )
2119

2120
    # get highest std for axis limits
2121
    max_std = [ref_std]
×
2122
    for key, da in data.items():
×
2123
        max_std.append(
×
2124
            float(
2125
                max(
2126
                    da.sel(taylor_param="ref_std").values,
2127
                    da.sel(taylor_param="sim_std").values,
2128
                )
2129
            )
2130
        )
2131

2132
    # make labels
2133
    if not std_label:
×
2134
        try:
×
2135
            units = list(data.values())[0].units
×
2136
            std_label = get_localized_term("standard deviation")
×
2137
            std_label = std_label if units == "" else f"{std_label} ({units})"
×
2138
        except AttributeError:
×
2139
            std_label = get_localized_term("standard deviation").capitalize()
×
2140

2141
    if not corr_label:
×
2142
        try:
×
2143
            if "Pearson" in list(data.values())[0].correlation_type:
×
2144
                corr_label = get_localized_term("pearson correlation").capitalize()
×
2145
            else:
2146
                corr_label = get_localized_term("correlation").capitalize()
×
2147
        except AttributeError:
×
2148
            corr_label = get_localized_term("correlation").capitalize()
×
2149

2150
    # build diagram
2151
    transform = PolarAxes.PolarTransform()
×
2152

2153
    # Setup the axis, here we map angles in degrees to angles in radius
2154
    # Correlation labels
2155
    rlocs = np.array([0, 0.2, 0.4, 0.6, 0.7, 0.8, 0.9, 0.95, 0.99, 1])
×
2156
    tlocs = np.arccos(rlocs)  # Conversion to polar angles
×
2157
    gl1 = gf.FixedLocator(tlocs)  # Positions
×
2158
    tf1 = gf.DictFormatter(dict(zip(tlocs, map(str, rlocs))))
×
2159
    # Standard deviation axis extent
2160
    radius_min = std_range[0] * max(max_std)
×
2161
    radius_max = std_range[1] * max(max_std)
×
2162

2163
    # Set up the axes range in the parameter "extremes"
2164
    ghelper = GridHelperCurveLinear(
×
2165
        transform,
2166
        extremes=(0, np.pi / 2, radius_min, radius_max),
2167
        grid_locator1=gl1,
2168
        tick_formatter1=tf1,
2169
    )
2170

2171
    fig = plt.figure(**fig_kw)
×
2172
    floating_ax = FloatingSubplot(fig, 111, grid_helper=ghelper)
×
2173
    fig.add_subplot(floating_ax)
×
2174

2175
    # Adjust axes
2176
    floating_ax.axis["top"].set_axis_direction("bottom")  # "Angle axis"
×
2177
    floating_ax.axis["top"].toggle(ticklabels=True, label=True)
×
2178
    floating_ax.axis["top"].major_ticklabels.set_axis_direction("top")
×
2179
    floating_ax.axis["top"].label.set_axis_direction("top")
×
2180
    floating_ax.axis["top"].label.set_text(corr_label)
×
2181

2182
    floating_ax.axis["left"].set_axis_direction("bottom")  # "X axis"
×
2183
    floating_ax.axis["left"].label.set_text(std_label)
×
2184

2185
    floating_ax.axis["right"].set_axis_direction("top")  # "Y axis"
×
2186
    floating_ax.axis["right"].toggle(ticklabels=True, label=True)
×
2187
    floating_ax.axis["right"].major_ticklabels.set_axis_direction("left")
×
2188
    floating_ax.axis["right"].label.set_text(std_label)
×
2189

2190
    floating_ax.axis["bottom"].set_visible(False)  # Useless
×
2191

2192
    # Contours along standard deviations
2193
    floating_ax.grid(visible=True, alpha=0.4)
×
2194
    floating_ax.set_title("")
×
2195

2196
    ax = floating_ax.get_aux_axes(transform)  # return the axes that can be plotted on
×
2197

2198
    # plot reference
2199
    if "reference" in plot_kw:
×
2200
        ref_kw = plot_kw.pop("reference")
×
2201
    else:
2202
        ref_kw = {}
×
2203
    ref_kw = {
×
2204
        "color": "#154504",
2205
        "marker": "s",
2206
        "label": get_localized_term("reference"),
2207
    } | ref_kw
2208

2209
    ref_pt = ax.scatter(0, ref_std, **ref_kw)
×
2210

2211
    points = [ref_pt]  # set up for later
×
2212

2213
    # plot a circular line along `ref_std`
2214
    if ref_std_line:
×
2215
        angles_for_line = np.linspace(0, np.pi / 2, 100)
×
2216
        radii_for_line = np.full_like(angles_for_line, ref_std)
×
2217
        ax.plot(
×
2218
            angles_for_line,
2219
            radii_for_line,
2220
            color=ref_kw["color"],
2221
            linewidth=0.5,
2222
            linestyle="-",
2223
        )
2224

2225
    # rmse contours from reference standard deviation
2226
    if contours:
×
2227
        radii, angles = np.meshgrid(
×
2228
            np.linspace(radius_min, radius_max),
2229
            np.linspace(0, np.pi / 2),
2230
        )
2231
        # Compute centered RMS difference
2232
        rms = np.sqrt(ref_std**2 + radii**2 - 2 * ref_std * radii * np.cos(angles))
×
2233

2234
        contours_kw = {"linestyles": "--", "linewidths": 0.5} | contours_kw
×
2235
        ct = ax.contour(angles, radii, rms, levels=contours, **contours_kw)
×
2236

2237
        ax.clabel(ct, ct.levels, fontsize=8)
×
2238

2239
        # points.append(ct_line)
2240
        ct_line = ax.plot(
×
2241
            [0],
2242
            [0],
2243
            ls=contours_kw["linestyles"],
2244
            lw=1,
2245
            c="k" if "colors" not in contours_kw else contours_kw["colors"],
2246
            label="rmse",
2247
        )
2248
        points.append(ct_line[0])
×
2249

2250
    # get color options
2251
    style_colors = matplotlib.rcParams["axes.prop_cycle"].by_key()["color"]
×
2252
    if len(data) > len(style_colors):
×
2253
        style_colors = style_colors * math.ceil(len(data) / len(style_colors))
×
2254
    cat_colors = Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
×
2255
    # get marker options (only used if `markers_key` is set)
2256
    style_markers = "oDv^<>p*hH+x|_"
×
2257
    if len(data) > len(style_markers):
×
2258
        style_markers = style_markers * math.ceil(len(data) / len(style_markers))
×
2259

2260
    # set colors and markers styles based on discrimnating attributes (if specified)
2261
    if colors_key or markers_key:
×
2262
        if colors_key:
×
2263
            # get_scen_color : look for SSP, RCP, CMIP model color
2264
            colors_type = (
×
2265
                ordered_colors_type
2266
                if ordered_colors_type is not None
2267
                else {da.attrs[colors_key] for da in data.values()}
2268
            )
2269
            colorsd = {
×
2270
                k: get_scen_color(k, cat_colors) or style_colors[i]
2271
                for i, k in enumerate(colors_type)
2272
            }
2273
        if markers_key:
×
2274
            markers_type = (
×
2275
                ordered_markers_type
2276
                if ordered_markers_type is not None
2277
                else {da.attrs[markers_key] for da in data.values()}
2278
            )
2279
            markersd = {k: style_markers[i] for i, k in enumerate(markers_type)}
×
2280

2281
        for key, da in data.items():
×
2282
            if colors_key:
×
2283
                plot_kw[key]["color"] = colorsd[da.attrs[colors_key]]
×
2284
            if markers_key:
×
2285
                plot_kw[key]["marker"] = markersd[da.attrs[markers_key]]
×
2286

2287
    # plot scatter
2288
    for (key, da), i in zip(data.items(), range(len(data))):
×
2289
        # look for SSP, RCP, CMIP model color
2290
        if colors_key is None:
×
2291
            plot_kw[key].setdefault(
×
2292
                "color", get_scen_color(key, cat_colors) or style_colors[i]
2293
            )
2294
        # set defaults
2295
        plot_kw[key] = {"label": key} | plot_kw[key]
×
2296

2297
        # legend will be handled later in this case
2298
        if markers_key or colors_key:
×
2299
            plot_kw[key]["label"] = ""
×
2300

2301
        # plot
2302
        pt = ax.scatter(
×
2303
            np.arccos(da.sel(taylor_param="corr").values),
2304
            da.sel(taylor_param="sim_std").values,
2305
            **plot_kw[key],
2306
        )
2307
        points.append(pt)
×
2308

2309
    # legend
2310
    legend_kw.setdefault("loc", "upper right")
×
2311
    legend = fig.legend(points, [pt.get_label() for pt in points], **legend_kw)
×
2312

2313
    # plot new legend if markers/colors represent a certain dimension
2314
    if colors_key or markers_key:
×
2315
        handles = list(floating_ax.get_legend_handles_labels()[0])
×
2316
        if markers_key:
×
2317
            for k, m in markersd.items():
×
2318
                handles.append(Line2D([0], [0], color="k", label=k, marker=m, ls=""))
×
2319
        if colors_key:
×
2320
            for k, c in colorsd.items():
×
2321
                handles.append(Line2D([0], [0], color=c, label=k, ls="-"))
×
2322
        legend.remove()
×
2323
        legend = fig.legend(handles=handles, **legend_kw)
×
2324

2325
    return fig, floating_ax, legend
×
2326

2327

2328
def hatchmap(
6✔
2329
    data: dict[str, Any] | xr.DataArray | xr.Dataset,
2330
    ax: matplotlib.axes.Axes | None = None,
2331
    use_attrs: dict[str, Any] | None = None,
2332
    fig_kw: dict[str, Any] | None = None,
2333
    plot_kw: dict[str, Any] | None = None,
2334
    projection: ccrs.Projection = ccrs.LambertConformal(),
2335
    transform: ccrs.Projection | None = None,
2336
    features: list[str] | dict[str, dict[str, Any]] | None = None,
2337
    geometries_kw: dict[str, Any] | None = None,
2338
    levels: int | None = None,
2339
    legend_kw: dict[str, Any] | bool = True,
2340
    show_time: bool | str | int | tuple[float, float] = False,
2341
    frame: bool = False,
2342
    enumerate_subplots: bool = False,
2343
) -> matplotlib.axes.Axes:
2344
    """Create map of hatches from 2D data.
2345

2346
    Parameters
2347
    ----------
2348
    data : dict, DataArray or Dataset
2349
        Input data do plot.
2350
    ax : matplotlib axis, optional
2351
        Matplotlib axis on which to plot, with the same projection as the one specified.
2352
    use_attrs : dict, optional
2353
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
2354
        Default value is {'title': 'description'}.
2355
        Only the keys found in the default dict can be used.
2356
    fig_kw : dict, optional
2357
        Arguments to pass to `plt.figure()`.
2358
    plot_kw:  dict, optional
2359
        Arguments to pass to 'xarray.plot.contourf()' function.
2360
        If 'data' is a dictionary, can be a nested dictionary with the same keys as 'data'.
2361
    projection : ccrs.Projection
2362
        The projection to use, taken from the cartopy.ccrs options. Ignored if ax is not None.
2363
    transform : ccrs.Projection, optional
2364
        Transform corresponding to the data coordinate system. If None, an attempt is made to find dimensions matching
2365
        ccrs.PlateCarree() or ccrs.RotatedPole().
2366
    features : list or dict, optional
2367
        Features to use, as a list or a nested dict containing kwargs. Options are the predefined features from
2368
        cartopy.feature: ['coastline', 'borders', 'lakes', 'land', 'ocean', 'rivers', 'states'].
2369
    geometries_kw : dict, optional
2370
        Arguments passed to cartopy ax.add_geometry() which adds given geometries (GeoDataFrame geometry) to axis.
2371
    legend_kw : dict or boolean, optional
2372
        Arguments to pass to `ax.legend()`. No legend is added if legend_kw == False.
2373
    show_time : bool, tuple, string or int.
2374
        If True, show time (as date) at the bottom right of the figure.
2375
        Can be a tuple of axis coordinates (0 to 1, as a fraction of the axis length) representing the location
2376
        of the text. If a string or an int, the same values as those of the 'loc' parameter
2377
        of matplotlib's legends are accepted.
2378

2379
        ==================   =============
2380
        Location String      Location Code
2381
        ==================   =============
2382
        'upper right'        1
2383
        'upper left'         2
2384
        'lower left'         3
2385
        'lower right'        4
2386
        'right'              5
2387
        'center left'        6
2388
        'center right'       7
2389
        'lower center'       8
2390
        'upper center'       9
2391
        'center'             10
2392
        ==================   =============
2393
    frame : bool
2394
        Show or hide frame. Default False.
2395
    enumerate_subplots: bool
2396
        If True, enumerate subplots with letters.
2397
        Only works with facetgrids (pass `col` or `row` in plot_kw).
2398

2399
    Returns
2400
    -------
2401
    matplotlib.axes.Axes
2402
    """
2403
    # default hatches
2404
    dfh = [
×
2405
        "/",
2406
        "\\",
2407
        "|",
2408
        "-",
2409
        "+",
2410
        "x",
2411
        "o",
2412
        "O",
2413
        ".",
2414
        "*",
2415
        "//",
2416
        "\\\\",
2417
        "||",
2418
        "--",
2419
        "++",
2420
        "xx",
2421
        "oo",
2422
        "OO",
2423
        "..",
2424
        "**",
2425
    ]
2426

2427
    # create empty dicts if None
2428
    use_attrs = empty_dict(use_attrs)
×
2429
    fig_kw = empty_dict(fig_kw)
×
2430
    plot_kw = empty_dict(plot_kw)
×
2431
    legend_kw = empty_dict(legend_kw)
×
2432

2433
    dattrs = None
×
2434
    plot_data = {}
×
2435

2436
    # convert data to dict (if not one)
2437
    if not isinstance(data, dict):
×
2438
        if isinstance(data, xr.DataArray):
×
2439
            plot_data = {data.name: data}
×
NEW
2440
            if data.name not in plot_kw.keys():
×
NEW
2441
                plot_kw = {data.name: plot_kw}
×
2442
        elif isinstance(data, xr.Dataset):
×
2443
            dattrs = data
×
2444
            plot_data = {var: data[var] for var in data.data_vars}
×
2445
            for v in plot_data.keys():
×
2446
                if v not in plot_kw.keys():
×
NEW
2447
                    plot_kw[v] = plot_kw
×
2448
    else:
2449
        for k, v in data.items():
×
2450
            if k not in plot_kw.keys():
×
NEW
2451
                plot_kw[k] = plot_kw
×
2452
            if isinstance(v, xr.Dataset):
×
2453
                dattrs = k
×
2454
                plot_data[k] = v[list(v.data_vars)[0]]
×
2455
                warnings.warn("Only first variable of Dataset is plotted.")
×
2456
            else:
2457
                plot_data[k] = v
×
2458

2459
    # setup transform from first data entry
2460
    trdata = list(plot_data.values())[0]
×
2461
    if transform is None:
×
2462
        if "lat" in trdata.dims and "lon" in trdata.dims:
×
2463
            transform = ccrs.PlateCarree()
×
2464
        elif "rlat" in trdata.dims and "rlon" in trdata.dims:
×
2465
            if hasattr(list(plot_data.values())[0], "rotated_pole"):
×
2466
                transform = get_rotpole(list(plot_data.values())[0])
×
2467

2468
    # bug xlim / ylim + transform in facetgrids
2469
    # (see https://github.com/pydata/xarray/issues/8562#issuecomment-1865189766)
2470
    if transform and (
×
2471
        "xlim" in list(plot_kw.values())[0] and "ylim" in list(plot_kw.values())[0]
2472
    ):
NEW
2473
        extent = [
×
2474
            list(plot_kw.values())[0]["xlim"][0],
2475
            list(plot_kw.values())[0]["xlim"][1],
2476
            list(plot_kw.values())[0]["ylim"][0],
2477
            list(plot_kw.values())[0]["ylim"][1],
2478
        ]
NEW
2479
        [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]
×
2480

2481
    elif transform and (
×
2482
        "xlim" in list(plot_kw.values())[0] or "ylim" in list(plot_kw.values())[0]
2483
    ):
NEW
2484
        extent = None
×
2485
        warnings.warn(
×
2486
            "Requires both xlim and ylim with 'transform'. Xlim or ylim was dropped"
2487
        )
NEW
2488
        [v.pop(lim) for lim in ["xlim", "ylim"] for v in plot_kw.values() if lim in v]
×
2489

2490
    else:
NEW
2491
        extent = None
×
2492

2493
    # setup fig, ax
2494
    if ax is None and (
×
2495
        "row" not in list(plot_kw.values())[0].keys()
2496
        and "col" not in list(plot_kw.values())[0].keys()
2497
    ):
2498
        fig, ax = plt.subplots(subplot_kw={"projection": projection}, **fig_kw)
×
2499
    elif ax is not None and (
×
2500
        "col" in list(plot_kw.values())[0].keys()
2501
        or "row" in list(plot_kw.values())[0].keys()
2502
    ):
2503
        raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
×
2504
    elif ax is None:
×
NEW
2505
        [
×
2506
            v.setdefault("subplot_kws", {}).setdefault("projection", projection)
2507
            for v in plot_kw.values()
2508
        ]
NEW
2509
        cfig_kw = copy.deepcopy(fig_kw)
×
2510
        if "figsize" in fig_kw:  # add figsize to plot_kw for facetgrid
×
2511
            plot_kw[0].setdefault("figsize", fig_kw["figsize"])
×
2512
            cfig_kw.pop("figsize")
×
2513
        if cfig_kw:
×
2514
            for v in plot_kw.values():
×
2515
                {"subplots_kws": cfig_kw} | v
×
2516
            warnings.warn(
×
2517
                "Only figsize and figure.add_subplot() arguments can be passed to fig_kw when using facetgrid."
2518
            )
2519

2520
    pat_leg = []
×
2521
    n = 0
×
2522
    for k, v in plot_data.items():
×
2523
        # if levels plot multiple hatching from one data entry
2524
        if "levels" in plot_kw[k] and len(plot_data) == 1:
×
2525
            # nans
2526
            mask = ~np.isnan(v.values)
×
2527
            if np.sum(mask) < len(mask):
×
2528
                warnings.warn(
×
2529
                    f"{len(mask) - np.sum(mask)} nan values were dropped when plotting the pattern values"
2530
                )
2531
            if "hatches" in plot_kw[k] and plot_kw[k]["levels"] != len(
×
2532
                plot_kw[k]["hatches"]
2533
            ):
2534
                warnings.warn("Hatches number is not equivalent to number of levels")
×
2535
                hatches = dfh[0:levels]
×
2536
            if "hatches" not in plot_kw[k]:
×
2537
                hatches = dfh[0:levels]
×
2538

2539
            plot_kw[k] = {
×
2540
                "hatches": hatches,
2541
                "colors": "none",
2542
                "add_colorbar": False,
2543
            } | plot_kw[k]
2544

2545
            if "lat" in v.dims:
×
2546
                v.coords["mask"] = (("lat", "lon"), mask)
×
2547
            else:
2548
                v.coords["mask"] = (("rlat", "rlon"), mask)
×
2549

2550
            plot_kw[k].setdefault("transform", transform)
×
2551
            if ax:
×
2552
                plot_kw[k].setdefault("ax", ax)
×
2553

2554
            im = v.where(mask is not True).plot.contourf(**plot_kw[k])
×
2555
            artists, labels = im.legend_elements(str_format="{:2.1f}".format)
×
2556

NEW
2557
            if ax and legend_kw:
×
2558
                ax.legend(artists, labels, **legend_kw)
×
NEW
2559
            elif legend_kw:
×
2560
                im.figlegend = im.fig.legend(**legend_kw)
×
2561

2562
        elif len(plot_data) > 1 and "levels" in plot_kw[k]:
×
2563
            raise TypeError(
×
2564
                "To plot levels only one xr.DataArray or xr.Dataset accepted"
2565
            )
2566
        else:
2567
            # since pattern remove colors and colorbar from plotting (done by gridmap)
2568
            plot_kw[k] = {"colors": "none", "add_colorbar": False} | plot_kw[k]
×
2569

2570
            if "hatches" not in plot_kw[k].keys():
×
2571
                plot_kw[k]["hatches"] = dfh[n]
×
2572
                n += 1
×
NEW
2573
            elif isinstance(
×
2574
                plot_kw[k]["hatches"], str
2575
            ):  # make sure the hatches are in a list
NEW
2576
                warnings.warn(
×
2577
                    "Hatches argument must be of type 'list'. Wrapping string argument as list."
2578
                )
NEW
2579
                plot_kw[k]["hatches"] = [plot_kw[k]["hatches"]]
×
2580

2581
            plot_kw[k].setdefault("transform", transform)
×
2582
            if ax:
×
2583
                im = v.plot.contourf(ax=ax, **plot_kw[k])
×
2584

2585
            if not ax:
×
2586
                if k == list(plot_data.keys())[0]:
×
2587
                    im = v.plot.contourf(**plot_kw[k])
×
2588

2589
                for i, fax in enumerate(im.axs.flat):
×
2590
                    if len(plot_data) > 1 and k != list(plot_data.keys())[0]:
×
2591
                        # select data to plot from DataSet in loop to plot on facetgrids axis
2592
                        c_pkw = plot_kw[k].copy()
×
2593
                        c_pkw.pop("subplot_kws")
×
2594
                        sel = {}
×
2595
                        if "row" in c_pkw.keys():
×
2596
                            sel[c_pkw["row"]] = i
×
2597
                            c_pkw.pop("row")
×
2598
                        elif "col" in c_pkw.keys():
×
2599
                            sel[c_pkw["col"]] = i
×
2600
                            c_pkw.pop("col")
×
2601
                        v.isel(sel).plot.contourf(ax=fax, **c_pkw)
×
2602

2603
                    if k == list(plot_data.keys())[-1]:
×
2604
                        add_features_map(
×
2605
                            dattrs,
2606
                            fax,
2607
                            use_attrs,
2608
                            projection,
2609
                            features,
2610
                            geometries_kw,
2611
                            frame,
2612
                        )
NEW
2613
                        if extent:
×
NEW
2614
                            fax.set_extent(extent)
×
2615

2616
            pat_leg.append(
×
2617
                matplotlib.patches.Patch(
2618
                    hatch=plot_kw[k]["hatches"][0], fill=False, label=k
2619
                )
2620
            )
2621

NEW
2622
    if pat_leg and legend_kw:
×
2623
        legend_kw = {
×
2624
            "loc": "lower right",
2625
            "handleheight": 2,
2626
            "handlelength": 4,
2627
        } | legend_kw
2628

NEW
2629
        if ax and legend_kw:
×
2630
            ax.legend(handles=pat_leg, **legend_kw)
×
NEW
2631
        elif legend_kw:
×
2632
            im.figlegend = im.fig.legend(handles=pat_leg, **legend_kw)
×
2633

2634
    # add features
2635
    if ax:
×
NEW
2636
        if extent:
×
NEW
2637
            ax.set_extent(extent)
×
2638
        if dattrs:
×
2639
            use_attrs.setdefault("title", "description")
×
2640

2641
        ax = add_features_map(
×
2642
            dattrs,
2643
            ax,
2644
            use_attrs,
2645
            projection,
2646
            features,
2647
            geometries_kw,
2648
            frame,
2649
        )
2650

2651
        if show_time:
×
2652
            if isinstance(show_time, bool):
×
2653
                plot_coords(
×
2654
                    ax,
2655
                    plot_data,
2656
                    param="time",
2657
                    loc="lower right",
2658
                    backgroundalpha=1,
2659
                )
2660
            elif isinstance(show_time, (str, tuple, int)):
×
2661
                plot_coords(
×
2662
                    ax,
2663
                    plot_data,
2664
                    param="time",
2665
                    loc=show_time,
2666
                    backgroundalpha=1,
2667
                )
2668

2669
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
2670
        if (frame is False) and (
×
2671
            (getattr(im, "colorbar", None) is not None)
2672
            or (getattr(im, "cbar", None) is not None)
2673
        ):
2674
            im.colorbar.outline.set_visible(False)
×
2675

2676
            set_plot_attrs(use_attrs, dattrs, ax, wrap_kw={"max_line_len": 60})
×
2677
        return ax
×
2678

2679
    else:
2680
        # when im is an ax, it has a colorbar attribute. If it is a facetgrid, it has a cbar attribute.
2681
        if (frame is False) and (
×
2682
            (getattr(im, "colorbar", None) is not None)
2683
            or (getattr(im, "cbar", None) is not None)
2684
        ):
2685
            im.cbar.outline.set_visible(False)
×
2686

2687
        if show_time:
×
2688
            if show_time is True:
×
2689
                plot_coords(
×
2690
                    None,
2691
                    dattrs,
2692
                    param="time",
2693
                    loc="lower right",
2694
                    backgroundalpha=1,
2695
                )
2696
            elif isinstance(show_time, (str, tuple, int)):
×
2697
                plot_coords(
×
2698
                    None, dattrs, param="time", loc=show_time, backgroundalpha=1
2699
                )
2700
        if dattrs:
×
2701
            use_attrs.setdefault("suptitle", "long_name")
×
2702
            set_plot_attrs(use_attrs, dattrs, facetgrid=im)
×
2703

2704
        if enumerate_subplots and isinstance(im, xr.plot.facetgrid.FacetGrid):
×
2705
            for idx, ax in enumerate(im.axs.flat):
×
2706
                ax.set_title(f"{string.ascii_lowercase[idx]}) {ax.get_title()}")
×
2707

2708
        return im
×
2709

2710

2711
def _add_lead_time_coord(da, ref):
6✔
2712
    """Add a lead time coordinate to the data. Modifies da in-place."""
2713
    lead_time = da.time.dt.year - int(ref)
×
2714
    da["Lead time"] = lead_time
×
2715
    da["Lead time"].attrs["units"] = f"years from {ref}"
×
2716
    return lead_time
×
2717

2718

2719
def partition(
6✔
2720
    data: xr.DataArray | xr.Dataset,
2721
    ax: matplotlib.axes.Axes | None = None,
2722
    start_year: str | None = None,
2723
    show_num: bool = True,
2724
    fill_kw: dict[str, Any] | None = None,
2725
    line_kw: dict[str, Any] | None = None,
2726
    fig_kw: dict[str, Any] | None = None,
2727
    legend_kw: dict[str, Any] | None = None,
2728
) -> matplotlib.axes.Axes:
2729
    """Figure of the partition of total uncertainty by components.
2730

2731
    Uncertainty fractions can be computed with xclim (https://xclim.readthedocs.io/en/stable/api.html#uncertainty-partitioning).
2732
    Make sure the use `fraction=True` in the xclim function call.
2733

2734
    Parameters
2735
    ----------
2736
    data : xr.DataArray or xr.Dataset
2737
        Variance over time of the different components of uncertainty.
2738
        Output of a `xclim.ensembles._partitioning` function.
2739
    ax : matplotlib axis, optional
2740
        Matplotlib axis on which to plot.
2741
    start_year : str
2742
        If None, the x-axis will be the time in year.
2743
        If str, the x-axis will show the number of year since start_year.
2744
    show_num : bool
2745
        If True, show the number of elements for each uncertainty components in parentheses in the legend.
2746
        `data` should have attributes named after the components with a list of its the elements.
2747
    fill_kw : dict
2748
        Keyword arguments passed to `ax.fill_between`.
2749
        It is possible to pass a dictionary of keywords for each component (uncertainty coordinates).
2750
    line_kw : dict
2751
        Keyword arguments passed to `ax.plot` for the lines in between the components.
2752
        The default is {color="k", lw=2}. We recommend always using lw>=2.
2753
    fig_kw : dict
2754
        Keyword arguments passed to `plt.subplots`.
2755
    legend_kw : dict
2756
        Keyword arguments passed to `ax.legend`.
2757

2758
    Returns
2759
    -------
2760
    mpl.axes.Axes
2761
    """
2762
    if isinstance(data, xr.Dataset):
×
2763
        if len(data.data_vars) > 1:
×
2764
            warnings.warn(
×
2765
                "data is xr.Dataset; only the first variable will be used in plot"
2766
            )
2767
        data = data[list(data.keys())[0]].squeeze()
×
2768

2769
    if data.attrs["units"] != "%":
×
2770
        raise ValueError(
×
2771
            "The units are not %. Use `fraction=True` in the xclim function call."
2772
        )
2773

2774
    fill_kw = empty_dict(fill_kw)
×
2775
    line_kw = empty_dict(line_kw)
×
2776
    fig_kw = empty_dict(fig_kw)
×
2777
    legend_kw = empty_dict(legend_kw)
×
2778

2779
    # select data to plot
2780
    if isinstance(data, xr.DataArray):
×
2781
        data = data.squeeze()
×
2782
    elif isinstance(data, xr.Dataset):  # in case, it was saved to disk before plotting.
×
2783
        if len(data.data_vars) > 1:
×
2784
            warnings.warn(
×
2785
                "data is xr.Dataset; only the first variable will be used in plot"
2786
            )
2787
        data = data[list(data.keys())[0]].squeeze()
×
2788
    else:
2789
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
2790

2791
    if ax is None:
×
2792
        fig, ax = plt.subplots(**fig_kw)
×
2793

2794
    # Select data from reference year onward
2795
    if start_year:
×
2796
        data = data.sel(time=slice(start_year, None))
×
2797

2798
        # Lead time coordinate
2799
        time = _add_lead_time_coord(data, start_year)
×
2800
        ax.set_xlabel(f"Lead time (years from {start_year})")
×
2801
    else:
2802
        time = data.time.dt.year
×
2803

2804
    # fill_kw that are direct (not with uncertainty as key)
2805
    fk_direct = {k: v for k, v in fill_kw.items() if (k not in data.uncertainty.values)}
×
2806

2807
    # Draw areas
2808
    past_y = 0
×
2809
    black_lines = []
×
2810
    for u in data.uncertainty.values:
×
2811
        if u not in ["total", "variability"]:
×
2812
            present_y = past_y + data.sel(uncertainty=u)
×
2813
            num = len(data.attrs.get(u, []))  # compatible with pre PR PR #1529
×
2814
            label = f"{u} ({num})" if show_num and num else u
×
2815
            ax.fill_between(
×
2816
                time,
2817
                past_y,
2818
                present_y,
2819
                label=label,
2820
                **fill_kw.get(u, fk_direct),
2821
            )
2822
            black_lines.append(present_y)
×
2823
            past_y = present_y
×
2824
    ax.fill_between(
×
2825
        time,
2826
        past_y,
2827
        100,
2828
        label="variability",
2829
        **fill_kw.get("variability", fk_direct),
2830
    )
2831

2832
    # Draw black lines
2833
    line_kw.setdefault("color", "k")
×
2834
    line_kw.setdefault("lw", 2)
×
2835
    ax.plot(time, np.array(black_lines).T, **line_kw)
×
2836

2837
    ax.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(20))
×
2838
    ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=5))
×
2839

2840
    ax.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(10))
×
2841
    ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=2))
×
2842

2843
    ax.set_ylabel(f"{data.attrs['long_name']} ({data.attrs['units']})")  #
×
2844

2845
    ax.set_ylim(0, 100)
×
2846
    ax.legend(**legend_kw)
×
2847

2848
    return ax
×
2849

2850

2851
def triheatmap(
6✔
2852
    data: xr.DataArray | xr.Dataset,
2853
    z: str,
2854
    ax: matplotlib.axes.Axes | None = None,
2855
    use_attrs: dict[str, Any] | None = None,
2856
    fig_kw: dict[str, Any] | None = None,
2857
    plot_kw: dict[str, Any] | None | list = None,
2858
    cmap: str | matplotlib.colors.Colormap | None = None,
2859
    divergent: bool | int | float = False,
2860
    cbar: bool | str = "unique",
2861
    cbar_kw: dict[str, Any] | None | list = None,
2862
) -> matplotlib.axes.Axes:
2863
    """Create a triangle heatmap from a DataArray.
2864

2865
    Note that most of the code comes from:
2866
    https://stackoverflow.com/questions/66048529/how-to-create-a-heatmap-where-each-cell-is-divided-into-4-triangles
2867

2868
    Parameters
2869
    ----------
2870
    data : DataArray or Dataset
2871
        Input data do plot.
2872
    z: str
2873
        Dimension to plot on the triangles. Its length should be 2 or 4.
2874
    ax : matplotlib axis, optional
2875
        Matplotlib axis on which to plot, with the same projection as the one specified.
2876
    use_attrs : dict, optional
2877
        Dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
2878
        Default value is {'cbar_label': 'long_name',"cbar_units": "units"}.
2879
        Valid keys are: 'title', 'xlabel', 'ylabel', 'cbar_label', 'cbar_units'.
2880
    fig_kw : dict, optional
2881
        Arguments to pass to `plt.figure()`.
2882
    plot_kw :  dict, optional
2883
        Arguments to pass to the 'plt.tripcolor()' function.
2884
        It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west).
2885
    cmap : matplotlib.colors.Colormap or str, optional
2886
        Colormap to use. If str, can be a matplotlib or name of the file of an IPCC colormap (see data/ipcc_colors).
2887
        If None, look for common variables (from data/ipcc_colors/variables_groups.json) in the name of the DataArray
2888
        or its 'history' attribute and use corresponding colormap, aligned with the IPCC Visual Style Guide 2022
2889
        (https://www.ipcc.ch/site/assets/uploads/2022/09/IPCC_AR6_WGI_VisualStyleGuide_2022.pdf).
2890
    divergent : bool or int or float
2891
        If int or float, becomes center of cmap. Default center is 0.
2892
    cbar : {False, True, 'unique', 'each'}
2893
        If False, don't show the colorbar.
2894
        If True or 'unique', show a unique colorbar for all triangle types. (The cbar of the first triangle is used).
2895
        If 'each', show a colorbar for each triangle type.
2896
    cbar_kw : dict or list
2897
        Arguments to pass to 'fig.colorbar()'.
2898
        It can be a list of dictionaries to pass different arguments to each type of triangles (upper/lower or north/east/south/west).
2899

2900
    Returns
2901
    -------
2902
    matplotlib.axes.Axes
2903
    """
2904
    # create empty dicts if None
2905
    use_attrs = empty_dict(use_attrs)
×
2906
    fig_kw = empty_dict(fig_kw)
×
2907
    plot_kw = empty_dict(plot_kw)
×
2908
    cbar_kw = empty_dict(cbar_kw)
×
2909

2910
    # select data to plot
2911
    if isinstance(data, xr.DataArray):
×
2912
        da = data
×
2913
    elif isinstance(data, xr.Dataset):
×
2914
        if len(data.data_vars) > 1:
×
2915
            warnings.warn(
×
2916
                "data is xr.Dataset; only the first variable will be used in plot"
2917
            )
2918
        da = list(data.values())[0]
×
2919
    else:
2920
        raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")
×
2921

2922
    # setup fig, axis
2923
    if ax is None:
×
2924
        fig, ax = plt.subplots(**fig_kw)
×
2925

2926
    # colormap
2927
    if isinstance(cmap, str):
×
2928
        if cmap not in plt.colormaps():
×
2929
            try:
×
2930
                cmap = create_cmap(filename=cmap)
×
2931
            except FileNotFoundError:
×
2932
                pass
2933
                logging.log("Colormap not found. Using default.")
×
2934

2935
    elif cmap is None:
×
2936
        cdata = Path(__file__).parents[1] / "data/ipcc_colors/variable_groups.json"
×
2937
        cmap = create_cmap(
×
2938
            get_var_group(path_to_json=cdata, da=da),
2939
            divergent=divergent,
2940
        )
2941

2942
    # prep data
2943
    d = [da.sel(**{z: v}).values for v in da[z].values]
×
2944

2945
    other_dims = [di for di in da.dims if di != z]
×
2946
    if len(other_dims) > 2:
×
2947
        warnings.warn(
×
2948
            "More than 3 dimensions in data. The first two after dim will be used as the dimensions of the heatmap."
2949
        )
2950
    if len(other_dims) < 2:
×
2951
        raise ValueError(
×
2952
            "Data must have 3 dimensions. If you only have 2 dimensions, use fg.heatmap."
2953
        )
2954

2955
    if plot_kw == {} and cbar in ["unique", True]:
×
2956
        warnings.warn(
×
2957
            'With cbar="unique" only the colorbar of the first triangle'
2958
            " will be shown. No `plot_kw` was passed. vmin and vmax will be set the max"
2959
            " and min of data."
2960
        )
2961
        plot_kw = {"vmax": da.max().values, "vmin": da.min().values}
×
2962

2963
    if isinstance(plot_kw, dict):
×
2964
        plot_kw.setdefault("cmap", cmap)
×
2965
        plot_kw.setdefault("ec", "white")
×
2966
        plot_kw = [plot_kw for _ in range(len(d))]
×
2967

2968
    labels_x = da[other_dims[0]].values
×
2969
    labels_y = da[other_dims[1]].values
×
2970
    m, n = d[0].shape[0], d[0].shape[1]
×
2971

2972
    # plot
2973
    if len(d) == 2:
×
2974
        x = np.arange(m + 1)
×
2975
        y = np.arange(n + 1)
×
2976
        xss, ys = np.meshgrid(x, y)
×
2977
        zs = (xss * ys) % 10
×
2978
        triangles1 = [
×
2979
            (i + j * (m + 1), i + 1 + j * (m + 1), i + (j + 1) * (m + 1))
2980
            for j in range(n)
2981
            for i in range(m)
2982
        ]
2983
        triangles2 = [
×
2984
            (
2985
                i + 1 + j * (m + 1),
2986
                i + 1 + (j + 1) * (m + 1),
2987
                i + (j + 1) * (m + 1),
2988
            )
2989
            for j in range(n)
2990
            for i in range(m)
2991
        ]
2992
        triang1 = Triangulation(xss.ravel(), ys.ravel(), triangles1)
×
2993
        triang2 = Triangulation(xss.ravel(), ys.ravel(), triangles2)
×
2994
        triangul = [triang1, triang2]
×
2995

2996
        imgs = [
×
2997
            ax.tripcolor(t, np.ravel(val), **plotkw)
2998
            for t, val, plotkw in zip(triangul, d, plot_kw)
2999
        ]
3000

3001
        ax.set_xticks(np.array(range(m)) + 0.5, labels=labels_x, rotation=45)
×
3002
        ax.set_yticks(np.array(range(n)) + 0.5, labels=labels_y, rotation=90)
×
3003

3004
    elif len(d) == 4:
×
3005
        xv, yv = np.meshgrid(
×
3006
            np.arange(-0.5, m), np.arange(-0.5, n)
3007
        )  # vertices of the little squares
3008
        xc, yc = np.meshgrid(
×
3009
            np.arange(0, m), np.arange(0, n)
3010
        )  # centers of the little squares
3011
        x = np.concatenate([xv.ravel(), xc.ravel()])
×
3012
        y = np.concatenate([yv.ravel(), yc.ravel()])
×
3013
        cstart = (m + 1) * (n + 1)  # indices of the centers
×
3014

3015
        triangles_n = [
×
3016
            (i + j * (m + 1), i + 1 + j * (m + 1), cstart + i + j * m)
3017
            for j in range(n)
3018
            for i in range(m)
3019
        ]
3020
        triangles_e = [
×
3021
            (i + 1 + j * (m + 1), i + 1 + (j + 1) * (m + 1), cstart + i + j * m)
3022
            for j in range(n)
3023
            for i in range(m)
3024
        ]
3025
        triangles_s = [
×
3026
            (
3027
                i + 1 + (j + 1) * (m + 1),
3028
                i + (j + 1) * (m + 1),
3029
                cstart + i + j * m,
3030
            )
3031
            for j in range(n)
3032
            for i in range(m)
3033
        ]
3034
        triangles_w = [
×
3035
            (i + (j + 1) * (m + 1), i + j * (m + 1), cstart + i + j * m)
3036
            for j in range(n)
3037
            for i in range(m)
3038
        ]
3039
        triangul = [
×
3040
            Triangulation(x, y, triangles)
3041
            for triangles in [
3042
                triangles_n,
3043
                triangles_e,
3044
                triangles_s,
3045
                triangles_w,
3046
            ]
3047
        ]
3048

3049
        imgs = [
×
3050
            ax.tripcolor(t, np.ravel(val), **plotkw)
3051
            for t, val, plotkw in zip(triangul, d, plot_kw)
3052
        ]
3053
        ax.set_xticks(np.array(range(m)), labels=labels_x, rotation=45)
×
3054
        ax.set_yticks(np.array(range(n)), labels=labels_y, rotation=90)
×
3055

3056
    else:
3057
        raise ValueError(
×
3058
            f"The length of the dimensiondim ({z},{len(d)}) should be either 2 or 4. It represents the number of triangles."
3059
        )
3060

3061
    ax.set_title(get_attributes(use_attrs.get("title", None), data))
×
3062
    ax.set_xlabel(other_dims[0])
×
3063
    ax.set_ylabel(other_dims[1])
×
3064
    if "xlabel" in use_attrs:
×
3065
        ax.set_xlabel(get_attributes(use_attrs["xlabel"], data))
×
3066
    if "ylabel" in use_attrs:
×
3067
        ax.set_ylabel(get_attributes(use_attrs["ylabel"], data))
×
3068
    ax.set_aspect("equal", "box")
×
3069
    ax.invert_yaxis()
×
3070
    ax.tick_params(left=False, bottom=False)
×
3071
    ax.spines["bottom"].set_visible(False)
×
3072
    ax.spines["left"].set_visible(False)
×
3073

3074
    # create cbar label
3075
    # set default use_attrs values
3076
    use_attrs.setdefault("cbar_label", "long_name")
×
3077
    use_attrs.setdefault("cbar_units", "units")
×
3078
    if (
×
3079
        "cbar_units" in use_attrs
3080
        and len(get_attributes(use_attrs["cbar_units"], data)) >= 1
3081
    ):  # avoids '()' as label
3082
        cbar_label = (
×
3083
            get_attributes(use_attrs["cbar_label"], data)
3084
            + " ("
3085
            + get_attributes(use_attrs["cbar_units"], data)
3086
            + ")"
3087
        )
3088
    else:
3089
        cbar_label = get_attributes(use_attrs["cbar_label"], data)
×
3090

3091
    if isinstance(cbar_kw, dict):
×
3092
        cbar_kw.setdefault("label", cbar_label)
×
3093
        cbar_kw = [cbar_kw for _ in range(len(d))]
×
3094
    if cbar == "unique":
×
3095
        plt.colorbar(imgs[0], ax=ax, **cbar_kw[0])
×
3096

3097
    elif (cbar == "each") or (cbar is True):
×
3098
        for i in reversed(range(len(d))):  # switch order of colour bars
×
3099
            plt.colorbar(imgs[i], ax=ax, **cbar_kw[i])
×
3100

3101
    return ax
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc