• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 17569160869

09 Sep 2025 01:41AM UTC coverage: 91.465% (-0.1%) from 91.614%
17569160869

push

github

morganjwilliams
Add uncertainties, add optional deps for pyproject.toml; WIP demo NB

6226 of 6807 relevant lines covered (91.46%)

10.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.07
/pyrolite/plot/__init__.py
1
"""
2
Submodule with various plotting and visualisation functions.
3
"""
4

5
import warnings
12✔
6

7
import matplotlib
12✔
8
import matplotlib.pyplot as plt
12✔
9
import mpltern
12✔
10
import numpy as np
12✔
11
import pandas as pd
12✔
12

13
warnings.filterwarnings("ignore", "Unknown section")
12✔
14

15
from .. import geochem
12✔
16
from ..comp.codata import ILR, close
12✔
17
from ..util.distributions import get_scaler, sample_kde
12✔
18
from ..util.log import Handle
12✔
19
from ..util.meta import get_additional_params, subkwargs
12✔
20
from ..util.pd import to_frame
12✔
21
from ..util.plot.axes import init_axes, label_axes
12✔
22
from ..util.plot.helpers import plot_cooccurence
12✔
23
from ..util.plot.style import linekwargs, scatterkwargs
12✔
24
from . import density, parallel, spider, stem
12✔
25
from .color import process_color
12✔
26

27
logger = Handle(__name__)
12✔
28

29
# pyroplot added to __all__ for docs
30
__all__ = ["density", "spider", "pyroplot"]
12✔
31

32

33
def _check_components(obj, components=None, check_size=True, valid_sizes=[2, 3]):
12✔
34
    """
35
    Check that the components provided within a dataframe are consistent with the
36
    form of plot being used.
37

38
    Parameters
39
    ----------
40
    obj : :class:`pandas.DataFrame`
41
        Object to check.
42
    components : :class:`list`
43
        List of components, optionally specified.
44
    check_size : :class:`bool`
45
        Whether to verify the size of the column index.
46
    valid_sizes : :class:`list`
47
        Component list lengths which are valid for the plot type.
48

49
    Returns
50
    -------
51
    :class:`list`
52
        Components for the plot.
53
    """
54
    try:
12✔
55
        if check_size and (obj.columns.size not in valid_sizes):
12✔
56
            assert len(components) in valid_sizes
12✔
57

58
        if components is None:
12✔
59
            components = obj.columns.values
12✔
60
    except AssertionError:
12✔
61
        msg = "Suggest components or provide a slice of the dataframe."
×
62
        raise AssertionError(msg)
×
63
    return components
12✔
64

65

66
# note that only some of these methods will be valid for series
67
@pd.api.extensions.register_series_accessor("pyroplot")
12✔
68
@pd.api.extensions.register_dataframe_accessor("pyroplot")
12✔
69
class pyroplot(object):
12✔
70
    def __init__(self, obj):
12✔
71
        """
72
        Custom dataframe accessor for pyrolite plotting.
73

74
        Notes
75
        -----
76
            This accessor enables the coexistence of array-based plotting functions and
77
            methods for pandas objects. This enables some separation of concerns.
78
        """
79
        self._validate(obj)
12✔
80
        self._obj = obj
12✔
81

82
    @staticmethod
12✔
83
    def _validate(obj):
12✔
84
        pass
12✔
85

86
    def cooccurence(self, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
12✔
87
        """
88
        Plot the co-occurence frequency matrix for a given input.
89

90
        Parameters
91
        ----------
92
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
93
            The subplot to draw on.
94
        normalize : :class:`bool`
95
            Whether to normalize the cooccurence to compare disparate variables.
96
        log : :class:`bool`
97
            Whether to take the log of the cooccurence.
98
        colorbar : :class:`bool`
99
            Whether to append a colorbar.
100

101
        Returns
102
        -------
103
        :class:`matplotlib.axes.Axes`
104
            Axes on which the cooccurence plot is added.
105

106
        """
107
        obj = to_frame(self._obj)
12✔
108
        ax = plot_cooccurence(
12✔
109
            obj.values, ax=ax, normalize=normalize, log=log, colorbar=colorbar, **kwargs
110
        )
111
        ax.set_xticklabels(obj.columns, minor=False, rotation=90)
12✔
112
        ax.set_yticklabels(obj.columns, minor=False)
12✔
113
        return ax
12✔
114

115
    def density(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
116
        r"""
117
        Method for plotting histograms (mode='hist2d'|'hexbin') or kernel density
118
        esitimates from point data. Convenience access function to
119
        :func:`~pyrolite.plot.density.density` (see `Other Parameters`, below), where
120
        further parameters for relevant `matplotlib` functions are also listed.
121

122
        Parameters
123
        ----------
124
        components : :class:`list`, :code:`None`
125
            Elements or compositional components to plot.
126
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
127
            The subplot to draw on.
128
        axlabels : :class:`bool`, True
129
            Whether to add x-y axis labels.
130

131
        {otherparams}
132

133
        Returns
134
        -------
135
        :class:`matplotlib.axes.Axes`
136
            Axes on which the density diagram is plotted.
137

138
        """
139
        obj = to_frame(self._obj)
12✔
140
        components = _check_components(obj, components=components)
12✔
141

142
        ax = density.density(
12✔
143
            obj.reindex(columns=components).astype(float).values, ax=ax, **kwargs
144
        )
145
        if axlabels:
12✔
146
            label_axes(ax, labels=components)
12✔
147

148
        return ax
12✔
149

150
    def heatscatter(
12✔
151
        self,
152
        components: list = None,
153
        ax=None,
154
        axlabels=True,
155
        logx=False,
156
        logy=False,
157
        **kwargs,
158
    ):
159
        r"""
160
        Heatmapped scatter plots using the pyroplot API. See further parameters
161
        for `matplotlib.pyplot.scatter` function below.
162

163
        Parameters
164
        ----------
165
        components : :class:`list`, :code:`None`
166
            Elements or compositional components to plot.
167
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
168
            The subplot to draw on.
169
        axlabels : :class:`bool`, :code:`True`
170
            Whether to add x-y axis labels.
171
        logx : :class:`bool`, `False`
172
            Whether to log-transform x values before the KDE for bivariate plots.
173
        logy : :class:`bool`, `False`
174
            Whether to log-transform y values before the KDE for bivariate plots.
175

176
        {otherparams}
177

178
        Returns
179
        -------
180
        :class:`matplotlib.axes.Axes`
181
            Axes on which the heatmapped scatterplot is added.
182

183
        """
184
        obj = to_frame(self._obj)
12✔
185
        components = _check_components(obj, components=components)
12✔
186
        data, samples = (
12✔
187
            obj.reindex(columns=components).values,
188
            obj.reindex(columns=components).values,
189
        )
190
        kdetfm = [  # log transforms
12✔
191
            get_scaler([None, np.log][logx], [None, np.log][logy]),
192
            lambda x: ILR(close(x)),
193
        ][len(components) == 3]
194
        zi = sample_kde(
12✔
195
            data, samples, transform=kdetfm, **subkwargs(kwargs, sample_kde)
196
        )
197
        kwargs.update({"c": zi})
12✔
198
        ax = obj.reindex(columns=components).pyroplot.scatter(
12✔
199
            ax=ax, axlabels=axlabels, **kwargs
200
        )
201
        return ax
12✔
202

203
    def parallel(
12✔
204
        self,
205
        components=None,
206
        rescale=False,
207
        legend=False,
208
        ax=None,
209
        **kwargs,
210
    ):
211
        """
212
        Create a :func:`pyrolite.plot.parallel.parallel`. coordinate plot from
213
        the columns of the :class:`~pandas.DataFrame`.
214

215
        Parameters
216
        ----------
217
        components : :class:`list`, :code:`None`
218
            Components to use as axes for the plot.
219
        rescale : :class:`bool`
220
            Whether to rescale values to [-1, 1].
221
        legend : :class:`bool`, :code:`False`
222
            Whether to include or suppress the legend.
223
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
224
            The subplot to draw on.
225
        {otherparams}
226

227
        Returns
228
        -------
229
        :class:`matplotlib.axes.Axes`
230
            Axes on which the parallel coordinates plot is added.
231

232
        Todo
233
        ----
234
        * Adapt figure size based on number of columns.
235

236
        """
237

238
        obj = to_frame(self._obj)
12✔
239
        ax = parallel.parallel(
12✔
240
            obj,
241
            components=components,
242
            rescale=rescale,
243
            legend=legend,
244
            ax=ax,
245
            **kwargs,
246
        )
247
        return ax
12✔
248

249
    def plot(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
250
        r"""
251
        Convenience method for line plots using the pyroplot API. See
252
        further parameters for `matplotlib.pyplot.scatter` function below.
253

254
        Parameters
255
        ----------
256
        components : :class:`list`, :code:`None`
257
            Elements or compositional components to plot.
258
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
259
            The subplot to draw on.
260
        axlabels : :class:`bool`, :code:`True`
261
            Whether to add x-y axis labels.
262
        {otherparams}
263

264
        Returns
265
        -------
266
        :class:`matplotlib.axes.Axes`
267
            Axes on which the plot is added.
268

269
        """
270
        obj = to_frame(self._obj)
×
271
        components = _check_components(obj, components=components)
×
272
        projection = [None, "ternary"][len(components) == 3]
×
273
        ax = init_axes(ax=ax, projection=projection, **kwargs)
×
274
        kw = linekwargs(kwargs)
×
275
        ax.plot(*obj.reindex(columns=components).values.T, **kw)
×
276
        # if color is multi, could update line colors here
277
        if axlabels:
×
278
            label_axes(ax, labels=components)
×
279

280
        ax.tick_params("both")
×
281
        # ax.grid()
282
        # ax.set_aspect("equal")
283
        return ax
×
284

285
    def REE(
12✔
286
        self,
287
        index="elements",
288
        ax=None,
289
        mode="plot",
290
        dropPm=True,
291
        scatter_kw={},
292
        line_kw={},
293
        **kwargs,
294
    ):
295
        """Pass the pandas object to :func:`pyrolite.plot.spider.REE_v_radii`.
296

297
        Parameters
298
        ----------
299
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
300
            The subplot to draw on.
301
        index : :class:`str`
302
            Whether to plot radii ('radii') on the principal x-axis, or elements
303
            ('elements').
304
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
305
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
306
            a filled range. Density will return a conditional density diagram.
307
        dropPm : :class:`bool`
308
            Whether to exclude the (almost) non-existent element Promethium from the REE
309
            list.
310
        scatter_kw : :class:`dict`
311
            Keyword parameters to be passed to the scatter plotting function.
312
        line_kw : :class:`dict`
313
            Keyword parameters to be passed to the line plotting function.
314
        {otherparams}
315

316
        Returns
317
        -------
318
        :class:`matplotlib.axes.Axes`
319
            Axes on which the REE plot is added.
320

321
        """
322
        obj = to_frame(self._obj)
12✔
323
        ree = [i for i in geochem.ind.REE(dropPm=dropPm) if i in obj.columns]
12✔
324

325
        ax = spider.REE_v_radii(
12✔
326
            obj.reindex(columns=ree).astype(float).values,
327
            index=index,
328
            ree=ree,
329
            mode=mode,
330
            ax=ax,
331
            scatter_kw=scatter_kw,
332
            line_kw=line_kw,
333
            **kwargs,
334
        )
335
        ax.set_ylabel(r"$\mathrm{X / X_{Reference}}$")
12✔
336
        return ax
12✔
337

338
    def scatter(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
339
        r"""
340
        Convenience method for scatter plots using the pyroplot API. See
341
        further parameters for `matplotlib.pyplot.scatter` function below.
342

343
        Parameters
344
        ----------
345
        components : :class:`list`, :code:`None`
346
            Elements or compositional components to plot.
347
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
348
            The subplot to draw on.
349
        axlabels : :class:`bool`, :code:`True`
350
            Whether to add x-y axis labels.
351
        {otherparams}
352

353
        Returns
354
        -------
355
        :class:`matplotlib.axes.Axes`
356
            Axes on which the scatterplot is added.
357

358
        """
359
        obj = to_frame(self._obj)
12✔
360
        components = _check_components(obj, components=components)
12✔
361

362
        projection = [None, "ternary"][len(components) == 3]
12✔
363
        ax = init_axes(ax=ax, projection=projection, **kwargs)
12✔
364
        size = obj.index.size
12✔
365
        kw = process_color(size=size, **kwargs)
12✔
366
        with warnings.catch_warnings():
12✔
367
            # ternary transform where points add to zero will give an unnecessary
368
            # warning; here we supress it
369
            warnings.filterwarnings(
12✔
370
                "ignore", message="invalid value encountered in divide"
371
            )
372
            ax.scatter(*obj.reindex(columns=components).values.T, **scatterkwargs(kw))
12✔
373

374
        if axlabels:
12✔
375
            label_axes(ax, labels=components)
12✔
376

377
        ax.tick_params("both")
12✔
378
        # ax.grid()
379
        # ax.set_aspect("equal")
380
        return ax
12✔
381

382
    def spider(
12✔
383
        self,
384
        components: list = None,
385
        indexes: list = None,
386
        ax=None,
387
        mode="plot",
388
        index_order=None,
389
        autoscale=True,
390
        scatter_kw={},
391
        line_kw={},
392
        **kwargs,
393
    ):
394
        r"""
395
        Method for spider plots. Convenience access function to
396
        :func:`~pyrolite.plot.spider.spider` (see `Other Parameters`, below), where
397
        further parameters for relevant `matplotlib` functions are also listed.
398

399
        Parameters
400
        ----------
401
        components : :class:`list`, `None`
402
            Elements or compositional components to plot.
403
        indexes :  :class:`list`, `None`
404
            Elements or compositional components to plot.
405
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
406
            The subplot to draw on.
407
        index_order
408
            Function to order spider plot indexes (e.g. by incompatibility).
409
        autoscale : :class:`bool`
410
            Whether to autoscale the y-axis limits for standard spider plots.
411
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
412
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
413
            a filled range. Density will return a conditional density diagram.
414
        scatter_kw : :class:`dict`
415
            Keyword parameters to be passed to the scatter plotting function.
416
        line_kw : :class:`dict`
417
            Keyword parameters to be passed to the line plotting function.
418
        {otherparams}
419

420
        Returns
421
        -------
422
        :class:`matplotlib.axes.Axes`
423
            Axes on which the spider diagram is plotted.
424

425
        Todo
426
        ----
427
            * Add 'compositional data' filter for default components if None is given
428

429
        """
430
        obj = to_frame(self._obj)
12✔
431

432
        if components is None:  # default to plotting elemental data
12✔
433
            components = [
12✔
434
                el for el in obj.columns if el in geochem.ind.common_elements()
435
            ]
436

437
        assert len(components) != 0
12✔
438

439
        if index_order is not None:
12✔
440
            if isinstance(index_order, str):
×
441
                try:
×
442
                    index_order = geochem.ind.ordering[index_order]
×
443
                except KeyError:
×
444
                    msg = (
×
445
                        "Ordering not applied, as parameter '{}' not recognized."
446
                        " Select from: {}"
447
                    ).format(index_order, ", ".join(list(geochem.ind.ordering.keys())))
448
                    logger.warning(msg)
×
449
                components = index_order(components)
×
450
            else:
451
                components = index_order(components)
×
452

453
        ax = init_axes(ax=ax, **kwargs)
12✔
454

455
        if hasattr(ax, "_pyrolite_components"):
12✔
456
            # TODO: handle spider diagrams which have specified components
457
            pass
458

459
        ax = spider.spider(
12✔
460
            obj.reindex(columns=components).astype(float).values,
461
            indexes=indexes,
462
            ax=ax,
463
            mode=mode,
464
            autoscale=autoscale,
465
            scatter_kw=scatter_kw,
466
            line_kw=line_kw,
467
            **kwargs,
468
        )
469
        ax._pyrolite_components = components
12✔
470
        ax.set_xticklabels(components, rotation=60)
12✔
471
        return ax
12✔
472

473
    def stem(
12✔
474
        self,
475
        components: list = None,
476
        ax=None,
477
        orientation="horizontal",
478
        axlabels=True,
479
        **kwargs,
480
    ):
481
        r"""
482
        Method for creating stem plots. Convenience access function to
483
        :func:`~pyrolite.plot.stem.stem` (see `Other Parameters`, below), where
484
        further parameters for relevant `matplotlib` functions are also listed.
485

486
        Parameters
487
        ----------
488
        components : :class:`list`, :code:`None`
489
            Elements or compositional components to plot.
490
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
491
            The subplot to draw on.
492
        orientation : :class:`str`
493
            Orientation of the plot (horizontal or vertical).
494
        axlabels : :class:`bool`, True
495
            Whether to add x-y axis labels.
496
        {otherparams}
497

498
        Returns
499
        -------
500
        :class:`matplotlib.axes.Axes`
501
            Axes on which the stem diagram is plotted.
502
        """
503
        obj = to_frame(self._obj)
12✔
504
        components = _check_components(obj, components=components, valid_sizes=[2])
12✔
505

506
        ax = stem.stem(
12✔
507
            *obj.reindex(columns=components).values.T,
508
            ax=ax,
509
            orientation=orientation,
510
            **process_color(**kwargs),
511
        )
512

513
        if axlabels:
12✔
514
            if "h" not in orientation.lower():
12✔
515
                components = components[::-1]
12✔
516
            label_axes(ax, labels=components)
12✔
517

518
        return ax
12✔
519

520

521
# ideally we would i) check for the same params and ii) aggregate all others across
522
# inherited or chained functions. This simply imports the params from another docstring
523
_add_additional_parameters = True
12✔
524

525
pyroplot.density.__doc__ = pyroplot.density.__doc__.format(
12✔
526
    otherparams=[
527
        "",
528
        get_additional_params(
529
            pyroplot.density,
530
            density.density,
531
            header="Other Parameters",
532
            indent=8,
533
            subsections=True,
534
        ),
535
    ][_add_additional_parameters]
536
)
537

538
pyroplot.parallel.__doc__ = pyroplot.parallel.__doc__.format(
12✔
539
    otherparams=[
540
        "",
541
        get_additional_params(
542
            pyroplot.parallel,
543
            parallel.parallel,
544
            header="Other Parameters",
545
            indent=8,
546
            subsections=True,
547
        ),
548
    ][_add_additional_parameters]
549
)
550

551

552
pyroplot.REE.__doc__ = pyroplot.REE.__doc__.format(
12✔
553
    otherparams=[
554
        "",
555
        get_additional_params(
556
            pyroplot.REE,
557
            spider.REE_v_radii,
558
            header="Other Parameters",
559
            indent=8,
560
            subsections=True,
561
        ),
562
    ][_add_additional_parameters]
563
)
564

565

566
pyroplot.scatter.__doc__ = pyroplot.scatter.__doc__.format(
12✔
567
    otherparams=[
568
        "",
569
        get_additional_params(
570
            pyroplot.scatter,
571
            plt.scatter,
572
            header="Other Parameters",
573
            indent=8,
574
            subsections=True,
575
        ),
576
    ][_add_additional_parameters]
577
)
578

579
pyroplot.plot.__doc__ = pyroplot.plot.__doc__.format(
12✔
580
    otherparams=[
581
        "",
582
        get_additional_params(
583
            pyroplot.plot,
584
            plt.plot,
585
            header="Other Parameters",
586
            indent=8,
587
            subsections=True,
588
        ),
589
    ][_add_additional_parameters]
590
)
591

592
pyroplot.spider.__doc__ = pyroplot.spider.__doc__.format(
12✔
593
    otherparams=[
594
        "",
595
        get_additional_params(
596
            pyroplot.spider,
597
            spider.spider,
598
            header="Other Parameters",
599
            indent=8,
600
            subsections=True,
601
        ),
602
    ][_add_additional_parameters]
603
)
604

605

606
pyroplot.stem.__doc__ = pyroplot.stem.__doc__.format(
12✔
607
    otherparams=[
608
        "",
609
        get_additional_params(
610
            pyroplot.stem,
611
            stem.stem,
612
            header="Other Parameters",
613
            indent=8,
614
            subsections=True,
615
        ),
616
    ][_add_additional_parameters]
617
)
618

619
pyroplot.heatscatter.__doc__ = pyroplot.heatscatter.__doc__.format(
12✔
620
    otherparams=[
621
        "",
622
        get_additional_params(
623
            pyroplot.scatter, header="Other Parameters", indent=8, subsections=True
624
        ),
625
    ][_add_additional_parameters]
626
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc