• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 11836306515

14 Nov 2024 11:28AM UTC coverage: 90.339% (-1.3%) from 91.611%
11836306515

push

github

morganjwilliams
Update example notebook

6237 of 6904 relevant lines covered (90.34%)

10.83 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.71
/pyrolite/plot/__init__.py
1
"""
2
Submodule with various plotting and visualisation functions.
3
"""
4

5
import warnings
12✔
6

7
import matplotlib
12✔
8
import matplotlib.pyplot as plt
12✔
9
import mpltern
12✔
10
import numpy as np
12✔
11
import pandas as pd
12✔
12

13
warnings.filterwarnings("ignore", "Unknown section")
12✔
14

15
from .. import geochem
12✔
16
from ..comp.codata import ILR, close
12✔
17
from ..util.distributions import get_scaler, sample_kde
12✔
18
from ..util.log import Handle
12✔
19
from ..util.meta import get_additional_params, subkwargs
12✔
20
from ..util.pd import to_frame, _check_components
12✔
21
from ..util.plot.axes import init_axes, label_axes
12✔
22
from ..util.plot.helpers import plot_cooccurence
12✔
23
from ..util.plot.style import _export_nonRCstyles, linekwargs, scatterkwargs
12✔
24
from . import density, parallel, spider, stem
12✔
25
from .color import process_color
12✔
26

27
logger = Handle(__name__)
12✔
28

29
# pyroplot added to __all__ for docs
30
__all__ = ["density", "spider", "pyroplot"]
12✔
31

32

33
class pyroplot_matplotlib(object):
12✔
34
    def __init__(self, obj):
12✔
35
        """
36
        Custom dataframe accessor for pyrolite plotting.
37

38
        Notes
39
        -----
40
            This accessor enables the coexistence of array-based plotting functions and
41
            methods for pandas objects. This enables some separation of concerns.
42
        """
43
        self._validate(obj)
12✔
44
        self._obj = obj
12✔
45

46
        # refresh custom styling on creation?
47
        _export_nonRCstyles()
12✔
48

49
    @staticmethod
12✔
50
    def _validate(obj):
12✔
51
        pass
12✔
52

53
    def cooccurence(self, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
12✔
54
        """
55
        Plot the co-occurence frequency matrix for a given input.
56

57
        Parameters
58
        ----------
59
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
60
            The subplot to draw on.
61
        normalize : :class:`bool`
62
            Whether to normalize the cooccurence to compare disparate variables.
63
        log : :class:`bool`
64
            Whether to take the log of the cooccurence.
65
        colorbar : :class:`bool`
66
            Whether to append a colorbar.
67

68
        Returns
69
        -------
70
        :class:`matplotlib.axes.Axes`
71
            Axes on which the cooccurence plot is added.
72

73
        """
74
        obj = to_frame(self._obj)
12✔
75
        ax = plot_cooccurence(
12✔
76
            obj.values, ax=ax, normalize=normalize, log=log, colorbar=colorbar, **kwargs
77
        )
78
        ax.set_xticklabels(obj.columns, minor=False, rotation=90)
12✔
79
        ax.set_yticklabels(obj.columns, minor=False)
12✔
80
        return ax
12✔
81

82
    def density(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
83
        r"""
84
        Method for plotting histograms (mode='hist2d'|'hexbin') or kernel density
85
        esitimates from point data. Convenience access function to
86
        :func:`~pyrolite.plot.density.density` (see `Other Parameters`, below), where
87
        further parameters for relevant `matplotlib` functions are also listed.
88

89
        Parameters
90
        ----------
91
        components : :class:`list`, :code:`None`
92
            Elements or compositional components to plot.
93
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
94
            The subplot to draw on.
95
        axlabels : :class:`bool`, True
96
            Whether to add x-y axis labels.
97

98
        {otherparams}
99

100
        Returns
101
        -------
102
        :class:`matplotlib.axes.Axes`
103
            Axes on which the density diagram is plotted.
104

105
        """
106
        obj = to_frame(self._obj)
12✔
107
        components = _check_components(obj, components=components)
12✔
108

109
        ax = density.density(
12✔
110
            obj.reindex(columns=components).astype(float).values, ax=ax, **kwargs
111
        )
112
        if axlabels:
12✔
113
            label_axes(ax, labels=components)
12✔
114

115
        return ax
12✔
116

117
    def heatscatter(
12✔
118
        self,
119
        components: list = None,
120
        ax=None,
121
        axlabels=True,
122
        logx=False,
123
        logy=False,
124
        **kwargs,
125
    ):
126
        r"""
127
        Heatmapped scatter plots using the pyroplot API. See further parameters
128
        for `matplotlib.pyplot.scatter` function below.
129

130
        Parameters
131
        ----------
132
        components : :class:`list`, :code:`None`
133
            Elements or compositional components to plot.
134
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
135
            The subplot to draw on.
136
        axlabels : :class:`bool`, :code:`True`
137
            Whether to add x-y axis labels.
138
        logx : :class:`bool`, `False`
139
            Whether to log-transform x values before the KDE for bivariate plots.
140
        logy : :class:`bool`, `False`
141
            Whether to log-transform y values before the KDE for bivariate plots.
142

143
        {otherparams}
144

145
        Returns
146
        -------
147
        :class:`matplotlib.axes.Axes`
148
            Axes on which the heatmapped scatterplot is added.
149

150
        """
151
        obj = to_frame(self._obj)
12✔
152
        components = _check_components(obj, components=components)
12✔
153
        data, samples = (
12✔
154
            obj.reindex(columns=components).values,
155
            obj.reindex(columns=components).values,
156
        )
157
        kdetfm = [  # log transforms
12✔
158
            get_scaler([None, np.log][logx], [None, np.log][logy]),
159
            lambda x: ILR(close(x)),
160
        ][len(components) == 3]
161
        zi = sample_kde(
12✔
162
            data, samples, transform=kdetfm, **subkwargs(kwargs, sample_kde)
163
        )
164
        kwargs.update({"c": zi})
12✔
165
        ax = obj.reindex(columns=components).pyroplot.scatter(
12✔
166
            ax=ax, axlabels=axlabels, **kwargs
167
        )
168
        return ax
12✔
169

170
    def parallel(
12✔
171
        self,
172
        components=None,
173
        rescale=False,
174
        legend=False,
175
        ax=None,
176
        **kwargs,
177
    ):
178
        """
179
        Create a :func:`pyrolite.plot.parallel.parallel`. coordinate plot from
180
        the columns of the :class:`~pandas.DataFrame`.
181

182
        Parameters
183
        ----------
184
        components : :class:`list`, :code:`None`
185
            Components to use as axes for the plot.
186
        rescale : :class:`bool`
187
            Whether to rescale values to [-1, 1].
188
        legend : :class:`bool`, :code:`False`
189
            Whether to include or suppress the legend.
190
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
191
            The subplot to draw on.
192
        {otherparams}
193

194
        Returns
195
        -------
196
        :class:`matplotlib.axes.Axes`
197
            Axes on which the parallel coordinates plot is added.
198

199
        Todo
200
        ----
201
        * Adapt figure size based on number of columns.
202

203
        """
204

205
        obj = to_frame(self._obj)
12✔
206
        ax = parallel.parallel(
12✔
207
            obj,
208
            components=components,
209
            rescale=rescale,
210
            legend=legend,
211
            ax=ax,
212
            **kwargs,
213
        )
214
        return ax
12✔
215

216
    def plot(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
217
        r"""
218
        Convenience method for line plots using the pyroplot API. See
219
        further parameters for `matplotlib.pyplot.scatter` function below.
220

221
        Parameters
222
        ----------
223
        components : :class:`list`, :code:`None`
224
            Elements or compositional components to plot.
225
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
226
            The subplot to draw on.
227
        axlabels : :class:`bool`, :code:`True`
228
            Whether to add x-y axis labels.
229
        {otherparams}
230

231
        Returns
232
        -------
233
        :class:`matplotlib.axes.Axes`
234
            Axes on which the plot is added.
235

236
        """
237
        obj = to_frame(self._obj)
×
238
        components = _check_components(obj, components=components)
×
239
        projection = [None, "ternary"][len(components) == 3]
×
240
        ax = init_axes(ax=ax, projection=projection, **kwargs)
×
241
        kw = linekwargs(kwargs)
×
242
        ax.plot(*obj.reindex(columns=components).values.T, **kw)
×
243
        # if color is multi, could update line colors here
244
        if axlabels:
×
245
            label_axes(ax, labels=components)
×
246

247
        ax.tick_params("both")
×
248
        # ax.grid()
249
        # ax.set_aspect("equal")
250
        return ax
×
251

252
    def REE(
12✔
253
        self,
254
        index="elements",
255
        ax=None,
256
        mode="plot",
257
        dropPm=True,
258
        scatter_kw={},
259
        line_kw={},
260
        **kwargs,
261
    ):
262
        """Pass the pandas object to :func:`pyrolite.plot.spider.REE_v_radii`.
263

264
        Parameters
265
        ----------
266
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
267
            The subplot to draw on.
268
        index : :class:`str`
269
            Whether to plot radii ('radii') on the principal x-axis, or elements
270
            ('elements').
271
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
272
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
273
            a filled range. Density will return a conditional density diagram.
274
        dropPm : :class:`bool`
275
            Whether to exclude the (almost) non-existent element Promethium from the REE
276
            list.
277
        scatter_kw : :class:`dict`
278
            Keyword parameters to be passed to the scatter plotting function.
279
        line_kw : :class:`dict`
280
            Keyword parameters to be passed to the line plotting function.
281
        {otherparams}
282

283
        Returns
284
        -------
285
        :class:`matplotlib.axes.Axes`
286
            Axes on which the REE plot is added.
287

288
        """
289
        obj = to_frame(self._obj)
12✔
290
        ree = [i for i in geochem.ind.REE(dropPm=dropPm) if i in obj.columns]
12✔
291

292
        ax = spider.REE_v_radii(
12✔
293
            obj.reindex(columns=ree).astype(float).values,
294
            index=index,
295
            ree=ree,
296
            mode=mode,
297
            ax=ax,
298
            scatter_kw=scatter_kw,
299
            line_kw=line_kw,
300
            **kwargs,
301
        )
302
        ax.set_ylabel(r"$\mathrm{X / X_{Reference}}$")
12✔
303
        return ax
12✔
304

305
    def scatter(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
306
        r"""
307
        Convenience method for scatter plots using the pyroplot API. See
308
        further parameters for `matplotlib.pyplot.scatter` function below.
309

310
        Parameters
311
        ----------
312
        components : :class:`list`, :code:`None`
313
            Elements or compositional components to plot.
314
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
315
            The subplot to draw on.
316
        axlabels : :class:`bool`, :code:`True`
317
            Whether to add x-y axis labels.
318
        {otherparams}
319

320
        Returns
321
        -------
322
        :class:`matplotlib.axes.Axes`
323
            Axes on which the scatterplot is added.
324

325
        """
326
        obj = to_frame(self._obj)
12✔
327
        components = _check_components(obj, components=components)
12✔
328

329
        projection = [None, "ternary"][len(components) == 3]
12✔
330
        ax = init_axes(ax=ax, projection=projection, **kwargs)
12✔
331
        size = obj.index.size
12✔
332
        kw = process_color(size=size, **kwargs)
12✔
333
        with warnings.catch_warnings():
12✔
334
            # ternary transform where points add to zero will give an unnecessary
335
            # warning; here we supress it
336
            warnings.filterwarnings(
12✔
337
                "ignore", message="invalid value encountered in divide"
338
            )
339
            ax.scatter(*obj.reindex(columns=components).values.T, **scatterkwargs(kw))
12✔
340

341
        if axlabels:
12✔
342
            label_axes(ax, labels=components)
12✔
343

344
        ax.tick_params("both")
12✔
345
        # ax.grid()
346
        # ax.set_aspect("equal")
347
        return ax
12✔
348

349
    def spider(
12✔
350
        self,
351
        components: list = None,
352
        indexes: list = None,
353
        ax=None,
354
        mode="plot",
355
        index_order=None,
356
        autoscale=True,
357
        scatter_kw={},
358
        line_kw={},
359
        **kwargs,
360
    ):
361
        r"""
362
        Method for spider plots. Convenience access function to
363
        :func:`~pyrolite.plot.spider.spider` (see `Other Parameters`, below), where
364
        further parameters for relevant `matplotlib` functions are also listed.
365

366
        Parameters
367
        ----------
368
        components : :class:`list`, `None`
369
            Elements or compositional components to plot.
370
        indexes :  :class:`list`, `None`
371
            Elements or compositional components to plot.
372
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
373
            The subplot to draw on.
374
        index_order
375
            Function to order spider plot indexes (e.g. by incompatibility).
376
        autoscale : :class:`bool`
377
            Whether to autoscale the y-axis limits for standard spider plots.
378
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
379
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
380
            a filled range. Density will return a conditional density diagram.
381
        scatter_kw : :class:`dict`
382
            Keyword parameters to be passed to the scatter plotting function.
383
        line_kw : :class:`dict`
384
            Keyword parameters to be passed to the line plotting function.
385
        {otherparams}
386

387
        Returns
388
        -------
389
        :class:`matplotlib.axes.Axes`
390
            Axes on which the spider diagram is plotted.
391

392
        Todo
393
        ----
394
            * Add 'compositional data' filter for default components if None is given
395

396
        """
397
        obj = to_frame(self._obj)
12✔
398

399
        if components is None:  # default to plotting elemental data
12✔
400
            components = [
12✔
401
                el for el in obj.columns if el in geochem.ind.common_elements()
402
            ]
403

404
        assert len(components) != 0
12✔
405

406
        if index_order is not None:
12✔
407
            if isinstance(index_order, str):
×
408
                try:
×
409
                    index_order = geochem.ind.ordering[index_order]
×
410
                except KeyError:
×
411
                    msg = (
×
412
                        "Ordering not applied, as parameter '{}' not recognized."
413
                        " Select from: {}"
414
                    ).format(index_order, ", ".join(list(geochem.ind.ordering.keys())))
415
                    logger.warning(msg)
×
416
                components = index_order(components)
×
417
            else:
418
                components = index_order(components)
×
419

420
        ax = init_axes(ax=ax, **kwargs)
12✔
421

422
        if hasattr(ax, "_pyrolite_components"):
12✔
423
            # TODO: handle spider diagrams which have specified components
424
            pass
425

426
        ax = spider.spider(
12✔
427
            obj.reindex(columns=components).astype(float).values,
428
            indexes=indexes,
429
            ax=ax,
430
            mode=mode,
431
            autoscale=autoscale,
432
            scatter_kw=scatter_kw,
433
            line_kw=line_kw,
434
            **kwargs,
435
        )
436
        ax._pyrolite_components = components
12✔
437
        ax.set_xticklabels(components, rotation=60)
12✔
438
        return ax
12✔
439

440
    def stem(
12✔
441
        self,
442
        components: list = None,
443
        ax=None,
444
        orientation="horizontal",
445
        axlabels=True,
446
        **kwargs,
447
    ):
448
        r"""
449
        Method for creating stem plots. Convenience access function to
450
        :func:`~pyrolite.plot.stem.stem` (see `Other Parameters`, below), where
451
        further parameters for relevant `matplotlib` functions are also listed.
452

453
        Parameters
454
        ----------
455
        components : :class:`list`, :code:`None`
456
            Elements or compositional components to plot.
457
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
458
            The subplot to draw on.
459
        orientation : :class:`str`
460
            Orientation of the plot (horizontal or vertical).
461
        axlabels : :class:`bool`, True
462
            Whether to add x-y axis labels.
463
        {otherparams}
464

465
        Returns
466
        -------
467
        :class:`matplotlib.axes.Axes`
468
            Axes on which the stem diagram is plotted.
469
        """
470
        obj = to_frame(self._obj)
12✔
471
        components = _check_components(obj, components=components, valid_sizes=[2])
12✔
472

473
        ax = stem.stem(
12✔
474
            *obj.reindex(columns=components).values.T,
475
            ax=ax,
476
            orientation=orientation,
477
            **process_color(**kwargs),
478
        )
479

480
        if axlabels:
12✔
481
            if "h" not in orientation.lower():
12✔
482
                components = components[::-1]
12✔
483
            label_axes(ax, labels=components)
12✔
484

485
        return ax
12✔
486

487

488
pyroplot = pyroplot_matplotlib
12✔
489

490

491
# note that only some of these methods will be valid for series
492
pd.api.extensions.register_series_accessor("pyroplot")(pyroplot)
12✔
493
pd.api.extensions.register_dataframe_accessor("pyroplot")(pyroplot)
12✔
494

495

496
# ideally we would i) check for the same params and ii) aggregate all others across
497
# inherited or chained functions. This simply imports the params from another docstring
498
_add_additional_parameters = True
12✔
499

500
pyroplot.density.__doc__ = pyroplot.density.__doc__.format(
12✔
501
    otherparams=[
502
        "",
503
        get_additional_params(
504
            pyroplot.density,
505
            density.density,
506
            header="Other Parameters",
507
            indent=8,
508
            subsections=True,
509
        ),
510
    ][_add_additional_parameters]
511
)
512

513
pyroplot.parallel.__doc__ = pyroplot.parallel.__doc__.format(
12✔
514
    otherparams=[
515
        "",
516
        get_additional_params(
517
            pyroplot.parallel,
518
            parallel.parallel,
519
            header="Other Parameters",
520
            indent=8,
521
            subsections=True,
522
        ),
523
    ][_add_additional_parameters]
524
)
525

526

527
pyroplot.REE.__doc__ = pyroplot.REE.__doc__.format(
12✔
528
    otherparams=[
529
        "",
530
        get_additional_params(
531
            pyroplot.REE,
532
            spider.REE_v_radii,
533
            header="Other Parameters",
534
            indent=8,
535
            subsections=True,
536
        ),
537
    ][_add_additional_parameters]
538
)
539

540

541
pyroplot.scatter.__doc__ = pyroplot.scatter.__doc__.format(
12✔
542
    otherparams=[
543
        "",
544
        get_additional_params(
545
            pyroplot.scatter,
546
            plt.scatter,
547
            header="Other Parameters",
548
            indent=8,
549
            subsections=True,
550
        ),
551
    ][_add_additional_parameters]
552
)
553

554
pyroplot.plot.__doc__ = pyroplot.plot.__doc__.format(
12✔
555
    otherparams=[
556
        "",
557
        get_additional_params(
558
            pyroplot.plot,
559
            plt.plot,
560
            header="Other Parameters",
561
            indent=8,
562
            subsections=True,
563
        ),
564
    ][_add_additional_parameters]
565
)
566

567
pyroplot.spider.__doc__ = pyroplot.spider.__doc__.format(
12✔
568
    otherparams=[
569
        "",
570
        get_additional_params(
571
            pyroplot.spider,
572
            spider.spider,
573
            header="Other Parameters",
574
            indent=8,
575
            subsections=True,
576
        ),
577
    ][_add_additional_parameters]
578
)
579

580

581
pyroplot.stem.__doc__ = pyroplot.stem.__doc__.format(
12✔
582
    otherparams=[
583
        "",
584
        get_additional_params(
585
            pyroplot.stem,
586
            stem.stem,
587
            header="Other Parameters",
588
            indent=8,
589
            subsections=True,
590
        ),
591
    ][_add_additional_parameters]
592
)
593

594
pyroplot.heatscatter.__doc__ = pyroplot.heatscatter.__doc__.format(
12✔
595
    otherparams=[
596
        "",
597
        get_additional_params(
598
            pyroplot.scatter, header="Other Parameters", indent=8, subsections=True
599
        ),
600
    ][_add_additional_parameters]
601
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc