• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 11836306515

14 Nov 2024 11:28AM UTC coverage: 90.339% (-1.3%) from 91.611%
11836306515

push

github

morganjwilliams
Update example notebook

6237 of 6904 relevant lines covered (90.34%)

10.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.91
/pyrolite/plot/spider.py
1
import matplotlib.collections
12✔
2
import matplotlib.lines
12✔
3
import matplotlib.patches
12✔
4
import matplotlib.pyplot as plt
12✔
5
import numpy as np
12✔
6

7
from ..util.log import Handle
12✔
8

9
logger = Handle(__name__)
12✔
10

11

12
from ..geochem.ind import REE, get_ionic_radii
12✔
13
from ..util.meta import get_additional_params, subkwargs
12✔
14
from ..util.plot.axes import get_twins, init_axes
12✔
15
from ..util.plot.density import (
12✔
16
    conditional_prob_density,
17
    percentile_contour_values_from_meshz,
18
    plot_Z_percentiles,
19
)
20
from ..util.plot.style import (
12✔
21
    DEFAULT_CONT_COLORMAP,
22
    linekwargs,
23
    patchkwargs,
24
    scatterkwargs,
25
)
26
from .color import process_color
12✔
27

28
_scatter_defaults = dict(cmap=DEFAULT_CONT_COLORMAP, marker="D", s=25)
12✔
29
_line_defaults = dict(cmap=DEFAULT_CONT_COLORMAP)
12✔
30

31

32
# could create a spidercollection?
33
def spider(
12✔
34
    arr,
35
    indexes=None,
36
    ax=None,
37
    label=None,
38
    logy=True,
39
    yextent=None,
40
    mode="plot",
41
    unity_line=False,
42
    scatter_kw={},
43
    line_kw={},
44
    set_ticks=True,
45
    autoscale=True,
46
    **kwargs,
47
):
48
    """
49
    Plots spidergrams for trace elements data. Additional arguments are typically forwarded
50
    to respective :mod:`matplotlib` functions :func:`~matplotlib.pyplot.plot` and
51
    :func:`~matplotlib.pyplot.scatter` (see Other Parameters, below).
52

53
    Parameters
54
    ----------
55
    arr : :class:`numpy.ndarray`
56
        Data array.
57
    indexes : : :class:`numpy.ndarray`
58
        Numerical indexes of x-axis positions.
59
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
60
        The subplot to draw on.
61
    label : :class:`str`, :code:`None`
62
        Label for the individual series.
63
    logy : :class:`bool`
64
        Whether to use a log y-axis.
65
    yextent : :class:`tuple`
66
        Extent in the y direction for conditional probability plots, to limit
67
        the gridspace over which the kernel density estimates are evaluated.
68
    mode : :class:`str`,  :code:`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
69
        Mode for plot. Plot will produce a line-scatter diagram. Fill will return
70
        a filled range. Density will return a conditional density diagram.
71
    unity_line : :class:`bool`
72
        Add a line at y=1 for reference.
73
    scatter_kw : :class:`dict`
74
        Keyword parameters to be passed to the scatter plotting function.
75
    line_kw : :class:`dict`
76
        Keyword parameters to be passed to the line plotting function.
77
    set_ticks : :class:`bool`
78
        Whether to set the x-axis ticks according to the specified index.
79
    autoscale : :class:`bool`
80
        Whether to autoscale the y-axis limits for standard spider plots.
81
    {otherparams}
82

83
    Returns
84
    -------
85
    :class:`matplotlib.axes.Axes`
86
        Axes on which the spiderplot is plotted.
87

88
    Notes
89
    -----
90
        By using separate lines and scatterplots, values between two missing
91
        items are still presented.
92

93
    Todo
94
    ----
95
        * Might be able to speed up lines with `~matplotlib.collections.LineCollection`.
96
        * Legend entries
97

98
    .. seealso::
99

100
        Functions:
101

102
            :func:`matplotlib.pyplot.plot`
103
            :func:`matplotlib.pyplot.scatter`
104
            :func:`REE_v_radii`
105
    """
106

107
    # ---------------------------------------------------------------------
108
    ncomponents = arr.shape[-1]
12✔
109
    figsize = kwargs.pop("figsize", None) or (ncomponents * 0.3, 4)
12✔
110

111
    ax = init_axes(ax=ax, figsize=figsize, **kwargs)
12✔
112

113
    if unity_line:
12✔
114
        ax.axhline(1.0, ls="--", c="k", lw=0.5)
×
115

116
    if logy:
12✔
117
        ax.set_yscale("log")
12✔
118

119
    if indexes is None:
12✔
120
        indexes = np.arange(ncomponents)
12✔
121
    else:
122
        indexes = np.array(indexes)
12✔
123

124
    if indexes.ndim == 1:
12✔
125
        indexes0 = indexes
12✔
126
    else:
127
        indexes0 = indexes[0]
×
128

129
    if set_ticks:
12✔
130
        ax.set_xticks(indexes0)
12✔
131

132
    # if there is no data, return the blank axis
133
    if (arr is None) or (not np.isfinite(arr).sum()):
12✔
134
        return ax
12✔
135

136
    # if the indexes are supplied as a 1D array but the data is 2D, we need to expand
137
    # it to fit the scatter data
138
    if indexes.ndim < arr.ndim:
12✔
139
        indexes = np.tile(indexes0, (arr.shape[0], 1))
12✔
140

141
    if "fill" in mode.lower():
12✔
142
        mins = np.nanmin(arr, axis=0)
12✔
143
        maxs = np.nanmax(arr, axis=0)
12✔
144
        ax.fill_between(indexes0, mins, maxs, **patchkwargs(kwargs))
12✔
145
    elif "plot" in mode.lower():
12✔
146
        # copy params
147
        l_kw, s_kw = {**line_kw}, {**scatter_kw}
12✔
148
        ################################################################################
149
        if line_kw.get("cmap") is None:
12✔
150
            l_kw["cmap"] = kwargs.get("cmap", None)
12✔
151

152
        l_kw = {**kwargs, **l_kw}
12✔
153

154
        # if a line color hasn't been specified, perhaps we can use the scatter 'c'
155
        if l_kw.get("color") is None:
12✔
156
            if l_kw.get("c") is not None:
12✔
157
                l_kw["color"] = kwargs.get("c")
12✔
158
        if "c" in l_kw:
12✔
159
            l_kw.pop("c")  # remove c if it's been specified globally
12✔
160
        # if a color option is not specified, get the next cycled color
161
        if l_kw.get("color") is None:
12✔
162
            # add cycler color as array to suppress singular color warning
163
            l_kw["color"] = ax._get_lines.get_next_color()
12✔
164

165
        l_kw = linekwargs(process_color(**{**_line_defaults, **l_kw}))
12✔
166
        # marker explictly dealt with by scatter
167
        for k in ["marker", "markers"]:
12✔
168
            l_kw.pop(k, None)
12✔
169
        # Construct and Add LineCollection?
170
        lcoll = matplotlib.collections.LineCollection(
12✔
171
            np.dstack((indexes, arr)), **{"zorder": 1, **l_kw}
172
        )
173
        ax.add_collection(lcoll)
12✔
174
        ################################################################################
175
        # load defaults and any specified parameters in scatter_kw / line_kw
176
        if s_kw.get("cmap") is None:
12✔
177
            s_kw["cmap"] = kwargs.get("cmap", None)
12✔
178

179
        _sctr_cfg = {**_scatter_defaults, **kwargs, **s_kw}
12✔
180
        s_kw = process_color(**_sctr_cfg)
12✔
181
        if s_kw["marker"] is not None:
12✔
182
            # will need to process colours for scatter markers here
183

184
            s_kw.update(dict(label=label))
12✔
185

186
            scattercolor = None
12✔
187
            if s_kw.get("c") is not None:
12✔
188
                scattercolor = s_kw.get("c")
12✔
189
            elif s_kw.get("color") is not None:
12✔
190
                scattercolor = s_kw.get("color")
×
191
            else:
192
                # no color recognised - will be default, here we get the
193
                # cycled color we added earlier
194
                scattercolor = l_kw["color"]
12✔
195

196
            if scattercolor is not None:
12✔
197
                if not isinstance(scattercolor, (str, tuple)):
12✔
198
                    # colors will be processed to arrays by this point
199
                    # here we reshape them to be the same length as ravel-ed arrays
200
                    if scattercolor.ndim >= 2 and scattercolor.shape[0] > 1:
12✔
201
                        scattercolor = np.tile(scattercolor, arr.shape[1]).reshape(
×
202
                            -1, scattercolor.shape[1]
203
                        )
204
                else:
205
                    # singular color should be converted to 2d array?
206
                    pass
9✔
207
            s_kw = scatterkwargs(
12✔
208
                {k: v for k, v in s_kw.items() if k not in ["c", "color"]}
209
            )
210
            # do these need to be ravelled?
211
            ax.scatter(
12✔
212
                indexes.ravel(),
213
                arr.ravel(),
214
                color=scattercolor,
215
                **{"zorder": 2, **s_kw},
216
            )
217

218
        # should create a custom legend handle here
219

220
        # could modify legend here.
221
    elif any([i in mode.lower() for i in ["binkde", "ckde", "kde", "hist"]]):
12✔
222
        cmap = kwargs.pop("cmap", None)
12✔
223
        if "contours" in kwargs and "vmin" in kwargs:
12✔
224
            msg = "Combining `contours` and `vmin` arguments for density plots should be avoided."
×
225
            logger.warn(msg)
×
226
        xe, ye, zi, xi, yi = conditional_prob_density(
12✔
227
            arr,
228
            x=indexes0,
229
            logy=logy,
230
            yextent=yextent,
231
            mode=mode,
232
            ret_centres=True,
233
            **kwargs,
234
        )
235
        # can have issues with nans here?
236
        vmin = kwargs.pop("vmin", 0)
12✔
237
        vmin = percentile_contour_values_from_meshz(zi, [1.0 - vmin])[1][0]  # pctl
12✔
238
        if "contours" in kwargs:
12✔
239
            pzpkwargs = {  # keyword arguments to forward to plot_Z_percentiles
×
240
                **subkwargs(kwargs, plot_Z_percentiles),
241
                **{"percentiles": kwargs["contours"]},
242
            }
243
            plot_Z_percentiles(  # pass all relevant kwargs including contours
×
244
                xi, yi, zi=zi, ax=ax, cmap=cmap, vmin=vmin, **pzpkwargs
245
            )
246
        else:
247
            zi[zi < vmin] = np.nan
12✔
248
            ax.pcolormesh(
12✔
249
                xe, ye, zi, cmap=cmap, vmin=vmin, **subkwargs(kwargs, ax.pcolormesh)
250
            )
251
    else:
252
        raise NotImplementedError(
12✔
253
            "Accepted modes: {plot, fill, binkde, ckde, kde, hist}"
254
        )
255

256
    if autoscale and arr.size:
12✔
257
        # set the y range to lock to the outermost log-increments
258
        _ymin, _ymax = np.nanmin(arr), np.nanmax(arr)
12✔
259

260
        if unity_line:  # unity line is added at 1 - so may alter the total range
12✔
261
            _ymin, _ymax = min(_ymin, 1.0), max(_ymax, 1.0)
×
262

263
        if logy:
12✔
264
            # at 5% range in log space, and clip to nearest 'minor' tick
265
            # except in the case where (logmin - 0.05 * logy_rng) <= 0
266
            # (to avoid errors on log-scaled plots)
267
            logmin, logmax = np.log10(_ymin), np.log10(_ymax)
12✔
268
            logy_rng = logmax - logmin
12✔
269

270
            low, high = 10 ** np.floor(logmin), 10 ** np.floor(logmax)
12✔
271
           
272
            _ymin, _ymax = (
12✔
273
                np.floor(10 ** (logmin - 0.05 * logy_rng) / low) * low
274
                if (logmin - 0.05 * logy_rng) > 0
275
                else low,
276
                np.ceil(10 ** (logmax + 0.05 * logy_rng) / high) * high,
277
            )
278
        else:
279
            # add 10% range either side for linear scale
280
            _ymin, _ymax = 0.9 * _ymin, 1.1 * _ymax
12✔
281

282
        if np.isfinite(_ymax) and np.isfinite(_ymin) and (_ymax - _ymin) > 0:
12✔
283
            ax.set_ylim(_ymin, _ymax)
12✔
284
    return ax
12✔
285

286

287
def REE_v_radii(
12✔
288
    arr=None,
289
    ax=None,
290
    ree=REE(),
291
    index="elements",
292
    mode="plot",
293
    logy=True,
294
    tl_rotation=60,
295
    unity_line=False,
296
    scatter_kw={},
297
    line_kw={},
298
    set_labels=True,
299
    set_ticks=True,
300
    **kwargs,
301
):
302
    r"""
303
    Creates an axis for a REE diagram with ionic radii along the x axis.
304

305
    Parameters
306
    ----------
307
    arr : :class:`numpy.ndarray`
308
        Data array.
309
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
310
        Optional designation of axes to reconfigure.
311
    ree : :class:`list`
312
        List of REE to use as an index.
313
    index : :class:`str`
314
        Whether to plot using radii on the x-axis ('radii'), or elements ('elements').
315
    mode : :class:`str`, :code:`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
316
        Mode for plot. Plot will produce a line-scatter diagram. Fill will return
317
        a filled range. Density will return a conditional density diagram.
318
    logy : :class:`bool`
319
        Whether to use a log y-axis.
320
    tl_rotation : :class:`float`
321
        Rotation of the numerical index labels in degrees.
322
    unity_line : :class:`bool`
323
        Add a line at y=1 for reference.
324
    scatter_kw : :class:`dict`
325
        Keyword parameters to be passed to the scatter plotting function.
326
    line_kw : :class:`dict`
327
        Keyword parameters to be passed to the line plotting function.
328
    set_labels : :class:`bool`
329
        Whether to set the x-axis ticklabels for the REE.
330
    set_ticks : :class:`bool`
331
        Whether to set the x-axis ticks according to the specified index.
332
    {otherparams}
333

334
    Returns
335
    -------
336
    :class:`matplotlib.axes.Axes`
337
        Axes on which the REE_v_radii plot is added.
338

339
    Todo
340
    ----
341
        * Turn this into a plot template within pyrolite.plot.templates submodule
342

343
    .. seealso::
344

345
        Functions:
346

347
            :func:`matplotlib.pyplot.plot`
348
            :func:`matplotlib.pyplot.scatter`
349
            :func:`spider`
350
            :func:`pyrolite.geochem.transform.lambda_lnREE`
351

352
    """
353
    ax = init_axes(ax=ax, **kwargs)
12✔
354

355
    radii = np.array(get_ionic_radii(ree, charge=3, coordination=8))
12✔
356

357
    xlabels, _xlabels = ["{:1.3f}".format(i) for i in radii], ree
12✔
358
    xticks, _xticks = radii, radii
12✔
359
    xlim = (0.99 * np.min(radii), 1.01 * np.max(radii))
12✔
360
    xlabelrotation, _xlabelrotation = tl_rotation, 0
12✔
361
    xtitle, _xtitle = r"Ionic Radius ($\mathrm{\AA}$)", "Element"
12✔
362

363
    if index == "radii":
12✔
364
        invertx = False
12✔
365
        indexes = radii
12✔
366
    else:  # mode == 'elements'
367
        invertx = True
12✔
368
        indexes = radii
12✔
369
        # swap ticks labels etc,
370
        _xtitle, xtitle = xtitle, _xtitle
12✔
371
        _xlabels, xlabels = xlabels, _xlabels
12✔
372
        _xlabelrotation, xlabelrotation = xlabelrotation, _xlabelrotation
12✔
373

374
    if arr is not None:
12✔
375
        ax = spider(
12✔
376
            arr,
377
            ax=ax,
378
            logy=logy,
379
            mode=mode,
380
            unity_line=unity_line,
381
            indexes=indexes,
382
            scatter_kw=scatter_kw,
383
            line_kw=line_kw,
384
            **kwargs,
385
        )
386

387
    twinys = get_twins(ax, which="y")
12✔
388
    if len(twinys):
12✔
389
        _ax = twinys[0]
×
390
    else:
391
        _ax = ax.twiny()
12✔
392

393
    if set_labels:
12✔
394
        ax.set_xlabel(xtitle)
12✔
395
        _ax.set_xlabel(_xtitle)
12✔
396

397
    if set_ticks:
12✔
398
        ax.set_xticks(xticks)
12✔
399
        ax.set_xticklabels(xlabels, rotation=xlabelrotation)
12✔
400

401
        if invertx:
12✔
402
            xlim = xlim[::-1]
12✔
403
        if xlim is not None:
12✔
404
            ax.set_xlim(xlim)
12✔
405

406
        _ax.set_xticks(_xticks)
12✔
407
        _ax.set_xticklabels(_xlabels, rotation=_xlabelrotation)
12✔
408
        _ax.set_xlim(ax.get_xlim())
12✔
409
    return ax
12✔
410

411

412
_add_additional_parameters = True
12✔
413
spider.__doc__ = spider.__doc__.format(
12✔
414
    otherparams=[
415
        "",
416
        get_additional_params(
417
            spider,
418
            plt.scatter,
419
            plt.plot,
420
            matplotlib.lines.Line2D,
421
            indent=4,
422
            header="Other Parameters",
423
            subsections=True,
424
        ),
425
    ][_add_additional_parameters]
426
)
427

428
REE_v_radii.__doc__ = REE_v_radii.__doc__.format(
12✔
429
    otherparams=[
430
        "",
431
        get_additional_params(
432
            REE_v_radii, spider, indent=4, header="Other Parameters", subsections=True
433
        ),
434
    ][_add_additional_parameters]
435
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc