• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 17569160869

09 Sep 2025 01:41AM UTC coverage: 91.465% (-0.1%) from 91.614%
17569160869

push

github

morganjwilliams
Add uncertainties, add optional deps for pyproject.toml; WIP demo NB

6226 of 6807 relevant lines covered (91.46%)

10.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.35
/pyrolite/util/plot/helpers.py
1
"""
2
matplotlib helper functions for commong drawing tasks.
3
"""
4

5
import matplotlib.patches
12✔
6
import matplotlib.pyplot as plt
12✔
7
import numpy as np
12✔
8
import scipy.spatial
12✔
9
from .center import visual_center
12✔
10

11
from ..log import Handle
12✔
12
from ..math import eigsorted, nancov
12✔
13
from ..missing import cooccurence_pattern
12✔
14
from ..text import int_to_alpha
12✔
15
from .axes import add_colorbar, init_axes, subaxes
12✔
16
from .interpolation import interpolated_patch_path
12✔
17

18
logger = Handle(__name__)
12✔
19

20
try:
12✔
21
    from sklearn.decomposition import PCA
12✔
22
except ImportError:
×
23
    msg = "scikit-learn not installed"
×
24
    logger.warning(msg)
×
25

26

27
def alphalabel_subplots(ax, fmt="{}", xy=(0.03, 0.95), ha="left", va="top", **kwargs):
12✔
28
    """
29
    Add alphabetical labels to a successive series of subplots with a specified format.
30

31
    Parameters
32
    -----------
33
    ax : :class:`list` | :class:`numpy.ndarray` | :class:`numpy.flatiter`
34
        Axes to label, in desired order.
35
    fmt : :class:`str`
36
        Format string to use. To add e.g. parentheses, you could specify :code:`"({})"`.
37
    xy : :class:`tuple`
38
        Position of the labels in axes coordinates.
39
    ha : :class:`str`
40
        Horizontal alignment of the labels (:code:`{"left", "right"}`).
41
    va : :class:`str`
42
        Vertical alignment of the labels (:code:`{"top", "bottom"}`).
43
    """
44
    flat = np.array(ax).flatten()
×
45
    # get axes in case of iterator which is consumed
46
    _ax = [(ix, flat[ix]) for ix in range(len(flat))]
×
47
    labels = [(a, fmt.format(int_to_alpha(ix))) for ix, a in _ax]
×
48
    [
×
49
        a.annotate(label, xy=xy, xycoords=a.transAxes, ha=ha, va=va, **kwargs)
50
        for a, label in labels
51
    ]
52

53

54
def get_centroid(poly):
12✔
55
    """
56
    Centroid of a closed polygon using the Shoelace formula.
57

58
    Parameters
59
    ----------
60
    poly : :class:`matplotlib.patches.Polygon`
61
        Polygon to obtain the centroid of.
62

63
    Returns
64
    -------
65
    cx, cy : :class:`tuple`
66
        Centroid coordinates.
67
    """
68
    # get signed area
69
    verts = poly.get_xy()
12✔
70
    A = 0
12✔
71
    cx, cy = 0, 0
12✔
72
    x, y = verts.T
12✔
73
    for i in range(len(verts) - 1):
12✔
74
        A += x[i] * y[i + 1] - x[i + 1] * y[i]
12✔
75
        cx += (x[i] + x[i + 1]) * (x[i] * y[i + 1] - x[i + 1] * y[i])
12✔
76
        cy += (y[i] + y[i + 1]) * (x[i] * y[i + 1] - x[i + 1] * y[i])
12✔
77
    A /= 2
12✔
78
    cx /= 6 * A
12✔
79
    cy /= 6 * A
12✔
80
    return cx, cy
12✔
81

82

83
def get_visual_center(poly, vertical_exaggeration=1):
12✔
84
    """
85
    Visual center of a closed polygon.
86

87
    Parameters
88
    ----------
89
    poly : :class:`matplotlib.patches.Polygon`
90
        Polygon for which to obtain the visual center.
91

92
    vertical_exaggeration : :class:`float`
93
        Apparent vertical exaggeration of the plot
94
        (pixels per unit in y direction divided by pixels
95
        per unit in the x direction).
96

97
    Returns
98
    -------
99
    cx, cy : :class:`tuple`
100
        Centroid coordinates.
101
    """
102
    poly_scaled = np.array([poly.get_xy() * [1.0, vertical_exaggeration]])
12✔
103
    x, y = visual_center(poly_scaled)
12✔
104
    return tuple([x, y / vertical_exaggeration])
12✔
105

106

107
def rect_from_centre(x, y, dx=0, dy=0, **kwargs):
12✔
108
    """
109
    Takes an xy point, and creates a rectangular patch centred about it.
110
    """
111
    # If either x or y is nan
112
    if any([np.isnan(i) for i in [x, y]]):
12✔
113
        return None
×
114
    if np.isnan(dx):
12✔
115
        dx = 0
×
116
    if np.isnan(dy):
12✔
117
        dy = 0
×
118
    llc = (x - dx, y - dy)
12✔
119
    return matplotlib.patches.Rectangle(llc, 2 * dx, 2 * dy, **kwargs)
12✔
120

121

122
def draw_vector(v0, v1, ax=None, **kwargs):
12✔
123
    """
124
    Plots an arrow represnting the direction and magnitue of a principal
125
    component on a biaxial plot.
126

127
    Modified after Jake VanderPlas' Python Data Science Handbook
128
    https://jakevdp.github.io/PythonDataScienceHandbook/ \
129
    05.09-principal-component-analysis.html
130

131
    Todo
132
    -----
133
        Update for ternary plots.
134

135
    """
136
    ax = ax
12✔
137
    arrowprops = dict(arrowstyle="->", linewidth=2, shrinkA=0, shrinkB=0)
12✔
138
    arrowprops.update(kwargs)
12✔
139
    ax.annotate("", v1, v0, arrowprops=arrowprops)
12✔
140

141

142
def vector_to_line(
12✔
143
    mu: np.array, vector: np.array, variance: float, spans: int = 4, expand: int = 10
144
):
145
    """
146
    Creates an array of points representing a line along a vector - typically
147
    for principal component analysis. Modified after Jake VanderPlas' Python Data
148
    Science Handbook https://jakevdp.github.io/PythonDataScienceHandbook/ \
149
    05.09-principal-component-analysis.html
150
    """
151
    length = np.sqrt(variance)
12✔
152
    parts = np.linspace(-spans, spans, expand * spans + 1)
12✔
153
    line = length * np.dot(parts[:, np.newaxis], vector[np.newaxis, :]) + mu
12✔
154
    line = length * parts.reshape(parts.shape[0], 1) * vector + mu
12✔
155
    return line
12✔
156

157

158
def plot_stdev_ellipses(
12✔
159
    comp, nstds=4, scale=100, resolution=1000, transform=None, ax=None, **kwargs
160
):
161
    """
162
    Plot covariance ellipses at a number of standard deviations from the mean.
163

164
    Parameters
165
    -------------
166
    comp : :class:`numpy.ndarray`
167
        Composition to use.
168
    nstds : :class:`int`
169
        Number of standard deviations from the mean for which to plot the ellipses.
170
    scale : :class:`float`
171
        Scale applying to all x-y data points. For intergration with python-ternary.
172
    transform : :class:`callable`
173
        Function for transformation of data prior to plotting (to either 2D or 3D).
174
    ax : :class:`matplotlib.axes.Axes`
175
        Axes to plot on.
176

177
    Returns
178
    -------
179
    ax :  :class:`matplotlib.axes.Axes`
180
    """
181
    mean, cov = np.nanmean(comp, axis=0), nancov(comp)
12✔
182
    vals, vecs = eigsorted(cov)
12✔
183
    theta = np.degrees(np.arctan2(*vecs[::-1]))
12✔
184

185
    if ax is None:
12✔
186
        projection = None
12✔
187
        if callable(transform) and (transform is not None):
12✔
188
            if transform(comp).shape[1] == 3:
12✔
189
                projection = "ternary"
12✔
190

191
        fig, ax = plt.subplots(1, subplot_kw=dict(projection=projection))
12✔
192

193
    for nstd in np.arange(1, nstds + 1)[::-1]:  # backwards for svg construction
12✔
194
        # here we use the absolute eigenvalues
195
        xsig, ysig = nstd * np.sqrt(np.abs(vals))  # n sigmas
12✔
196
        ell = matplotlib.patches.Ellipse(
12✔
197
            xy=mean.flatten(), width=2 * xsig, height=2 * ysig, angle=theta[:1]
198
        )
199
        points = interpolated_patch_path(ell, resolution=resolution).vertices
12✔
200

201
        if callable(transform) and (transform is not None):
12✔
202
            points = transform(points)  # transform to compositional data
12✔
203

204
        if points.shape[1] == 3:
12✔
205
            ax_transfrom = (ax.transData + ax.transTernaryAxes.inverted()).inverted()
12✔
206
            points = ax_transfrom.transform(points)  # transform to axes coords
12✔
207

208
        patch = matplotlib.patches.PathPatch(matplotlib.path.Path(points), **kwargs)
12✔
209
        patch.set_edgecolor("k")
12✔
210
        patch.set_alpha(1.0 / nstd)
12✔
211
        patch.set_linewidth(0.5)
12✔
212
        ax.add_artist(patch)
12✔
213
    return ax
12✔
214

215

216
def plot_pca_vectors(
12✔
217
    comp,
218
    nstds=2,
219
    scale=100.0,
220
    transform=None,
221
    ax=None,
222
    colors=None,
223
    linestyles=None,
224
    **kwargs,
225
):
226
    """
227
    Plot vectors corresponding to principal components and their magnitudes.
228

229
    Parameters
230
    -------------
231
    comp : :class:`numpy.ndarray`
232
        Composition to use.
233
    nstds : :class:`int`
234
        Multiplier for magnitude of individual principal component vectors.
235
    scale : :class:`float`
236
        Scale applying to all x-y data points. For intergration with python-ternary.
237
    transform : :class:`callable`
238
        Function for transformation of data prior to plotting (to either 2D or 3D).
239
    ax : :class:`matplotlib.axes.Axes`
240
        Axes to plot on.
241

242
    Returns
243
    -------
244
    ax :  :class:`matplotlib.axes.Axes`
245

246
    Todo
247
    -----
248
        * Minor reimplementation of the sklearn PCA to avoid dependency.
249

250
            https://en.wikipedia.org/wiki/Principal_component_analysis
251
    """
252
    pca = PCA(n_components=2)
12✔
253
    pca.fit(comp)
12✔
254

255
    if ax is None:
12✔
256
        fig, ax = plt.subplots(1)
12✔
257

258
    items = [pca.explained_variance_, pca.components_]
12✔
259
    if linestyles is not None:
12✔
260
        assert len(linestyles) == 2
×
261
        items.append(linestyles)
×
262
    else:
263
        items.append([None, None])
12✔
264
    if colors is not None:
12✔
265
        assert len(colors) == 2
×
266
        items.append(colors)
×
267
    else:
268
        items.append([None, None])
12✔
269
    for variance, vector, linestyle, color in zip(*items):
12✔
270
        line = vector_to_line(pca.mean_, vector, variance, spans=nstds)
12✔
271
        if callable(transform) and (transform is not None):
12✔
272
            line = transform(line)
12✔
273
        line *= scale
12✔
274
        kw = {**kwargs}
12✔
275
        if color is not None:
12✔
276
            kw["color"] = color
×
277
        if linestyle is not None:
12✔
278
            kw["ls"] = linestyle
×
279
        ax.plot(*line.T, **kw)
12✔
280
    return ax
12✔
281

282

283
def plot_2dhull(data, ax=None, splines=False, s=0, **plotkwargs):
12✔
284
    """
285
    Plots a 2D convex hull around an array of xy data points.
286
    """
287
    if ax is None:
12✔
288
        fig, ax = plt.subplots(1)
×
289
    chull = scipy.spatial.ConvexHull(data, incremental=True)
12✔
290
    x, y = data[chull.vertices].T
12✔
291
    if not splines:
12✔
292
        lines = ax.plot(np.append(x, [x[0]]), np.append(y, [y[0]]), **plotkwargs)
12✔
293
    else:
294
        # https://stackoverflow.com/questions/33962717/interpolating-a-closed-curve-using-scipy
295
        tck, u = scipy.interpolate.splprep([x, y], per=True, s=s)
12✔
296
        xi, yi = scipy.interpolate.splev(np.linspace(0, 1, 1000), tck)
12✔
297
        lines = ax.plot(xi, yi, **plotkwargs)
12✔
298
    return lines
12✔
299

300

301
def plot_cooccurence(arr, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
12✔
302
    """
303
    Plot the co-occurence frequency matrix for a given input.
304

305
    Parameters
306
    -----------
307
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
308
        The subplot to draw on.
309
    normalize : :class:`bool`
310
        Whether to normalize the cooccurence to compare disparate variables.
311
    log : :class:`bool`
312
        Whether to take the log of the cooccurence.
313
    colorbar : :class:`bool`
314
        Whether to append a colorbar.
315

316
    Returns
317
    --------
318
    :class:`matplotlib.axes.Axes`
319
        Axes on which the cooccurence plot is added.
320
    """
321
    arr = np.array(arr)
12✔
322
    if ax is None:
12✔
323
        fig, ax = plt.subplots(1, figsize=(4 + [0.0, 0.2][colorbar], 4))
12✔
324
    co_occur = cooccurence_pattern(arr, normalize=normalize, log=log)
12✔
325
    heatmap = ax.pcolor(co_occur, **kwargs)
12✔
326
    ax.set_yticks(np.arange(co_occur.shape[0]) + 0.5, minor=False)
12✔
327
    ax.set_xticks(np.arange(co_occur.shape[1]) + 0.5, minor=False)
12✔
328
    ax.invert_yaxis()
12✔
329
    ax.xaxis.tick_top()
12✔
330
    if colorbar:
12✔
331
        add_colorbar(heatmap, **kwargs)
12✔
332
    return ax
12✔
333

334

335
def nan_scatter(xdata, ydata, ax=None, axes_width=0.2, **kwargs):
12✔
336
    """
337
    Scatter plot with additional marginal axes to plot data for which data is partially
338
    missing. Additional keyword arguments are passed to matplotlib.
339

340
    Parameters
341
    ----------
342
    xdata : :class:`numpy.ndarray`
343
        X data
344
    ydata: class:`numpy.ndarray` | pd.Series
345
        Y data
346
    ax : :class:`matplotlib.axes.Axes`
347
        Axes on which to plot.
348
    axes_width : :class:`float`
349
        Width of the marginal axes.
350

351
    Returns
352
    -------
353
    :class:`matplotlib.axes.Axes`
354
        Axes on which the nan_scatter is plotted.
355

356
    """
357
    if ax is None:
12✔
358
        fig, ax = plt.subplots(1)
12✔
359

360
    ax.scatter(xdata, ydata, **kwargs)
12✔
361

362
    if hasattr(ax, "divider"):  # Don't rebuild axes
12✔
363
        div = ax.divider
12✔
364
        nanaxx = div.nanaxx
12✔
365
        nanaxy = div.nanaxy
12✔
366
    else:  # Build axes
367
        nanaxx = subaxes(ax, side="bottom", width=axes_width)
12✔
368
        nanaxx.invert_yaxis()
12✔
369
        nanaxy = subaxes(ax, side="left", width=axes_width)
12✔
370
        nanaxy.invert_xaxis()
12✔
371
        ax.divider.nanaxx = nanaxx  # assign for later use
12✔
372
        ax.divider.nanaxy = nanaxy
12✔
373

374
    nanxdata = xdata[(np.isnan(ydata) & np.isfinite(xdata))]
12✔
375
    nanydata = ydata[(np.isnan(xdata) & np.isfinite(ydata))]
12✔
376

377
    # yminmax = np.nanmin(ydata), np.nanmax(ydata)
378
    no_ybins = 50
12✔
379
    ybinwidth = (np.nanmax(ydata) - np.nanmin(ydata)) / no_ybins
12✔
380
    ybins = np.linspace(np.nanmin(ydata), np.nanmax(ydata) + ybinwidth, no_ybins)
12✔
381

382
    nanaxy.hist(nanydata, bins=ybins, orientation="horizontal", **kwargs)
12✔
383
    nanaxy.scatter(
12✔
384
        10 * np.ones_like(nanydata) + 5 * np.random.randn(len(nanydata)),
385
        nanydata,
386
        zorder=-1,
387
        **kwargs,
388
    )
389

390
    # xminmax = np.nanmin(xdata), np.nanmax(xdata)
391
    no_xbins = 50
12✔
392
    xbinwidth = (np.nanmax(xdata) - np.nanmin(xdata)) / no_xbins
12✔
393
    xbins = np.linspace(np.nanmin(xdata), np.nanmax(xdata) + xbinwidth, no_xbins)
12✔
394

395
    nanaxx.hist(nanxdata, bins=xbins, **kwargs)
12✔
396
    nanaxx.scatter(
12✔
397
        nanxdata,
398
        10 * np.ones_like(nanxdata) + 5 * np.random.randn(len(nanxdata)),
399
        zorder=-1,
400
        **kwargs,
401
    )
402

403
    return ax
12✔
404

405

406
###############################################################################
407
# Helpers for pyrolite.comp.codata.sphere and related functions
408
from pyrolite.comp.codata import inverse_sphere
12✔
409

410

411
def _get_spherical_vector(phis):
12✔
412
    """
413
    Get a line aligned to a unit vector corresponding to a specific combination
414
    of angles.
415

416
    Parameters
417
    ----------
418
    phis : :class:`numpy.ndarray`
419

420
    Returns
421
    -------
422
    :class:`numpy.ndarray`
423
    """
424
    vector = np.sqrt(inverse_sphere(phis))
12✔
425
    return np.vstack([np.zeros_like(vector), vector, vector * 1.5])
12✔
426

427

428
def _plot_spherical_vector(ax, phis, marker="D", markevery=(1, 2), ls="--", **kwargs):
12✔
429
    """
430
    Plot a unit vector corresponding to angles `phis` on a specified axis.
431

432
    Parameters
433
    ----------
434
    ax : :class:`matplotlib.axes.Axes3D`
435
    """
436
    vector = _get_spherical_vector(phis)
12✔
437
    ax.plot(*vector.T, marker=marker, markevery=markevery, ls=ls, **kwargs)
12✔
438

439

440
def _get_spherical_arc(thetas0, thetas1, resolution=100):
12✔
441
    """
442
    Get a 3D arc on a sphere between two points.
443

444
    Parameters
445
    ----------
446
    thetas0 : :class:`numpy.ndarray`
447
        Angles corresponding to first unit vector.
448
    thetas1 : :class:`numpy.ndarray`
449
        Angles corresponding to second unit vector.
450
    resolution : :class:`int`
451
        Resolution of the line to be used/number of points in the line.
452

453
    Returns
454
    -------
455
    :class:`numpy.ndarray`
456
    """
457
    # check that the points are on the sphere?
458
    v0, v1 = _get_spherical_vector(thetas0)[1], _get_spherical_vector(thetas1)[1]
12✔
459
    vs = v0 + np.linspace(0, 1, resolution + 1)[:, None] * (v1 - v0)
12✔
460
    r = np.sqrt((vs**2).sum(axis=1))  # equivalent arc radius
12✔
461
    vs = vs / r[:, None]
12✔
462
    return vs
12✔
463

464

465
def init_spherical_octant(
12✔
466
    angle_indicated=30, labels=None, view_init=(25, 55), fontsize=10, **kwargs
467
):
468
    """
469
    Initalize a figure with a 3D octant of a unit sphere, appropriately labeled
470
    with regard to angles corresponding to the handling of the respective
471
    compositional data transformation function (:func:`~pyrolite.comp.codata.sphere`).
472

473
    Parameters
474
    -----------
475
    angle_indicated : :class:`float`
476
        Angle relative to axes for indicating the relative positioning, optional.
477
    labels : :class:`list`
478
        Optional specification of data/axes labels. This will be used for labelling
479
        of both the axes and optionally-added arcs specifying which angles are
480
        represented.
481

482
    Returns
483
    -------
484
    ax : :class:`matplotlib.axes.Axes3D`
485
        Initialized 3D axis.
486
    """
487
    ax = init_axes(subplot_kw=dict(projection="3d"), **kwargs)
12✔
488

489
    ax.view_init(*view_init)
12✔
490
    ax.set_xlabel("x")
12✔
491
    ax.set_ylabel("y")
12✔
492
    ax.set_zlabel("z")
12✔
493
    ax.xaxis.pane.fill = False
12✔
494
    ax.yaxis.pane.fill = False
12✔
495
    ax.zaxis.pane.fill = False
12✔
496
    ax.xaxis.pane.set_edgecolor("w")
12✔
497
    ax.yaxis.pane.set_edgecolor("w")
12✔
498
    ax.zaxis.pane.set_edgecolor("w")
12✔
499
    ax.grid(False)
12✔
500

501
    if labels is None:
12✔
502
        labels = ["x", "y", "z"]
12✔
503
        angle_labels = [r"$\theta_2$", r"$\theta_3$"]
12✔
504
    else:
505
        angle_labels = [
×
506
            r"$\theta_{" + labels[-2] + "}$",
507
            r"$\theta_{" + labels[-1] + "}$",
508
        ]
509

510
    # axes lines
511
    lines = np.array([[0, 1, 1.5], [0, 0, 0]])
12✔
512

513
    ax.plot(*lines[[0, 1, 1]], lw=2, color="k", marker="D", markevery=(1, 2))  # x axis
12✔
514
    ax.plot(*lines[[1, 0, 1]], lw=2, color="k", marker="D", markevery=(1, 2))  # y axis
12✔
515
    ax.plot(*lines[[1, 1, 0]], lw=2, color="k", marker="D", markevery=(1, 2))  # z axis
12✔
516
    # axes labels
517
    for ix, row in enumerate(np.eye(3) * 1.6):
12✔
518
        ax.text(*row, labels[ix], fontsize=fontsize)
12✔
519

520
    if angle_indicated is not None:
12✔
521
        _a = np.deg2rad(angle_indicated)
12✔
522
        # theta 2 ##############################################################
523
        _plot_spherical_vector(ax, np.array([[_a, np.pi / 2]]), color="purple")
12✔
524
        ax.plot(
12✔
525
            *_get_spherical_arc(
526
                np.array([[_a, np.pi / 2]]), np.array([[0, np.pi / 2]])
527
            ).T,
528
            color="purple",
529
        )
530
        theta2_pos = (
12✔
531
            _get_spherical_vector(np.array([[_a, np.pi / 2]]))[1]
532
            + np.array([0, 1, 0]) / 2
533
        )
534
        ax.text(*theta2_pos, angle_labels[0], color="purple", fontsize=fontsize)
12✔
535

536
        # theta 3 ##############################################################
537
        _plot_spherical_vector(ax, np.array([[_a, _a]]), color="g")
12✔
538
        ax.plot(
12✔
539
            *_get_spherical_arc(np.array([[np.pi / 2, 0]]), np.array([[_a, _a]])).T,
540
            color="green",
541
        )
542
        theta3_pos = (
12✔
543
            _get_spherical_vector(np.array([[_a, _a]]))[1] + np.array([0, 0, 1]) / 2
544
        )
545
        ax.text(
12✔
546
            *theta3_pos, angle_labels[1], ha="left", color="green", fontsize=fontsize
547
        )
548

549
    return ax
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc