• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 5564819928

pending completion
5564819928

push

github

morganjwilliams
Merge branch 'release/0.3.3' into main

249 of 270 new or added lines in 48 files covered. (92.22%)

217 existing lines in 33 files now uncovered.

5971 of 6605 relevant lines covered (90.4%)

10.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.67
/pyrolite/plot/__init__.py
1
"""
2
Submodule with various plotting and visualisation functions.
3
"""
4
import warnings
12✔
5

6
import matplotlib
12✔
7
import matplotlib.pyplot as plt
12✔
8
import mpltern
12✔
9
import numpy as np
12✔
10
import pandas as pd
12✔
11

12
warnings.filterwarnings("ignore", "Unknown section")
12✔
13

14
from .. import geochem
12✔
15
from ..comp.codata import ILR, close
12✔
16
from ..util.distributions import get_scaler, sample_kde
12✔
17
from ..util.log import Handle
12✔
18
from ..util.meta import get_additional_params, subkwargs
12✔
19
from ..util.pd import to_frame
12✔
20
from ..util.plot.axes import init_axes, label_axes
12✔
21
from ..util.plot.helpers import plot_cooccurence
12✔
22
from ..util.plot.style import _export_nonRCstyles, linekwargs, scatterkwargs
12✔
23
from . import density, parallel, spider, stem
12✔
24
from .color import process_color
12✔
25

26
logger = Handle(__name__)
12✔
27

28
# pyroplot added to __all__ for docs
29
__all__ = ["density", "spider", "pyroplot"]
12✔
30

31

32
def _check_components(obj, components=None, check_size=True, valid_sizes=[2, 3]):
12✔
33
    """
34
    Check that the components provided within a dataframe are consistent with the
35
    form of plot being used.
36

37
    Parameters
38
    ----------
39
    obj : :class:`pandas.DataFrame`
40
        Object to check.
41
    components : :class:`list`
42
        List of components, optionally specified.
43
    check_size : :class:`bool`
44
        Whether to verify the size of the column index.
45
    valid_sizes : :class:`list`
46
        Component list lengths which are valid for the plot type.
47

48
    Returns
49
    -------
50
    :class:`list`
51
        Components for the plot.
52
    """
53
    try:
12✔
54
        if check_size and (obj.columns.size not in valid_sizes):
12✔
55
            assert len(components) in valid_sizes
12✔
56

57
        if components is None:
12✔
58
            components = obj.columns.values
12✔
59
    except:
12✔
60
        msg = "Suggest components or provide a slice of the dataframe."
12✔
61
        raise AssertionError(msg)
12✔
62
    return components
12✔
63

64

65
# note that only some of these methods will be valid for series
66
@pd.api.extensions.register_series_accessor("pyroplot")
12✔
67
@pd.api.extensions.register_dataframe_accessor("pyroplot")
12✔
68
class pyroplot(object):
12✔
69
    def __init__(self, obj):
12✔
70
        """
71
        Custom dataframe accessor for pyrolite plotting.
72

73
        Notes
74
        -----
75
            This accessor enables the coexistence of array-based plotting functions and
76
            methods for pandas objects. This enables some separation of concerns.
77
        """
78
        self._validate(obj)
12✔
79
        self._obj = obj
12✔
80

81
        # refresh custom styling on creation?
82
        _export_nonRCstyles()
12✔
83

84
    @staticmethod
12✔
85
    def _validate(obj):
12✔
86
        pass
12✔
87

88
    def cooccurence(self, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
12✔
89
        """
90
        Plot the co-occurence frequency matrix for a given input.
91

92
        Parameters
93
        ----------
94
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
95
            The subplot to draw on.
96
        normalize : :class:`bool`
97
            Whether to normalize the cooccurence to compare disparate variables.
98
        log : :class:`bool`
99
            Whether to take the log of the cooccurence.
100
        colorbar : :class:`bool`
101
            Whether to append a colorbar.
102

103
        Returns
104
        -------
105
        :class:`matplotlib.axes.Axes`
106
            Axes on which the cooccurence plot is added.
107

108
        """
109
        obj = to_frame(self._obj)
12✔
110
        ax = plot_cooccurence(
12✔
111
            obj.values, ax=ax, normalize=normalize, log=log, colorbar=colorbar, **kwargs
112
        )
113
        ax.set_xticklabels(obj.columns, minor=False, rotation=90)
12✔
114
        ax.set_yticklabels(obj.columns, minor=False)
12✔
115
        return ax
12✔
116

117
    def density(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
118
        r"""
119
        Method for plotting histograms (mode='hist2d'|'hexbin') or kernel density
120
        esitimates from point data. Convenience access function to
121
        :func:`~pyrolite.plot.density.density` (see `Other Parameters`, below), where
122
        further parameters for relevant `matplotlib` functions are also listed.
123

124
        Parameters
125
        ----------
126
        components : :class:`list`, :code:`None`
127
            Elements or compositional components to plot.
128
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
129
            The subplot to draw on.
130
        axlabels : :class:`bool`, True
131
            Whether to add x-y axis labels.
132

133
        {otherparams}
134

135
        Returns
136
        -------
137
        :class:`matplotlib.axes.Axes`
138
            Axes on which the density diagram is plotted.
139

140
        """
141
        obj = to_frame(self._obj)
12✔
142
        components = _check_components(obj, components=components)
12✔
143

144
        ax = density.density(
12✔
145
            obj.reindex(columns=components).astype(float).values, ax=ax, **kwargs
146
        )
147
        if axlabels:
12✔
148
            label_axes(ax, labels=components)
12✔
149

150
        return ax
12✔
151

152
    def heatscatter(
12✔
153
        self,
154
        components: list = None,
155
        ax=None,
156
        axlabels=True,
157
        logx=False,
158
        logy=False,
159
        **kwargs,
160
    ):
161
        r"""
162
        Heatmapped scatter plots using the pyroplot API. See further parameters
163
        for `matplotlib.pyplot.scatter` function below.
164

165
        Parameters
166
        ----------
167
        components : :class:`list`, :code:`None`
168
            Elements or compositional components to plot.
169
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
170
            The subplot to draw on.
171
        axlabels : :class:`bool`, :code:`True`
172
            Whether to add x-y axis labels.
173
        logx : :class:`bool`, `False`
174
            Whether to log-transform x values before the KDE for bivariate plots.
175
        logy : :class:`bool`, `False`
176
            Whether to log-transform y values before the KDE for bivariate plots.
177

178
        {otherparams}
179

180
        Returns
181
        -------
182
        :class:`matplotlib.axes.Axes`
183
            Axes on which the heatmapped scatterplot is added.
184

185
        """
186
        obj = to_frame(self._obj)
12✔
187
        components = _check_components(obj, components=components)
12✔
188
        data, samples = (
12✔
189
            obj.reindex(columns=components).values,
190
            obj.reindex(columns=components).values,
191
        )
192
        kdetfm = [  # log transforms
12✔
193
            get_scaler([None, np.log][logx], [None, np.log][logy]),
194
            lambda x: ILR(close(x)),
195
        ][len(components) == 3]
196
        zi = sample_kde(
12✔
197
            data, samples, transform=kdetfm, **subkwargs(kwargs, sample_kde)
198
        )
199
        kwargs.update({"c": zi})
12✔
200
        ax = obj.reindex(columns=components).pyroplot.scatter(
12✔
201
            ax=ax, axlabels=axlabels, **kwargs
202
        )
203
        return ax
12✔
204

205
    def parallel(
12✔
206
        self,
207
        components=None,
208
        rescale=False,
209
        legend=False,
210
        ax=None,
211
        **kwargs,
212
    ):
213
        """
214
        Create a :func:`pyrolite.plot.parallel.parallel`. coordinate plot from
215
        the columns of the :class:`~pandas.DataFrame`.
216

217
        Parameters
218
        ----------
219
        components : :class:`list`, :code:`None`
220
            Components to use as axes for the plot.
221
        rescale : :class:`bool`
222
            Whether to rescale values to [-1, 1].
223
        legend : :class:`bool`, :code:`False`
224
            Whether to include or suppress the legend.
225
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
226
            The subplot to draw on.
227
        {otherparams}
228

229
        Returns
230
        -------
231
        :class:`matplotlib.axes.Axes`
232
            Axes on which the parallel coordinates plot is added.
233

234
        Todo
235
        ----
236
        * Adapt figure size based on number of columns.
237

238
        """
239

240
        obj = to_frame(self._obj)
12✔
241
        ax = parallel.parallel(
12✔
242
            obj,
243
            components=components,
244
            rescale=rescale,
245
            legend=legend,
246
            ax=ax,
247
            **kwargs,
248
        )
249
        return ax
12✔
250

251
    def plot(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
252
        r"""
253
        Convenience method for line plots using the pyroplot API. See
254
        further parameters for `matplotlib.pyplot.scatter` function below.
255

256
        Parameters
257
        ----------
258
        components : :class:`list`, :code:`None`
259
            Elements or compositional components to plot.
260
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
261
            The subplot to draw on.
262
        axlabels : :class:`bool`, :code:`True`
263
            Whether to add x-y axis labels.
264
        {otherparams}
265

266
        Returns
267
        -------
268
        :class:`matplotlib.axes.Axes`
269
            Axes on which the plot is added.
270

271
        """
272
        obj = to_frame(self._obj)
×
273
        components = _check_components(obj, components=components)
×
UNCOV
274
        projection = [None, "ternary"][len(components) == 3]
×
275
        ax = init_axes(ax=ax, projection=projection, **kwargs)
×
276
        kw = linekwargs(kwargs)
×
NEW
277
        ax.plot(*obj.reindex(columns=components).values.T, **kw)
×
278
        # if color is multi, could update line colors here
UNCOV
279
        if axlabels:
×
UNCOV
280
            label_axes(ax, labels=components)
×
281

UNCOV
282
        ax.tick_params("both")
×
283
        # ax.grid()
284
        # ax.set_aspect("equal")
UNCOV
285
        return ax
×
286

287
    def REE(
12✔
288
        self,
289
        index="elements",
290
        ax=None,
291
        mode="plot",
292
        dropPm=True,
293
        scatter_kw={},
294
        line_kw={},
295
        **kwargs,
296
    ):
297
        """Pass the pandas object to :func:`pyrolite.plot.spider.REE_v_radii`.
298

299
        Parameters
300
        ----------
301
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
302
            The subplot to draw on.
303
        index : :class:`str`
304
            Whether to plot radii ('radii') on the principal x-axis, or elements
305
            ('elements').
306
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
307
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
308
            a filled range. Density will return a conditional density diagram.
309
        dropPm : :class:`bool`
310
            Whether to exclude the (almost) non-existent element Promethium from the REE
311
            list.
312
        scatter_kw : :class:`dict`
313
            Keyword parameters to be passed to the scatter plotting function.
314
        line_kw : :class:`dict`
315
            Keyword parameters to be passed to the line plotting function.
316
        {otherparams}
317

318
        Returns
319
        -------
320
        :class:`matplotlib.axes.Axes`
321
            Axes on which the REE plot is added.
322

323
        """
324
        obj = to_frame(self._obj)
12✔
325
        ree = [i for i in geochem.ind.REE(dropPm=dropPm) if i in obj.columns]
12✔
326

327
        ax = spider.REE_v_radii(
12✔
328
            obj.reindex(columns=ree).astype(float).values,
329
            index=index,
330
            ree=ree,
331
            mode=mode,
332
            ax=ax,
333
            scatter_kw=scatter_kw,
334
            line_kw=line_kw,
335
            **kwargs,
336
        )
337
        ax.set_ylabel(r"$\mathrm{X / X_{Reference}}$")
12✔
338
        return ax
12✔
339

340
    def scatter(self, components: list = None, ax=None, axlabels=True, **kwargs):
12✔
341
        r"""
342
        Convenience method for scatter plots using the pyroplot API. See
343
        further parameters for `matplotlib.pyplot.scatter` function below.
344

345
        Parameters
346
        ----------
347
        components : :class:`list`, :code:`None`
348
            Elements or compositional components to plot.
349
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
350
            The subplot to draw on.
351
        axlabels : :class:`bool`, :code:`True`
352
            Whether to add x-y axis labels.
353
        {otherparams}
354

355
        Returns
356
        -------
357
        :class:`matplotlib.axes.Axes`
358
            Axes on which the scatterplot is added.
359

360
        """
361
        obj = to_frame(self._obj)
12✔
362
        components = _check_components(obj, components=components)
12✔
363

364
        projection = [None, "ternary"][len(components) == 3]
12✔
365
        ax = init_axes(ax=ax, projection=projection, **kwargs)
12✔
366
        size = obj.index.size
12✔
367
        kw = process_color(size=size, **kwargs)
12✔
368
        with warnings.catch_warnings():
12✔
369
            # ternary transform where points add to zero will give an unnecessary
370
            # warning; here we supress it
371
            warnings.filterwarnings(
12✔
372
                "ignore", message="invalid value encountered in divide"
373
            )
374
            ax.scatter(*obj.reindex(columns=components).values.T, **scatterkwargs(kw))
12✔
375

376
        if axlabels:
12✔
377
            label_axes(ax, labels=components)
12✔
378

379
        ax.tick_params("both")
12✔
380
        # ax.grid()
381
        # ax.set_aspect("equal")
382
        return ax
12✔
383

384
    def spider(
12✔
385
        self,
386
        components: list = None,
387
        indexes: list = None,
388
        ax=None,
389
        mode="plot",
390
        index_order=None,
391
        autoscale=True,
392
        scatter_kw={},
393
        line_kw={},
394
        **kwargs,
395
    ):
396
        r"""
397
        Method for spider plots. Convenience access function to
398
        :func:`~pyrolite.plot.spider.spider` (see `Other Parameters`, below), where
399
        further parameters for relevant `matplotlib` functions are also listed.
400

401
        Parameters
402
        ----------
403
        components : :class:`list`, `None`
404
            Elements or compositional components to plot.
405
        indexes :  :class:`list`, `None`
406
            Elements or compositional components to plot.
407
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
408
            The subplot to draw on.
409
        index_order
410
            Function to order spider plot indexes (e.g. by incompatibility).
411
        autoscale : :class:`bool`
412
            Whether to autoscale the y-axis limits for standard spider plots.
413
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
414
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
415
            a filled range. Density will return a conditional density diagram.
416
        scatter_kw : :class:`dict`
417
            Keyword parameters to be passed to the scatter plotting function.
418
        line_kw : :class:`dict`
419
            Keyword parameters to be passed to the line plotting function.
420
        {otherparams}
421

422
        Returns
423
        -------
424
        :class:`matplotlib.axes.Axes`
425
            Axes on which the spider diagram is plotted.
426

427
        Todo
428
        ----
429
            * Add 'compositional data' filter for default components if None is given
430

431
        """
432
        obj = to_frame(self._obj)
12✔
433

434
        if components is None:  # default to plotting elemental data
12✔
435
            components = [
12✔
436
                el for el in obj.columns if el in geochem.ind.common_elements()
437
            ]
438

439
        assert len(components) != 0
12✔
440

441
        if index_order is not None:
12✔
UNCOV
442
            if isinstance(index_order, str):
×
UNCOV
443
                try:
×
444
                    index_order = geochem.ind.ordering[index_order]
×
445
                except KeyError:
×
446
                    msg = (
×
447
                        "Ordering not applied, as parameter '{}' not recognized."
448
                        " Select from: {}"
449
                    ).format(index_order, ", ".join(list(geochem.ind.ordering.keys())))
UNCOV
450
                    logger.warning(msg)
×
UNCOV
451
                components = index_order(components)
×
452
            else:
453
                components = index_order(components)
×
454

455
        ax = init_axes(ax=ax, **kwargs)
12✔
456

457
        if hasattr(ax, "_pyrolite_components"):
12✔
458
            pass
459

460
        ax = spider.spider(
12✔
461
            obj.reindex(columns=components).astype(float).values,
462
            indexes=indexes,
463
            ax=ax,
464
            mode=mode,
465
            autoscale=autoscale,
466
            scatter_kw=scatter_kw,
467
            line_kw=line_kw,
468
            **kwargs,
469
        )
470
        ax._pyrolite_components = components
12✔
471
        ax.set_xticklabels(components, rotation=60)
12✔
472
        return ax
12✔
473

474
    def stem(
12✔
475
        self,
476
        components: list = None,
477
        ax=None,
478
        orientation="horizontal",
479
        axlabels=True,
480
        **kwargs,
481
    ):
482
        r"""
483
        Method for creating stem plots. Convenience access function to
484
        :func:`~pyrolite.plot.stem.stem` (see `Other Parameters`, below), where
485
        further parameters for relevant `matplotlib` functions are also listed.
486

487
        Parameters
488
        ----------
489
        components : :class:`list`, :code:`None`
490
            Elements or compositional components to plot.
491
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
492
            The subplot to draw on.
493
        orientation : :class:`str`
494
            Orientation of the plot (horizontal or vertical).
495
        axlabels : :class:`bool`, True
496
            Whether to add x-y axis labels.
497
        {otherparams}
498

499
        Returns
500
        -------
501
        :class:`matplotlib.axes.Axes`
502
            Axes on which the stem diagram is plotted.
503
        """
504
        obj = to_frame(self._obj)
12✔
505
        components = _check_components(obj, components=components, valid_sizes=[2])
12✔
506

507
        ax = stem.stem(
12✔
508
            *obj.reindex(columns=components).values.T,
509
            ax=ax,
510
            orientation=orientation,
511
            **process_color(**kwargs),
512
        )
513

514
        if axlabels:
12✔
515
            if "h" not in orientation.lower():
12✔
516
                components = components[::-1]
12✔
517
            label_axes(ax, labels=components)
12✔
518

519
        return ax
12✔
520

521

522
# ideally we would i) check for the same params and ii) aggregate all others across
523
# inherited or chained functions. This simply imports the params from another docstring
524
_add_additional_parameters = True
12✔
525

526
pyroplot.density.__doc__ = pyroplot.density.__doc__.format(
12✔
527
    otherparams=[
528
        "",
529
        get_additional_params(
530
            pyroplot.density,
531
            density.density,
532
            header="Other Parameters",
533
            indent=8,
534
            subsections=True,
535
        ),
536
    ][_add_additional_parameters]
537
)
538

539
pyroplot.parallel.__doc__ = pyroplot.parallel.__doc__.format(
12✔
540
    otherparams=[
541
        "",
542
        get_additional_params(
543
            pyroplot.parallel,
544
            parallel.parallel,
545
            header="Other Parameters",
546
            indent=8,
547
            subsections=True,
548
        ),
549
    ][_add_additional_parameters]
550
)
551

552

553
pyroplot.REE.__doc__ = pyroplot.REE.__doc__.format(
12✔
554
    otherparams=[
555
        "",
556
        get_additional_params(
557
            pyroplot.REE,
558
            spider.REE_v_radii,
559
            header="Other Parameters",
560
            indent=8,
561
            subsections=True,
562
        ),
563
    ][_add_additional_parameters]
564
)
565

566

567
pyroplot.scatter.__doc__ = pyroplot.scatter.__doc__.format(
12✔
568
    otherparams=[
569
        "",
570
        get_additional_params(
571
            pyroplot.scatter,
572
            plt.scatter,
573
            header="Other Parameters",
574
            indent=8,
575
            subsections=True,
576
        ),
577
    ][_add_additional_parameters]
578
)
579

580
pyroplot.plot.__doc__ = pyroplot.plot.__doc__.format(
12✔
581
    otherparams=[
582
        "",
583
        get_additional_params(
584
            pyroplot.plot,
585
            plt.plot,
586
            header="Other Parameters",
587
            indent=8,
588
            subsections=True,
589
        ),
590
    ][_add_additional_parameters]
591
)
592

593
pyroplot.spider.__doc__ = pyroplot.spider.__doc__.format(
12✔
594
    otherparams=[
595
        "",
596
        get_additional_params(
597
            pyroplot.spider,
598
            spider.spider,
599
            header="Other Parameters",
600
            indent=8,
601
            subsections=True,
602
        ),
603
    ][_add_additional_parameters]
604
)
605

606

607
pyroplot.stem.__doc__ = pyroplot.stem.__doc__.format(
12✔
608
    otherparams=[
609
        "",
610
        get_additional_params(
611
            pyroplot.stem,
612
            stem.stem,
613
            header="Other Parameters",
614
            indent=8,
615
            subsections=True,
616
        ),
617
    ][_add_additional_parameters]
618
)
619

620
pyroplot.heatscatter.__doc__ = pyroplot.heatscatter.__doc__.format(
12✔
621
    otherparams=[
622
        "",
623
        get_additional_params(
624
            pyroplot.scatter, header="Other Parameters", indent=8, subsections=True
625
        ),
626
    ][_add_additional_parameters]
627
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc