• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 17569144994

09 Sep 2025 01:39AM UTC coverage: 90.14% (+0.06%) from 90.077%
17569144994

push

github

morganjwilliams
Update example for interactive plotting API

0 of 30 new or added lines in 2 files covered. (0.0%)

53 existing lines in 6 files now uncovered.

6226 of 6907 relevant lines covered (90.14%)

10.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pyrolite/util/plot/plotly.py
1
"""
2
Plotly backend for a pandas-based pyrolite plot accessor.
3

4
Todo
5
-----
6
* Make margins smaller
7
* Enable passing labels to markers
8
* Make plot variant for density plots
9
"""
10

11
import warnings
×
12

13
import matplotlib.colors
×
14
import numpy as np
×
15
import plotly.graph_objects as go
×
16

17
from ... import geochem
×
18
from ...comp.codata import ILR, close
×
19
from ...plot.color import process_color
×
20
from ..distributions import get_scaler, sample_kde
×
21
from ..log import Handle
×
22
from ..meta import get_additional_params, subkwargs
×
23
from ..pd import _check_components, to_frame
×
24

25
logger = Handle(__name__)
×
26

27

28
def to_plotly_color(color, alpha=1):
×
29
    # note that alpha isn't 255 scaled
30
    return "rgba" + str(
×
31
        tuple([int(i * 255) for i in matplotlib.colors.to_rgb(color)] + [alpha])
32
    )
33

34

35
class pyroplot_plotly(object):
×
36
    def __init__(self, obj):
×
37
        """
38
        Custom dataframe accessor for pyrolite plotting.
39

40
        Notes
41
        -----
42
            This accessor enables the coexistence of array-based plotting functions and
43
            methods for pandas objects. This enables some separation of concerns.
44
        """
45
        self._validate(obj)
×
46
        self._obj = obj
×
47

48
    @staticmethod
×
49
    def _validate(obj):
×
50
        pass
×
51

52
    def scatter(self, color="black", alpha=1, **kwargs):
×
NEW
53
        if self._obj.columns.size == 3:
×
NEW
54
            return self._ternary(color=color, alpha=alpha, **kwargs)
×
55
        else:
NEW
56
            layout = dict(width=600, plot_bgcolor="white")
×
NEW
57
            fig = go.Figure()
×
NEW
58
            marker = dict(color=to_plotly_color(color, alpha=alpha))
×
NEW
59
            traces = [
×
60
                go.Scatter(
61
                    x=self._obj.iloc[:, 0],
62
                    y=self._obj.iloc[:, 1],
63
                    mode="markers",
64
                    marker=marker,
65
                    showlegend=False,
66
                    text=self._obj.index.map("Sample {}".format),
67
                )
68
            ]
NEW
69
            fig.add_traces(traces)
×
NEW
70
            fig.update_layout(layout)
×
NEW
71
            fig.update_xaxes(
×
72
                linecolor="black", mirror=True, title=self._obj.columns[0]
73
            )  # todo: add this to layout
NEW
74
            fig.update_yaxes(
×
75
                linecolor="black", mirror=True, title=self._obj.columns[1]
76
            )  # todo: add this to layout
NEW
77
            return fig
×
78

NEW
79
    def _ternary(self, color="black", alpha=1, **kwargs):
×
UNCOV
80
        layout = dict(
×
81
            width=600,
82
            plot_bgcolor="white",
83
            ternary={
84
                **{
85
                    a: {
86
                        "title": c,
87
                        "showgrid": False,
88
                        "linecolor": "black",
89
                    }
90
                    for a, c in zip(["aaxis", "baxis", "caxis"], self._obj.columns)
91
                },
92
                "bgcolor": "white",
93
            },
94
        )
UNCOV
95
        layout.update(kwargs)
×
UNCOV
96
        marker = {"color": to_plotly_color(color, alpha=alpha)}
×
UNCOV
97
        data = {
×
98
            "mode": "markers",
99
            **dict(zip("abc", [self._obj[c] for c in self._obj.columns])),
100
            "text": self._obj.index.values,
101
            "marker": marker,
102
        }
UNCOV
103
        fig = go.Figure(go.Scatterternary(data))
×
104

UNCOV
105
        fig.update_layout(layout)
×
106
        return fig
×
107

108
    def spider(self, color="black", unity_line=True, alpha=1, text=None, **kwargs):
×
109
        layout = dict(width=600, plot_bgcolor="white")
×
UNCOV
110
        fig = go.Figure()
×
111
        line = dict(color=to_plotly_color(color, alpha=alpha))
×
112
        # hovertemplate = "%{text}<br><extra></extra>" if (text is not None) else None
113
        traces = [
×
114
            go.Scatter(
115
                x=self._obj.columns,
116
                y=row,
117
                mode="lines+markers",
118
                line=line,
119
                showlegend=False,
120
                hoverinfo="text",
121
                # hovertemplate =hovertemplate if (text is not None) else None,
122
                text=None if text is None else text[idx],
123
                name="Sample {}".format(idx),
124
            )
125
            for idx, row in self._obj.iterrows()
126
        ]
UNCOV
127
        if unity_line:
×
UNCOV
128
            traces += [
×
129
                go.Scatter(
130
                    x=self._obj.columns,
131
                    y=np.ones(self._obj.columns.size),
132
                    mode="lines",
133
                    showlegend=False,
134
                    name=None,
135
                    line={"color": "black", "dash": "dot", "width": 0.5},
136
                )
137
            ]
UNCOV
138
        fig.add_traces(traces)
×
UNCOV
139
        fig.update_layout(**layout)
×
UNCOV
140
        fig.update_yaxes(
×
141
            type="log", linecolor="black", mirror=True
142
        )  # todo: add this to layout
143
        fig.update_xaxes(linecolor="black", mirror=True)  # todo: add this to layout
×
UNCOV
144
        return fig
×
145

146

147
# class pyroplot_plotly(object):
148
#     def __init__(self, obj):
149
#         """
150
#         Custom dataframe accessor for pyrolite plotting.
151

152
#         Notes
153
#         -----
154
#             This accessor enables the coexistence of array-based plotting functions and
155
#             methods for pandas objects. This enables some separation of concerns.
156
#         """
157
#         self._validate(obj)
158
#         self._obj = obj
159

160
#     @staticmethod
161
#     def _validate(obj):
162
#         pass
163

164
#     def heatscatter(
165
#         self,
166
#         components: list = None,
167
#         ax=None,
168
#         axlabels=True,
169
#         logx=False,
170
#         logy=False,
171
#         **kwargs,
172
#     ):
173
#         r"""
174
#         Heatmapped scatter plots using the pyroplot API. See further parameters
175
#         for `matplotlib.pyplot.scatter` function below.
176

177
#         Parameters
178
#         ----------
179
#         components : :class:`list`, :code:`None`
180
#             Elements or compositional components to plot.
181
#         ax : :class:`matplotlib.axes.Axes`, :code:`None`
182
#             The subplot to draw on.
183
#         axlabels : :class:`bool`, :code:`True`
184
#             Whether to add x-y axis labels.
185
#         logx : :class:`bool`, `False`
186
#             Whether to log-transform x values before the KDE for bivariate plots.
187
#         logy : :class:`bool`, `False`
188
#             Whether to log-transform y values before the KDE for bivariate plots.
189

190
#         {otherparams}
191

192
#         Returns
193
#         -------
194
#         :class:`matplotlib.axes.Axes`
195
#             Axes on which the heatmapped scatterplot is added.
196

197
#         """
198
#         obj = to_frame(self._obj)
199
#         components = _check_components(obj, components=components)
200
#         data, samples = (
201
#             obj.reindex(columns=components).values,
202
#             obj.reindex(columns=components).values,
203
#         )
204
#         kdetfm = [  # log transforms
205
#             get_scaler([None, np.log][logx], [None, np.log][logy]),
206
#             lambda x: ILR(close(x)),
207
#         ][len(components) == 3]
208
#         zi = sample_kde(
209
#             data, samples, transform=kdetfm, **subkwargs(kwargs, sample_kde)
210
#         )
211
#         kwargs.update({"c": zi})
212
#         ax = obj.reindex(columns=components).pyroplot.scatter(
213
#             ax=ax, axlabels=axlabels, **kwargs
214
#         )
215
#         return ax
216

217
#     def plot(self, components: list = None, ax=None, axlabels=True, **kwargs):
218
#         r"""
219
#         Convenience method for line plots using the pyroplot API. See
220
#         further parameters for `matplotlib.pyplot.scatter` function below.
221

222
#         Parameters
223
#         ----------
224
#         components : :class:`list`, :code:`None`
225
#             Elements or compositional components to plot.
226
#         ax : :class:`matplotlib.axes.Axes`, :code:`None`
227
#             The subplot to draw on.
228
#         axlabels : :class:`bool`, :code:`True`
229
#             Whether to add x-y axis labels.
230
#         {otherparams}
231

232
#         Returns
233
#         -------
234
#         :class:`matplotlib.axes.Axes`
235
#             Axes on which the plot is added.
236

237
#         """
238
#         obj = to_frame(self._obj)
239
#         components = _check_components(obj, components=components)
240
#         projection = [None, "ternary"][len(components) == 3]
241
#         # ax = init_axes(ax=ax, projection=projection, **kwargs)
242
#         # kw = linekwargs(kwargs)
243
#         ax.plot(*obj.reindex(columns=components).values.T, **kw)
244
#         # if color is multi, could update line colors here
245
#         # if axlabels:
246
#         #    label_axes(ax, labels=components)
247

248
#         ax.tick_params("both")
249
#         # ax.grid()
250
#         # ax.set_aspect("equal")
251
#         return ax
252

253
#     def REE(
254
#         self,
255
#         index="elements",
256
#         ax=None,
257
#         mode="plot",
258
#         dropPm=True,
259
#         scatter_kw={},
260
#         line_kw={},
261
#         **kwargs,
262
#     ):
263
#         """Pass the pandas object to :func:`pyrolite.plot.spider.REE_v_radii`.
264

265
#         Parameters
266
#         ----------
267
#         ax : :class:`matplotlib.axes.Axes`, :code:`None`
268
#             The subplot to draw on.
269
#         index : :class:`str`
270
#             Whether to plot radii ('radii') on the principal x-axis, or elements
271
#             ('elements').
272
#         mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
273
#             Mode for plot. Plot will produce a line-scatter diagram. Fill will return
274
#             a filled range. Density will return a conditional density diagram.
275
#         dropPm : :class:`bool`
276
#             Whether to exclude the (almost) non-existent element Promethium from the REE
277
#             list.
278
#         scatter_kw : :class:`dict`
279
#             Keyword parameters to be passed to the scatter plotting function.
280
#         line_kw : :class:`dict`
281
#             Keyword parameters to be passed to the line plotting function.
282
#         {otherparams}
283

284
#         Returns
285
#         -------
286
#         :class:`matplotlib.axes.Axes`
287
#             Axes on which the REE plot is added.
288

289
#         """
290
#         obj = to_frame(self._obj)
291
#         ree = [i for i in geochem.ind.REE(dropPm=dropPm) if i in obj.columns]
292

293
#         ax = spider.REE_v_radii(
294
#             obj.reindex(columns=ree).astype(float).values,
295
#             index=index,
296
#             ree=ree,
297
#             mode=mode,
298
#             ax=ax,
299
#             scatter_kw=scatter_kw,
300
#             line_kw=line_kw,
301
#             **kwargs,
302
#         )
303
#         ax.set_ylabel(r"$\mathrm{X / X_{Reference}}$")
304
#         return ax
305

306
#     def scatter(self, components: list = None, ax=None, axlabels=True, **kwargs):
307
#         r"""
308
#         Convenience method for scatter plots using the pyroplot API. See
309
#         further parameters for `matplotlib.pyplot.scatter` function below.
310

311
#         Parameters
312
#         ----------
313
#         components : :class:`list`, :code:`None`
314
#             Elements or compositional components to plot.
315
#         ax : :class:`matplotlib.axes.Axes`, :code:`None`
316
#             The subplot to draw on.
317
#         axlabels : :class:`bool`, :code:`True`
318
#             Whether to add x-y axis labels.
319
#         {otherparams}
320

321
#         Returns
322
#         -------
323
#         :class:`matplotlib.axes.Axes`
324
#             Axes on which the scatterplot is added.
325

326
#         """
327
#         obj = to_frame(self._obj)
328
#         components = _check_components(obj, components=components)
329

330
#         projection = [None, "ternary"][len(components) == 3]
331
#         # ax = init_axes(ax=ax, projection=projection, **kwargs)
332
#         size = obj.index.size
333
#         kw = process_color(size=size, **kwargs)
334
#         with warnings.catch_warnings():
335
#             # ternary transform where points add to zero will give an unnecessary
336
#             # warning; here we supress it
337
#             warnings.filterwarnings(
338
#                 "ignore", message="invalid value encountered in divide"
339
#             )
340
#             ax.scatter(*obj.reindex(columns=components).values.T, **kw)
341

342
#         # if axlabels:
343
#         #    label_axes(ax, labels=components)
344

345
#         ax.tick_params("both")
346
#         # ax.grid()
347
#         # ax.set_aspect("equal")
348
#         return ax
349

350
#     def spider(
351
#         self,
352
#         components: list = None,
353
#         indexes: list = None,
354
#         ax=None,
355
#         mode="plot",
356
#         index_order=None,
357
#         autoscale=True,
358
#         scatter_kw={},
359
#         line_kw={},
360
#         **kwargs,
361
#     ):
362
#         r"""
363
#         Method for spider plots. Convenience access function to
364
#         :func:`~pyrolite.plot.spider.spider` (see `Other Parameters`, below), where
365
#         further parameters for relevant `matplotlib` functions are also listed.
366

367
#         Parameters
368
#         ----------
369
#         components : :class:`list`, `None`
370
#             Elements or compositional components to plot.
371
#         indexes :  :class:`list`, `None`
372
#             Elements or compositional components to plot.
373
#         ax : :class:`matplotlib.axes.Axes`, :code:`None`
374
#             The subplot to draw on.
375
#         index_order
376
#             Function to order spider plot indexes (e.g. by incompatibility).
377
#         autoscale : :class:`bool`
378
#             Whether to autoscale the y-axis limits for standard spider plots.
379
#         mode : :class:`str`, :code`["plot", "fill", "binkde", "ckde", "kde", "hist"]`
380
#             Mode for plot. Plot will produce a line-scatter diagram. Fill will return
381
#             a filled range. Density will return a conditional density diagram.
382
#         scatter_kw : :class:`dict`
383
#             Keyword parameters to be passed to the scatter plotting function.
384
#         line_kw : :class:`dict`
385
#             Keyword parameters to be passed to the line plotting function.
386
#         {otherparams}
387

388
#         Returns
389
#         -------
390
#         :class:`matplotlib.axes.Axes`
391
#             Axes on which the spider diagram is plotted.
392

393
#         Todo
394
#         ----
395
#             * Add 'compositional data' filter for default components if None is given
396

397
#         """
398
#         obj = to_frame(self._obj)
399

400
#         if components is None:  # default to plotting elemental data
401
#             components = [
402
#                 el for el in obj.columns if el in geochem.ind.common_elements()
403
#             ]
404

405
#         assert len(components) != 0
406

407
#         if index_order is not None:
408
#             if isinstance(index_order, str):
409
#                 try:
410
#                     index_order = geochem.ind.ordering[index_order]
411
#                 except KeyError:
412
#                     msg = (
413
#                         "Ordering not applied, as parameter '{}' not recognized."
414
#                         " Select from: {}"
415
#                     ).format(index_order, ", ".join(list(geochem.ind.ordering.keys())))
416
#                     logger.warning(msg)
417
#                 components = index_order(components)
418
#             else:
419
#                 components = index_order(components)
420

421
#         # ax = init_axes(ax=ax, **kwargs)
422

423
#         if hasattr(ax, "_pyrolite_components"):
424
#             # TODO: handle spider diagrams which have specified components
425
#             pass
426

427
#         ax = spider.spider(
428
#             obj.reindex(columns=components).astype(float).values,
429
#             indexes=indexes,
430
#             ax=ax,
431
#             mode=mode,
432
#             autoscale=autoscale,
433
#             scatter_kw=scatter_kw,
434
#             line_kw=line_kw,
435
#             **kwargs,
436
#         )
437
#         ax._pyrolite_components = components
438
#         ax.set_xticklabels(components, rotation=60)
439
#         return ax
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc