• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 5564819928

pending completion
5564819928

push

github

morganjwilliams
Merge branch 'release/0.3.3' into main

249 of 270 new or added lines in 48 files covered. (92.22%)

217 existing lines in 33 files now uncovered.

5971 of 6605 relevant lines covered (90.4%)

10.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.86
/pyrolite/plot/spider.py
1
import matplotlib.collections
12✔
2
import matplotlib.lines
12✔
3
import matplotlib.patches
12✔
4
import matplotlib.pyplot as plt
12✔
5
import numpy as np
12✔
6

7
from ..util.log import Handle
12✔
8

9
logger = Handle(__name__)
12✔
10

11

12
from ..geochem.ind import REE, get_ionic_radii
12✔
13
from ..util.meta import get_additional_params, subkwargs
12✔
14
from ..util.plot.axes import get_twins, init_axes
12✔
15
from ..util.plot.density import (
12✔
16
    conditional_prob_density,
17
    percentile_contour_values_from_meshz,
18
    plot_Z_percentiles,
19
)
20
from ..util.plot.style import (
12✔
21
    DEFAULT_CONT_COLORMAP,
22
    linekwargs,
23
    patchkwargs,
24
    scatterkwargs,
25
)
26
from .color import process_color
12✔
27

28
_scatter_defaults = dict(cmap=DEFAULT_CONT_COLORMAP, marker="D", s=25)
12✔
29
_line_defaults = dict(cmap=DEFAULT_CONT_COLORMAP)
12✔
30

31

32
# could create a spidercollection?
33
def spider(
12✔
34
    arr,
35
    indexes=None,
36
    ax=None,
37
    label=None,
38
    logy=True,
39
    yextent=None,
40
    mode="plot",
41
    unity_line=False,
42
    scatter_kw={},
43
    line_kw={},
44
    set_ticks=True,
45
    autoscale=True,
46
    **kwargs
47
):
48
    """
49
    Plots spidergrams for trace elements data. Additional arguments are typically forwarded
50
    to respective :mod:`matplotlib` functions :func:`~matplotlib.pyplot.plot` and
51
    :func:`~matplotlib.pyplot.scatter` (see Other Parameters, below).
52

53
    Parameters
54
    ----------
55
    arr : :class:`numpy.ndarray`
56
        Data array.
57
    indexes : : :class:`numpy.ndarray`
58
        Numerical indexes of x-axis positions.
59
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
60
        The subplot to draw on.
61
    label : :class:`str`, :code:`None`
62
        Label for the individual series.
63
    logy : :class:`bool`
64
        Whether to use a log y-axis.
65
    yextent : :class:`tuple`
66
        Extent in the y direction for conditional probability plots, to limit
67
        the gridspace over which the kernel density estimates are evaluated.
68
    mode : :class:`str`,  :code:`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
69
        Mode for plot. Plot will produce a line-scatter diagram. Fill will return
70
        a filled range. Density will return a conditional density diagram.
71
    unity_line : :class:`bool`
72
        Add a line at y=1 for reference.
73
    scatter_kw : :class:`dict`
74
        Keyword parameters to be passed to the scatter plotting function.
75
    line_kw : :class:`dict`
76
        Keyword parameters to be passed to the line plotting function.
77
    set_ticks : :class:`bool`
78
        Whether to set the x-axis ticks according to the specified index.
79
    autoscale : :class:`bool`
80
        Whether to autoscale the y-axis limits for standard spider plots.
81
    {otherparams}
82

83
    Returns
84
    -------
85
    :class:`matplotlib.axes.Axes`
86
        Axes on which the spiderplot is plotted.
87

88
    Notes
89
    -----
90
        By using separate lines and scatterplots, values between two missing
91
        items are still presented.
92

93
    Todo
94
    ----
95
        * Might be able to speed up lines with `~matplotlib.collections.LineCollection`.
96
        * Legend entries
97

98
    .. seealso::
99

100
        Functions:
101

102
            :func:`matplotlib.pyplot.plot`
103
            :func:`matplotlib.pyplot.scatter`
104
            :func:`REE_v_radii`
105
    """
106

107
    # ---------------------------------------------------------------------
108
    ncomponents = arr.shape[-1]
12✔
109
    figsize = kwargs.pop("figsize", None) or (ncomponents * 0.3, 4)
12✔
110

111
    ax = init_axes(ax=ax, figsize=figsize, **kwargs)
12✔
112

113
    if unity_line:
12✔
UNCOV
114
        ax.axhline(1.0, ls="--", c="k", lw=0.5)
×
115

116
    if logy:
12✔
117
        ax.set_yscale("log")
12✔
118

119
    if indexes is None:
12✔
120
        indexes = np.arange(ncomponents)
12✔
121
    else:
122
        indexes = np.array(indexes)
12✔
123

124
    if indexes.ndim == 1:
12✔
125
        indexes0 = indexes
12✔
126
    else:
UNCOV
127
        indexes0 = indexes[0]
×
128

129
    if set_ticks:
12✔
130
        ax.set_xticks(indexes0)
12✔
131

132
    # if there is no data, return the blank axis
133
    if (arr is None) or (not np.isfinite(arr).sum()):
12✔
134
        return ax
12✔
135

136
    # if the indexes are supplied as a 1D array but the data is 2D, we need to expand
137
    # it to fit the scatter data
138
    if indexes.ndim < arr.ndim:
12✔
139
        indexes = np.tile(indexes0, (arr.shape[0], 1))
12✔
140

141
    if "fill" in mode.lower():
12✔
142
        mins = np.nanmin(arr, axis=0)
12✔
143
        maxs = np.nanmax(arr, axis=0)
12✔
144
        ax.fill_between(indexes0, mins, maxs, **patchkwargs(kwargs))
12✔
145
    elif "plot" in mode.lower():
12✔
146
        # copy params
147
        l_kw, s_kw = {**line_kw}, {**scatter_kw}
12✔
148
        ################################################################################
149
        if line_kw.get("cmap") is None:
12✔
150
            l_kw["cmap"] = kwargs.get("cmap", None)
12✔
151

152
        l_kw = {**kwargs, **l_kw}
12✔
153

154
        # if a line color hasn't been specified, perhaps we can use the scatter 'c'
155
        if l_kw.get("color") is None:
12✔
156
            if l_kw.get("c") is not None:
12✔
157
                l_kw["color"] = kwargs.get("c")
12✔
158
        if "c" in l_kw:
12✔
159
            l_kw.pop("c")  # remove c if it's been specified globally
12✔
160
        # if a color option is not specified, get the next cycled color
161
        if l_kw.get("color") is None:
12✔
162
            # add cycler color as array to suppress singular color warning
163
            l_kw["color"] = np.array([next(ax._get_lines.prop_cycler)["color"]])
12✔
164

165
        l_kw = linekwargs(process_color(**{**_line_defaults, **l_kw}))
12✔
166
        # marker explictly dealt with by scatter
167
        for k in ["marker", "markers"]:
12✔
168
            l_kw.pop(k, None)
12✔
169
        # Construct and Add LineCollection?
170
        lcoll = matplotlib.collections.LineCollection(
12✔
171
            np.dstack((indexes, arr)), **{"zorder": 1, **l_kw}
172
        )
173
        ax.add_collection(lcoll)
12✔
174
        ################################################################################
175
        # load defaults and any specified parameters in scatter_kw / line_kw
176
        if s_kw.get("cmap") is None:
12✔
177
            s_kw["cmap"] = kwargs.get("cmap", None)
12✔
178

179
        _sctr_cfg = {**_scatter_defaults, **kwargs, **s_kw}
12✔
180
        s_kw = process_color(**_sctr_cfg)
12✔
181
        if s_kw["marker"] is not None:
12✔
182
            # will need to process colours for scatter markers here
183

184
            s_kw.update(dict(label=label))
12✔
185

186
            scattercolor = None
12✔
187
            if s_kw.get("c") is not None:
12✔
188
                scattercolor = s_kw.get("c")
12✔
189
            elif s_kw.get("color") is not None:
12✔
UNCOV
190
                scattercolor = s_kw.get("color")
×
191
            else:
192
                # no color recognised - will be default, here we get the
193
                # cycled color we added earlier
194
                scattercolor = l_kw["color"]
12✔
195

196
            if scattercolor is not None:
12✔
197
                if not isinstance(scattercolor, (str, tuple)):
12✔
198
                    # colors will be processed to arrays by this point
199
                    # here we reshape them to be the same length as ravel-ed arrays
200
                    if scattercolor.ndim >= 2 and scattercolor.shape[0] > 1:
12✔
UNCOV
201
                        scattercolor = np.tile(scattercolor, arr.shape[1]).reshape(
×
202
                            -1, scattercolor.shape[1]
203
                        )
204
                else:
205
                    # singular color should be converted to 2d array?
206
                    pass
207
            s_kw = scatterkwargs(
12✔
208
                {k: v for k, v in s_kw.items() if k not in ["c", "color"]}
209
            )
210
            # do these need to be ravelled?
211
            ax.scatter(
12✔
212
                indexes.ravel(), arr.ravel(), c=scattercolor, **{"zorder": 2, **s_kw}
213
            )
214

215
        # should create a custom legend handle here
216

217
        # could modify legend here.
218
    elif any([i in mode.lower() for i in ["binkde", "ckde", "kde", "hist"]]):
12✔
219
        cmap = kwargs.pop("cmap", None)
12✔
220
        if "contours" in kwargs and "vmin" in kwargs:
12✔
UNCOV
221
            msg = "Combining `contours` and `vmin` arugments for density plots should be avoided."
×
UNCOV
222
            logger.warn(msg)
×
223
        xe, ye, zi, xi, yi = conditional_prob_density(
12✔
224
            arr,
225
            x=indexes0,
226
            logy=logy,
227
            yextent=yextent,
228
            mode=mode,
229
            ret_centres=True,
230
            **kwargs
231
        )
232
        # can have issues with nans here?
233
        vmin = kwargs.pop("vmin", 0)
12✔
234
        vmin = percentile_contour_values_from_meshz(zi, [1.0 - vmin])[1][0]  # pctl
12✔
235
        if "contours" in kwargs:
12✔
UNCOV
236
            pzpkwargs = {  # keyword arguments to forward to plot_Z_percentiles
×
237
                **subkwargs(kwargs, plot_Z_percentiles),
238
                **{"percentiles": kwargs["contours"]},
239
            }
UNCOV
240
            plot_Z_percentiles(  # pass all relevant kwargs including contours
×
241
                xi, yi, zi=zi, ax=ax, cmap=cmap, vmin=vmin, **pzpkwargs
242
            )
243
        else:
244
            zi[zi < vmin] = np.nan
12✔
245
            ax.pcolormesh(
12✔
246
                xe, ye, zi, cmap=cmap, vmin=vmin, **subkwargs(kwargs, ax.pcolormesh)
247
            )
248
    else:
249
        raise NotImplementedError(
12✔
250
            "Accepted modes: {plot, fill, binkde, ckde, kde, hist}"
251
        )
252

253
    if autoscale and arr.size:
12✔
254
        # set the y range to lock to the outermost log-increments
255
        _ymin, _ymax = np.nanmin(arr), np.nanmax(arr)
12✔
256

257
        if unity_line:
12✔
UNCOV
258
            _ymin, _ymax = min(_ymin, 1.0), max(_ymax, 1.0)
×
259

260
        if logy:
12✔
261
            # at 5% range in log space, and clip to nearest 'minor' tick
262
            logmin, logmax = np.log10(_ymin), np.log10(_ymax)
12✔
263
            logy_rng = logmax - logmin
12✔
264

265
            low, high = 10 ** np.floor(logmin), 10 ** np.floor(logmax)
12✔
266

267
            _ymin, _ymax = (
12✔
268
                np.floor(10 ** (logmin - 0.05 * logy_rng) / low) * low,
269
                np.ceil(10 ** (logmax + 0.05 * logy_rng) / high) * high,
270
            )
271
        else:
272
            # add 10% range either side for linear scale
273
            _ymin, _ymax = 0.9 * _ymin, 1.1 * _ymax
12✔
274

275
        if np.isfinite(_ymax) and np.isfinite(_ymin) and (_ymax - _ymin) > 0:
12✔
276
            ax.set_ylim(_ymin, _ymax)
12✔
277
    return ax
12✔
278

279

280
def REE_v_radii(
12✔
281
    arr=None,
282
    ax=None,
283
    ree=REE(),
284
    index="elements",
285
    mode="plot",
286
    logy=True,
287
    tl_rotation=60,
288
    unity_line=False,
289
    scatter_kw={},
290
    line_kw={},
291
    set_labels=True,
292
    set_ticks=True,
293
    **kwargs
294
):
295
    r"""
296
    Creates an axis for a REE diagram with ionic radii along the x axis.
297

298
    Parameters
299
    ----------
300
    arr : :class:`numpy.ndarray`
301
        Data array.
302
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
303
        Optional designation of axes to reconfigure.
304
    ree : :class:`list`
305
        List of REE to use as an index.
306
    index : :class:`str`
307
        Whether to plot using radii on the x-axis ('radii'), or elements ('elements').
308
    mode : :class:`str`, :code:`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
309
        Mode for plot. Plot will produce a line-scatter diagram. Fill will return
310
        a filled range. Density will return a conditional density diagram.
311
    logy : :class:`bool`
312
        Whether to use a log y-axis.
313
    tl_rotation : :class:`float`
314
        Rotation of the numerical index labels in degrees.
315
    unity_line : :class:`bool`
316
        Add a line at y=1 for reference.
317
    scatter_kw : :class:`dict`
318
        Keyword parameters to be passed to the scatter plotting function.
319
    line_kw : :class:`dict`
320
        Keyword parameters to be passed to the line plotting function.
321
    set_labels : :class:`bool`
322
        Whether to set the x-axis ticklabels for the REE.
323
    set_ticks : :class:`bool`
324
        Whether to set the x-axis ticks according to the specified index.
325
    {otherparams}
326

327
    Returns
328
    -------
329
    :class:`matplotlib.axes.Axes`
330
        Axes on which the REE_v_radii plot is added.
331

332
    Todo
333
    ----
334
        * Turn this into a plot template within pyrolite.plot.templates submodule
335

336
    .. seealso::
337

338
        Functions:
339

340
            :func:`matplotlib.pyplot.plot`
341
            :func:`matplotlib.pyplot.scatter`
342
            :func:`spider`
343
            :func:`pyrolite.geochem.transform.lambda_lnREE`
344

345
    """
346
    ax = init_axes(ax=ax, **kwargs)
12✔
347

348
    radii = np.array(get_ionic_radii(ree, charge=3, coordination=8))
12✔
349

350
    xlabels, _xlabels = ["{:1.3f}".format(i) for i in radii], ree
12✔
351
    xticks, _xticks = radii, radii
12✔
352
    xlim = (0.99 * np.min(radii), 1.01 * np.max(radii))
12✔
353
    xlabelrotation, _xlabelrotation = tl_rotation, 0
12✔
354
    xtitle, _xtitle = r"Ionic Radius ($\mathrm{\AA}$)", "Element"
12✔
355

356
    if index == "radii":
12✔
357
        invertx = False
12✔
358
        indexes = radii
12✔
359
    else:  # mode == 'elements'
360
        invertx = True
12✔
361
        indexes = radii
12✔
362
        # swap ticks labels etc,
363
        _xtitle, xtitle = xtitle, _xtitle
12✔
364
        _xlabels, xlabels = xlabels, _xlabels
12✔
365
        _xlabelrotation, xlabelrotation = xlabelrotation, _xlabelrotation
12✔
366

367
    if arr is not None:
12✔
368
        ax = spider(
12✔
369
            arr,
370
            ax=ax,
371
            logy=logy,
372
            mode=mode,
373
            unity_line=unity_line,
374
            indexes=indexes,
375
            scatter_kw=scatter_kw,
376
            line_kw=line_kw,
377
            **kwargs
378
        )
379

380
    twinys = get_twins(ax, which="y")
12✔
381
    if len(twinys):
12✔
UNCOV
382
        _ax = twinys[0]
×
383
    else:
384
        _ax = ax.twiny()
12✔
385

386
    if set_labels:
12✔
387
        ax.set_xlabel(xtitle)
12✔
388
        _ax.set_xlabel(_xtitle)
12✔
389

390
    if set_ticks:
12✔
391
        ax.set_xticks(xticks)
12✔
392
        ax.set_xticklabels(xlabels, rotation=xlabelrotation)
12✔
393

394
        if invertx:
12✔
395
            xlim = xlim[::-1]
12✔
396
        if xlim is not None:
12✔
397
            ax.set_xlim(xlim)
12✔
398

399
        _ax.set_xticks(_xticks)
12✔
400
        _ax.set_xticklabels(_xlabels, rotation=_xlabelrotation)
12✔
401
        _ax.set_xlim(ax.get_xlim())
12✔
402
    return ax
12✔
403

404

405
_add_additional_parameters = True
12✔
406
spider.__doc__ = spider.__doc__.format(
12✔
407
    otherparams=[
408
        "",
409
        get_additional_params(
410
            spider,
411
            plt.scatter,
412
            plt.plot,
413
            matplotlib.lines.Line2D,
414
            indent=4,
415
            header="Other Parameters",
416
            subsections=True,
417
        ),
418
    ][_add_additional_parameters]
419
)
420

421
REE_v_radii.__doc__ = REE_v_radii.__doc__.format(
12✔
422
    otherparams=[
423
        "",
424
        get_additional_params(
425
            REE_v_radii, spider, indent=4, header="Other Parameters", subsections=True
426
        ),
427
    ][_add_additional_parameters]
428
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc