• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 20402260732

29 Oct 2024 02:20AM UTC coverage: 91.571% (-0.07%) from 91.64%
20402260732

push

github

morganjwilliams
Merge branch 'release/0.3.6' into main

53 of 62 new or added lines in 12 files covered. (85.48%)

3 existing lines in 2 files now uncovered.

6225 of 6798 relevant lines covered (91.57%)

5.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.19
/pyrolite/plot/__init__.py
1
"""
2
Submodule with various plotting and visualisation functions.
3
"""
4

5
import warnings
6✔
6

7
import matplotlib
6✔
8
import matplotlib.pyplot as plt
6✔
9
import mpltern
6✔
10
import numpy as np
6✔
11
import pandas as pd
6✔
12

13
warnings.filterwarnings("ignore", "Unknown section")
6✔
14

15
from .. import geochem
6✔
16
from ..comp.codata import ILR, close
6✔
17
from ..util.distributions import get_scaler, sample_kde
6✔
18
from ..util.log import Handle
6✔
19
from ..util.meta import get_additional_params, subkwargs
6✔
20
from ..util.pd import to_frame
6✔
21
from ..util.plot.axes import init_axes, label_axes
6✔
22
from ..util.plot.helpers import plot_cooccurence
6✔
23
from ..util.plot.style import _export_nonRCstyles, linekwargs, scatterkwargs
6✔
24
from . import density, parallel, spider, stem
6✔
25
from .color import process_color
6✔
26

27
logger = Handle(__name__)
6✔
28

29
# pyroplot added to __all__ for docs
30
__all__ = ["density", "spider", "pyroplot"]
6✔
31

32

33
def _check_components(obj, components=None, check_size=True, valid_sizes=[2, 3]):
6✔
34
    """
35
    Check that the components provided within a dataframe are consistent with the
36
    form of plot being used.
37

38
    Parameters
39
    ----------
40
    obj : :class:`pandas.DataFrame`
41
        Object to check.
42
    components : :class:`list`
43
        List of components, optionally specified.
44
    check_size : :class:`bool`
45
        Whether to verify the size of the column index.
46
    valid_sizes : :class:`list`
47
        Component list lengths which are valid for the plot type.
48

49
    Returns
50
    -------
51
    :class:`list`
52
        Components for the plot.
53
    """
54
    try:
6✔
55
        if check_size and (obj.columns.size not in valid_sizes):
6✔
56
            assert len(components) in valid_sizes
6✔
57

58
        if components is None:
6✔
59
            components = obj.columns.values
6✔
60
    except AssertionError:
6✔
UNCOV
61
        msg = "Suggest components or provide a slice of the dataframe."
×
UNCOV
62
        raise AssertionError(msg)
×
63
    return components
6✔
64

65

66
# note that only some of these methods will be valid for series
67
@pd.api.extensions.register_series_accessor("pyroplot")
6✔
68
@pd.api.extensions.register_dataframe_accessor("pyroplot")
6✔
69
class pyroplot(object):
6✔
70
    def __init__(self, obj):
6✔
71
        """
72
        Custom dataframe accessor for pyrolite plotting.
73

74
        Notes
75
        -----
76
            This accessor enables the coexistence of array-based plotting functions and
77
            methods for pandas objects. This enables some separation of concerns.
78
        """
79
        self._validate(obj)
6✔
80
        self._obj = obj
6✔
81

82
        # refresh custom styling on creation?
83
        _export_nonRCstyles()
6✔
84

85
    @staticmethod
6✔
86
    def _validate(obj):
6✔
87
        pass
6✔
88

89
    def cooccurence(self, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
6✔
90
        """
91
        Plot the co-occurence frequency matrix for a given input.
92

93
        Parameters
94
        ----------
95
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
96
            The subplot to draw on.
97
        normalize : :class:`bool`
98
            Whether to normalize the cooccurence to compare disparate variables.
99
        log : :class:`bool`
100
            Whether to take the log of the cooccurence.
101
        colorbar : :class:`bool`
102
            Whether to append a colorbar.
103

104
        Returns
105
        -------
106
        :class:`matplotlib.axes.Axes`
107
            Axes on which the cooccurence plot is added.
108

109
        """
110
        obj = to_frame(self._obj)
6✔
111
        ax = plot_cooccurence(
6✔
112
            obj.values, ax=ax, normalize=normalize, log=log, colorbar=colorbar, **kwargs
113
        )
114
        ax.set_xticklabels(obj.columns, minor=False, rotation=90)
6✔
115
        ax.set_yticklabels(obj.columns, minor=False)
6✔
116
        return ax
6✔
117

118
    def density(self, components: list = None, ax=None, axlabels=True, **kwargs):
6✔
119
        r"""
120
        Method for plotting histograms (mode='hist2d'|'hexbin') or kernel density
121
        esitimates from point data. Convenience access function to
122
        :func:`~pyrolite.plot.density.density` (see `Other Parameters`, below), where
123
        further parameters for relevant `matplotlib` functions are also listed.
124

125
        Parameters
126
        ----------
127
        components : :class:`list`, :code:`None`
128
            Elements or compositional components to plot.
129
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
130
            The subplot to draw on.
131
        axlabels : :class:`bool`, True
132
            Whether to add x-y axis labels.
133

134
        {otherparams}
135

136
        Returns
137
        -------
138
        :class:`matplotlib.axes.Axes`
139
            Axes on which the density diagram is plotted.
140

141
        """
142
        obj = to_frame(self._obj)
6✔
143
        components = _check_components(obj, components=components)
6✔
144

145
        ax = density.density(
6✔
146
            obj.reindex(columns=components).astype(float).values, ax=ax, **kwargs
147
        )
148
        if axlabels:
6✔
149
            label_axes(ax, labels=components)
6✔
150

151
        return ax
6✔
152

153
    def heatscatter(
6✔
154
        self,
155
        components: list = None,
156
        ax=None,
157
        axlabels=True,
158
        logx=False,
159
        logy=False,
160
        **kwargs,
161
    ):
162
        r"""
163
        Heatmapped scatter plots using the pyroplot API. See further parameters
164
        for `matplotlib.pyplot.scatter` function below.
165

166
        Parameters
167
        ----------
168
        components : :class:`list`, :code:`None`
169
            Elements or compositional components to plot.
170
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
171
            The subplot to draw on.
172
        axlabels : :class:`bool`, :code:`True`
173
            Whether to add x-y axis labels.
174
        logx : :class:`bool`, `False`
175
            Whether to log-transform x values before the KDE for bivariate plots.
176
        logy : :class:`bool`, `False`
177
            Whether to log-transform y values before the KDE for bivariate plots.
178

179
        {otherparams}
180

181
        Returns
182
        -------
183
        :class:`matplotlib.axes.Axes`
184
            Axes on which the heatmapped scatterplot is added.
185

186
        """
187
        obj = to_frame(self._obj)
6✔
188
        components = _check_components(obj, components=components)
6✔
189
        data, samples = (
6✔
190
            obj.reindex(columns=components).values,
191
            obj.reindex(columns=components).values,
192
        )
193
        kdetfm = [  # log transforms
6✔
194
            get_scaler([None, np.log][logx], [None, np.log][logy]),
195
            lambda x: ILR(close(x)),
196
        ][len(components) == 3]
197
        zi = sample_kde(
6✔
198
            data, samples, transform=kdetfm, **subkwargs(kwargs, sample_kde)
199
        )
200
        kwargs.update({"c": zi})
6✔
201
        ax = obj.reindex(columns=components).pyroplot.scatter(
6✔
202
            ax=ax, axlabels=axlabels, **kwargs
203
        )
204
        return ax
6✔
205

206
    def parallel(
6✔
207
        self,
208
        components=None,
209
        rescale=False,
210
        legend=False,
211
        ax=None,
212
        **kwargs,
213
    ):
214
        """
215
        Create a :func:`pyrolite.plot.parallel.parallel`. coordinate plot from
216
        the columns of the :class:`~pandas.DataFrame`.
217

218
        Parameters
219
        ----------
220
        components : :class:`list`, :code:`None`
221
            Components to use as axes for the plot.
222
        rescale : :class:`bool`
223
            Whether to rescale values to [-1, 1].
224
        legend : :class:`bool`, :code:`False`
225
            Whether to include or suppress the legend.
226
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
227
            The subplot to draw on.
228
        {otherparams}
229

230
        Returns
231
        -------
232
        :class:`matplotlib.axes.Axes`
233
            Axes on which the parallel coordinates plot is added.
234

235
        Todo
236
        ----
237
        * Adapt figure size based on number of columns.
238

239
        """
240

241
        obj = to_frame(self._obj)
6✔
242
        ax = parallel.parallel(
6✔
243
            obj,
244
            components=components,
245
            rescale=rescale,
246
            legend=legend,
247
            ax=ax,
248
            **kwargs,
249
        )
250
        return ax
6✔
251

252
    def plot(self, components: list = None, ax=None, axlabels=True, **kwargs):
6✔
253
        r"""
254
        Convenience method for line plots using the pyroplot API. See
255
        further parameters for `matplotlib.pyplot.scatter` function below.
256

257
        Parameters
258
        ----------
259
        components : :class:`list`, :code:`None`
260
            Elements or compositional components to plot.
261
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
262
            The subplot to draw on.
263
        axlabels : :class:`bool`, :code:`True`
264
            Whether to add x-y axis labels.
265
        {otherparams}
266

267
        Returns
268
        -------
269
        :class:`matplotlib.axes.Axes`
270
            Axes on which the plot is added.
271

272
        """
273
        obj = to_frame(self._obj)
×
274
        components = _check_components(obj, components=components)
×
275
        projection = [None, "ternary"][len(components) == 3]
×
276
        ax = init_axes(ax=ax, projection=projection, **kwargs)
×
277
        kw = linekwargs(kwargs)
×
278
        ax.plot(*obj.reindex(columns=components).values.T, **kw)
×
279
        # if color is multi, could update line colors here
280
        if axlabels:
×
281
            label_axes(ax, labels=components)
×
282

283
        ax.tick_params("both")
×
284
        # ax.grid()
285
        # ax.set_aspect("equal")
286
        return ax
×
287

288
    def REE(
6✔
289
        self,
290
        index="elements",
291
        ax=None,
292
        mode="plot",
293
        dropPm=True,
294
        scatter_kw={},
295
        line_kw={},
296
        **kwargs,
297
    ):
298
        """Pass the pandas object to :func:`pyrolite.plot.spider.REE_v_radii`.
299

300
        Parameters
301
        ----------
302
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
303
            The subplot to draw on.
304
        index : :class:`str`
305
            Whether to plot radii ('radii') on the principal x-axis, or elements
306
            ('elements').
307
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
308
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
309
            a filled range. Density will return a conditional density diagram.
310
        dropPm : :class:`bool`
311
            Whether to exclude the (almost) non-existent element Promethium from the REE
312
            list.
313
        scatter_kw : :class:`dict`
314
            Keyword parameters to be passed to the scatter plotting function.
315
        line_kw : :class:`dict`
316
            Keyword parameters to be passed to the line plotting function.
317
        {otherparams}
318

319
        Returns
320
        -------
321
        :class:`matplotlib.axes.Axes`
322
            Axes on which the REE plot is added.
323

324
        """
325
        obj = to_frame(self._obj)
6✔
326
        ree = [i for i in geochem.ind.REE(dropPm=dropPm) if i in obj.columns]
6✔
327

328
        ax = spider.REE_v_radii(
6✔
329
            obj.reindex(columns=ree).astype(float).values,
330
            index=index,
331
            ree=ree,
332
            mode=mode,
333
            ax=ax,
334
            scatter_kw=scatter_kw,
335
            line_kw=line_kw,
336
            **kwargs,
337
        )
338
        ax.set_ylabel(r"$\mathrm{X / X_{Reference}}$")
6✔
339
        return ax
6✔
340

341
    def scatter(self, components: list = None, ax=None, axlabels=True, **kwargs):
6✔
342
        r"""
343
        Convenience method for scatter plots using the pyroplot API. See
344
        further parameters for `matplotlib.pyplot.scatter` function below.
345

346
        Parameters
347
        ----------
348
        components : :class:`list`, :code:`None`
349
            Elements or compositional components to plot.
350
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
351
            The subplot to draw on.
352
        axlabels : :class:`bool`, :code:`True`
353
            Whether to add x-y axis labels.
354
        {otherparams}
355

356
        Returns
357
        -------
358
        :class:`matplotlib.axes.Axes`
359
            Axes on which the scatterplot is added.
360

361
        """
362
        obj = to_frame(self._obj)
6✔
363
        components = _check_components(obj, components=components)
6✔
364

365
        projection = [None, "ternary"][len(components) == 3]
6✔
366
        ax = init_axes(ax=ax, projection=projection, **kwargs)
6✔
367
        size = obj.index.size
6✔
368
        kw = process_color(size=size, **kwargs)
6✔
369
        with warnings.catch_warnings():
6✔
370
            # ternary transform where points add to zero will give an unnecessary
371
            # warning; here we supress it
372
            warnings.filterwarnings(
6✔
373
                "ignore", message="invalid value encountered in divide"
374
            )
375
            ax.scatter(*obj.reindex(columns=components).values.T, **scatterkwargs(kw))
6✔
376

377
        if axlabels:
6✔
378
            label_axes(ax, labels=components)
6✔
379

380
        ax.tick_params("both")
6✔
381
        # ax.grid()
382
        # ax.set_aspect("equal")
383
        return ax
6✔
384

385
    def spider(
6✔
386
        self,
387
        components: list = None,
388
        indexes: list = None,
389
        ax=None,
390
        mode="plot",
391
        index_order=None,
392
        autoscale=True,
393
        scatter_kw={},
394
        line_kw={},
395
        **kwargs,
396
    ):
397
        r"""
398
        Method for spider plots. Convenience access function to
399
        :func:`~pyrolite.plot.spider.spider` (see `Other Parameters`, below), where
400
        further parameters for relevant `matplotlib` functions are also listed.
401

402
        Parameters
403
        ----------
404
        components : :class:`list`, `None`
405
            Elements or compositional components to plot.
406
        indexes :  :class:`list`, `None`
407
            Elements or compositional components to plot.
408
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
409
            The subplot to draw on.
410
        index_order
411
            Function to order spider plot indexes (e.g. by incompatibility).
412
        autoscale : :class:`bool`
413
            Whether to autoscale the y-axis limits for standard spider plots.
414
        mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
415
            Mode for plot. Plot will produce a line-scatter diagram. Fill will return
416
            a filled range. Density will return a conditional density diagram.
417
        scatter_kw : :class:`dict`
418
            Keyword parameters to be passed to the scatter plotting function.
419
        line_kw : :class:`dict`
420
            Keyword parameters to be passed to the line plotting function.
421
        {otherparams}
422

423
        Returns
424
        -------
425
        :class:`matplotlib.axes.Axes`
426
            Axes on which the spider diagram is plotted.
427

428
        Todo
429
        ----
430
            * Add 'compositional data' filter for default components if None is given
431

432
        """
433
        obj = to_frame(self._obj)
6✔
434

435
        if components is None:  # default to plotting elemental data
6✔
436
            components = [
6✔
437
                el for el in obj.columns if el in geochem.ind.common_elements()
438
            ]
439

440
        assert len(components) != 0
6✔
441

442
        if index_order is not None:
6✔
443
            if isinstance(index_order, str):
×
444
                try:
×
445
                    index_order = geochem.ind.ordering[index_order]
×
446
                except KeyError:
×
447
                    msg = (
×
448
                        "Ordering not applied, as parameter '{}' not recognized."
449
                        " Select from: {}"
450
                    ).format(index_order, ", ".join(list(geochem.ind.ordering.keys())))
451
                    logger.warning(msg)
×
452
                components = index_order(components)
×
453
            else:
454
                components = index_order(components)
×
455

456
        ax = init_axes(ax=ax, **kwargs)
6✔
457

458
        if hasattr(ax, "_pyrolite_components"):
6✔
459
            # TODO: handle spider diagrams which have specified components
460
            pass
461

462
        ax = spider.spider(
6✔
463
            obj.reindex(columns=components).astype(float).values,
464
            indexes=indexes,
465
            ax=ax,
466
            mode=mode,
467
            autoscale=autoscale,
468
            scatter_kw=scatter_kw,
469
            line_kw=line_kw,
470
            **kwargs,
471
        )
472
        ax._pyrolite_components = components
6✔
473
        ax.set_xticklabels(components, rotation=60)
6✔
474
        return ax
6✔
475

476
    def stem(
6✔
477
        self,
478
        components: list = None,
479
        ax=None,
480
        orientation="horizontal",
481
        axlabels=True,
482
        **kwargs,
483
    ):
484
        r"""
485
        Method for creating stem plots. Convenience access function to
486
        :func:`~pyrolite.plot.stem.stem` (see `Other Parameters`, below), where
487
        further parameters for relevant `matplotlib` functions are also listed.
488

489
        Parameters
490
        ----------
491
        components : :class:`list`, :code:`None`
492
            Elements or compositional components to plot.
493
        ax : :class:`matplotlib.axes.Axes`, :code:`None`
494
            The subplot to draw on.
495
        orientation : :class:`str`
496
            Orientation of the plot (horizontal or vertical).
497
        axlabels : :class:`bool`, True
498
            Whether to add x-y axis labels.
499
        {otherparams}
500

501
        Returns
502
        -------
503
        :class:`matplotlib.axes.Axes`
504
            Axes on which the stem diagram is plotted.
505
        """
506
        obj = to_frame(self._obj)
6✔
507
        components = _check_components(obj, components=components, valid_sizes=[2])
6✔
508

509
        ax = stem.stem(
6✔
510
            *obj.reindex(columns=components).values.T,
511
            ax=ax,
512
            orientation=orientation,
513
            **process_color(**kwargs),
514
        )
515

516
        if axlabels:
6✔
517
            if "h" not in orientation.lower():
6✔
518
                components = components[::-1]
6✔
519
            label_axes(ax, labels=components)
6✔
520

521
        return ax
6✔
522

523

524
# ideally we would i) check for the same params and ii) aggregate all others across
525
# inherited or chained functions. This simply imports the params from another docstring
526
_add_additional_parameters = True
6✔
527

528
pyroplot.density.__doc__ = pyroplot.density.__doc__.format(
6✔
529
    otherparams=[
530
        "",
531
        get_additional_params(
532
            pyroplot.density,
533
            density.density,
534
            header="Other Parameters",
535
            indent=8,
536
            subsections=True,
537
        ),
538
    ][_add_additional_parameters]
539
)
540

541
pyroplot.parallel.__doc__ = pyroplot.parallel.__doc__.format(
6✔
542
    otherparams=[
543
        "",
544
        get_additional_params(
545
            pyroplot.parallel,
546
            parallel.parallel,
547
            header="Other Parameters",
548
            indent=8,
549
            subsections=True,
550
        ),
551
    ][_add_additional_parameters]
552
)
553

554

555
pyroplot.REE.__doc__ = pyroplot.REE.__doc__.format(
6✔
556
    otherparams=[
557
        "",
558
        get_additional_params(
559
            pyroplot.REE,
560
            spider.REE_v_radii,
561
            header="Other Parameters",
562
            indent=8,
563
            subsections=True,
564
        ),
565
    ][_add_additional_parameters]
566
)
567

568

569
pyroplot.scatter.__doc__ = pyroplot.scatter.__doc__.format(
6✔
570
    otherparams=[
571
        "",
572
        get_additional_params(
573
            pyroplot.scatter,
574
            plt.scatter,
575
            header="Other Parameters",
576
            indent=8,
577
            subsections=True,
578
        ),
579
    ][_add_additional_parameters]
580
)
581

582
pyroplot.plot.__doc__ = pyroplot.plot.__doc__.format(
6✔
583
    otherparams=[
584
        "",
585
        get_additional_params(
586
            pyroplot.plot,
587
            plt.plot,
588
            header="Other Parameters",
589
            indent=8,
590
            subsections=True,
591
        ),
592
    ][_add_additional_parameters]
593
)
594

595
pyroplot.spider.__doc__ = pyroplot.spider.__doc__.format(
6✔
596
    otherparams=[
597
        "",
598
        get_additional_params(
599
            pyroplot.spider,
600
            spider.spider,
601
            header="Other Parameters",
602
            indent=8,
603
            subsections=True,
604
        ),
605
    ][_add_additional_parameters]
606
)
607

608

609
pyroplot.stem.__doc__ = pyroplot.stem.__doc__.format(
6✔
610
    otherparams=[
611
        "",
612
        get_additional_params(
613
            pyroplot.stem,
614
            stem.stem,
615
            header="Other Parameters",
616
            indent=8,
617
            subsections=True,
618
        ),
619
    ][_add_additional_parameters]
620
)
621

622
pyroplot.heatscatter.__doc__ = pyroplot.heatscatter.__doc__.format(
6✔
623
    otherparams=[
624
        "",
625
        get_additional_params(
626
            pyroplot.scatter, header="Other Parameters", indent=8, subsections=True
627
        ),
628
    ][_add_additional_parameters]
629
)
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc