• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 5564819928

pending completion
5564819928

push

github

morganjwilliams
Merge branch 'release/0.3.3' into main

249 of 270 new or added lines in 48 files covered. (92.22%)

217 existing lines in 33 files now uncovered.

5971 of 6605 relevant lines covered (90.4%)

10.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.13
/pyrolite/util/plot/helpers.py
1
"""
2
matplotlib helper functions for commong drawing tasks.
3
"""
4
import matplotlib.patches
12✔
5
import matplotlib.pyplot as plt
12✔
6
import numpy as np
12✔
7
import scipy.spatial
12✔
8

9
from ..log import Handle
12✔
10
from ..math import eigsorted, nancov
12✔
11
from ..missing import cooccurence_pattern
12✔
12
from ..text import int_to_alpha
12✔
13
from .axes import add_colorbar, init_axes, subaxes
12✔
14
from .interpolation import interpolated_patch_path
12✔
15

16
logger = Handle(__name__)
12✔
17

18
try:
12✔
19
    from sklearn.decomposition import PCA
12✔
20
except ImportError:
×
21
    msg = "scikit-learn not installed"
×
22
    logger.warning(msg)
×
23

24

25
def alphalabel_subplots(ax, fmt="{}", xy=(0.03, 0.95), ha="left", va="top", **kwargs):
12✔
26
    """
27
    Add alphabetical labels to a successive series of subplots with a specified format.
28

29
    Parameters
30
    -----------
31
    ax : :class:`list` | :class:`numpy.ndarray` | :class:`numpy.flatiter`
32
        Axes to label, in desired order.
33
    fmt : :class:`str`
34
        Format string to use. To add e.g. parentheses, you could specify :code:`"({})"`.
35
    xy : :class:`tuple`
36
        Position of the labels in axes coordinates.
37
    ha : :class:`str`
38
        Horizontal alignment of the labels (:code:`{"left", "right"}`).
39
    va : :class:`str`
40
        Vertical alignment of the labels (:code:`{"top", "bottom"}`).
41
    """
42
    flat = np.array(ax).flatten()
×
43
    # get axes in case of iterator which is consumed
44
    _ax = [(ix, flat[ix]) for ix in range(len(flat))]
×
45
    labels = [(a, fmt.format(int_to_alpha(ix))) for ix, a in _ax]
×
46
    [
×
47
        a.annotate(label, xy=xy, xycoords=a.transAxes, ha=ha, va=va, **kwargs)
48
        for a, label in labels
49
    ]
50

51

52
def get_centroid(poly):
12✔
53
    """
54
    Centroid of a closed polygon using the Shoelace formula.
55

56
    Parameters
57
    ----------
58
    poly : :class:`matplotlib.patches.Polygon`
59
        Polygon to obtain the centroid of.
60

61
    Returns
62
    -------
63
    cx, cy : :class:`tuple`
64
        Centroid coordinates.
65
    """
66
    # get signed area
67
    verts = poly.get_xy()
12✔
68
    A = 0
12✔
69
    cx, cy = 0, 0
12✔
70
    x, y = verts.T
12✔
71
    for i in range(len(verts) - 1):
12✔
72
        A += x[i] * y[i + 1] - x[i + 1] * y[i]
12✔
73
        cx += (x[i] + x[i + 1]) * (x[i] * y[i + 1] - x[i + 1] * y[i])
12✔
74
        cy += (y[i] + y[i + 1]) * (x[i] * y[i + 1] - x[i + 1] * y[i])
12✔
75
    A /= 2
12✔
76
    cx /= 6 * A
12✔
77
    cy /= 6 * A
12✔
78
    return cx, cy
12✔
79

80

81
def rect_from_centre(x, y, dx=0, dy=0, **kwargs):
12✔
82
    """
83
    Takes an xy point, and creates a rectangular patch centred about it.
84
    """
85
    # If either x or y is nan
86
    if any([np.isnan(i) for i in [x, y]]):
12✔
87
        return None
×
88
    if np.isnan(dx):
12✔
89
        dx = 0
×
90
    if np.isnan(dy):
12✔
91
        dy = 0
×
92
    llc = (x - dx, y - dy)
12✔
93
    return matplotlib.patches.Rectangle(llc, 2 * dx, 2 * dy, **kwargs)
12✔
94

95

96
def draw_vector(v0, v1, ax=None, **kwargs):
12✔
97
    """
98
    Plots an arrow represnting the direction and magnitue of a principal
99
    component on a biaxial plot.
100

101
    Modified after Jake VanderPlas' Python Data Science Handbook
102
    https://jakevdp.github.io/PythonDataScienceHandbook/ \
103
    05.09-principal-component-analysis.html
104

105
    Todo
106
    -----
107
        Update for ternary plots.
108

109
    """
110
    ax = ax
12✔
111
    arrowprops = dict(arrowstyle="->", linewidth=2, shrinkA=0, shrinkB=0)
12✔
112
    arrowprops.update(kwargs)
12✔
113
    ax.annotate("", v1, v0, arrowprops=arrowprops)
12✔
114

115

116
def vector_to_line(
12✔
117
    mu: np.array, vector: np.array, variance: float, spans: int = 4, expand: int = 10
118
):
119
    """
120
    Creates an array of points representing a line along a vector - typically
121
    for principal component analysis. Modified after Jake VanderPlas' Python Data
122
    Science Handbook https://jakevdp.github.io/PythonDataScienceHandbook/ \
123
    05.09-principal-component-analysis.html
124
    """
125
    length = np.sqrt(variance)
12✔
126
    parts = np.linspace(-spans, spans, expand * spans + 1)
12✔
127
    line = length * np.dot(parts[:, np.newaxis], vector[np.newaxis, :]) + mu
12✔
128
    line = length * parts.reshape(parts.shape[0], 1) * vector + mu
12✔
129
    return line
12✔
130

131

132
def plot_stdev_ellipses(
12✔
133
    comp, nstds=4, scale=100, resolution=1000, transform=None, ax=None, **kwargs
134
):
135
    """
136
    Plot covariance ellipses at a number of standard deviations from the mean.
137

138
    Parameters
139
    -------------
140
    comp : :class:`numpy.ndarray`
141
        Composition to use.
142
    nstds : :class:`int`
143
        Number of standard deviations from the mean for which to plot the ellipses.
144
    scale : :class:`float`
145
        Scale applying to all x-y data points. For intergration with python-ternary.
146
    transform : :class:`callable`
147
        Function for transformation of data prior to plotting (to either 2D or 3D).
148
    ax : :class:`matplotlib.axes.Axes`
149
        Axes to plot on.
150

151
    Returns
152
    -------
153
    ax :  :class:`matplotlib.axes.Axes`
154
    """
155
    mean, cov = np.nanmean(comp, axis=0), nancov(comp)
12✔
156
    vals, vecs = eigsorted(cov)
12✔
157
    theta = np.degrees(np.arctan2(*vecs[::-1]))
12✔
158

159
    if ax is None:
12✔
160
        projection = None
12✔
161
        if callable(transform) and (transform is not None):
12✔
162
            if transform(comp).shape[1] == 3:
12✔
163
                projection = "ternary"
12✔
164

165
        fig, ax = plt.subplots(1, subplot_kw=dict(projection=projection))
12✔
166

167
    for nstd in np.arange(1, nstds + 1)[::-1]:  # backwards for svg construction
12✔
168
        # here we use the absolute eigenvalues
169
        xsig, ysig = nstd * np.sqrt(np.abs(vals))  # n sigmas
12✔
170
        ell = matplotlib.patches.Ellipse(
12✔
171
            xy=mean.flatten(), width=2 * xsig, height=2 * ysig, angle=theta[:1]
172
        )
173
        points = interpolated_patch_path(ell, resolution=resolution).vertices
12✔
174

175
        if callable(transform) and (transform is not None):
12✔
176
            points = transform(points)  # transform to compositional data
12✔
177

178
        if points.shape[1] == 3:
12✔
179
            ax_transfrom = (ax.transData + ax.transTernaryAxes.inverted()).inverted()
12✔
180
            points = ax_transfrom.transform(points)  # transform to axes coords
12✔
181

182
        patch = matplotlib.patches.PathPatch(matplotlib.path.Path(points), **kwargs)
12✔
183
        patch.set_edgecolor("k")
12✔
184
        patch.set_alpha(1.0 / nstd)
12✔
185
        patch.set_linewidth(0.5)
12✔
186
        ax.add_artist(patch)
12✔
187
    return ax
12✔
188

189

190
def plot_pca_vectors(
12✔
191
    comp,
192
    nstds=2,
193
    scale=100.0,
194
    transform=None,
195
    ax=None,
196
    colors=None,
197
    linestyles=None,
198
    **kwargs
199
):
200
    """
201
    Plot vectors corresponding to principal components and their magnitudes.
202

203
    Parameters
204
    -------------
205
    comp : :class:`numpy.ndarray`
206
        Composition to use.
207
    nstds : :class:`int`
208
        Multiplier for magnitude of individual principal component vectors.
209
    scale : :class:`float`
210
        Scale applying to all x-y data points. For intergration with python-ternary.
211
    transform : :class:`callable`
212
        Function for transformation of data prior to plotting (to either 2D or 3D).
213
    ax : :class:`matplotlib.axes.Axes`
214
        Axes to plot on.
215

216
    Returns
217
    -------
218
    ax :  :class:`matplotlib.axes.Axes`
219

220
    Todo
221
    -----
222
        * Minor reimplementation of the sklearn PCA to avoid dependency.
223

224
            https://en.wikipedia.org/wiki/Principal_component_analysis
225
    """
226
    pca = PCA(n_components=2)
12✔
227
    pca.fit(comp)
12✔
228

229
    if ax is None:
12✔
230
        fig, ax = plt.subplots(1)
12✔
231

232
    items = [pca.explained_variance_, pca.components_]
12✔
233
    if linestyles is not None:
12✔
234
        assert len(linestyles) == 2
×
NEW
235
        items.append(linestyles)
×
236
    else:
237
        items.append([None, None])
12✔
238
    if colors is not None:
12✔
239
        assert len(colors) == 2
×
NEW
240
        items.append(colors)
×
241
    else:
242
        items.append([None, None])
12✔
243
    for variance, vector, linestyle, color in zip(*items):
12✔
244
        line = vector_to_line(pca.mean_, vector, variance, spans=nstds)
12✔
245
        if callable(transform) and (transform is not None):
12✔
246
            line = transform(line)
12✔
247
        line *= scale
12✔
248
        kw = {**kwargs}
12✔
249
        if color is not None:
12✔
250
            kw["color"] = color
×
251
        if linestyle is not None:
12✔
252
            kw["ls"] = linestyle
×
253
        ax.plot(*line.T, **kw)
12✔
254
    return ax
12✔
255

256

257
def plot_2dhull(data, ax=None, splines=False, s=0, **plotkwargs):
12✔
258
    """
259
    Plots a 2D convex hull around an array of xy data points.
260
    """
261
    if ax is None:
12✔
262
        fig, ax = plt.subplots(1)
×
263
    chull = scipy.spatial.ConvexHull(data, incremental=True)
12✔
264
    x, y = data[chull.vertices].T
12✔
265
    if not splines:
12✔
266
        lines = ax.plot(np.append(x, [x[0]]), np.append(y, [y[0]]), **plotkwargs)
12✔
267
    else:
268
        # https://stackoverflow.com/questions/33962717/interpolating-a-closed-curve-using-scipy
269
        tck, u = scipy.interpolate.splprep([x, y], per=True, s=s)
12✔
270
        xi, yi = scipy.interpolate.splev(np.linspace(0, 1, 1000), tck)
12✔
271
        lines = ax.plot(xi, yi, **plotkwargs)
12✔
272
    return lines
12✔
273

274

275
def plot_cooccurence(arr, ax=None, normalize=True, log=False, colorbar=False, **kwargs):
12✔
276
    """
277
    Plot the co-occurence frequency matrix for a given input.
278

279
    Parameters
280
    -----------
281
    ax : :class:`matplotlib.axes.Axes`, :code:`None`
282
        The subplot to draw on.
283
    normalize : :class:`bool`
284
        Whether to normalize the cooccurence to compare disparate variables.
285
    log : :class:`bool`
286
        Whether to take the log of the cooccurence.
287
    colorbar : :class:`bool`
288
        Whether to append a colorbar.
289

290
    Returns
291
    --------
292
    :class:`matplotlib.axes.Axes`
293
        Axes on which the cooccurence plot is added.
294
    """
295
    arr = np.array(arr)
12✔
296
    if ax is None:
12✔
297
        fig, ax = plt.subplots(1, figsize=(4 + [0.0, 0.2][colorbar], 4))
12✔
298
    co_occur = cooccurence_pattern(arr, normalize=normalize, log=log)
12✔
299
    heatmap = ax.pcolor(co_occur, **kwargs)
12✔
300
    ax.set_yticks(np.arange(co_occur.shape[0]) + 0.5, minor=False)
12✔
301
    ax.set_xticks(np.arange(co_occur.shape[1]) + 0.5, minor=False)
12✔
302
    ax.invert_yaxis()
12✔
303
    ax.xaxis.tick_top()
12✔
304
    if colorbar:
12✔
305
        add_colorbar(heatmap, **kwargs)
12✔
306
    return ax
12✔
307

308

309
def nan_scatter(xdata, ydata, ax=None, axes_width=0.2, **kwargs):
12✔
310
    """
311
    Scatter plot with additional marginal axes to plot data for which data is partially
312
    missing. Additional keyword arguments are passed to matplotlib.
313

314
    Parameters
315
    ----------
316
    xdata : :class:`numpy.ndarray`
317
        X data
318
    ydata: class:`numpy.ndarray` | pd.Series
319
        Y data
320
    ax : :class:`matplotlib.axes.Axes`
321
        Axes on which to plot.
322
    axes_width : :class:`float`
323
        Width of the marginal axes.
324

325
    Returns
326
    -------
327
    :class:`matplotlib.axes.Axes`
328
        Axes on which the nan_scatter is plotted.
329

330
    """
331
    if ax is None:
12✔
332
        fig, ax = plt.subplots(1)
12✔
333

334
    ax.scatter(xdata, ydata, **kwargs)
12✔
335

336
    if hasattr(ax, "divider"):  # Don't rebuild axes
12✔
337
        div = ax.divider
12✔
338
        nanaxx = div.nanaxx
12✔
339
        nanaxy = div.nanaxy
12✔
340
    else:  # Build axes
341
        nanaxx = subaxes(ax, side="bottom", width=axes_width)
12✔
342
        nanaxx.invert_yaxis()
12✔
343
        nanaxy = subaxes(ax, side="left", width=axes_width)
12✔
344
        nanaxy.invert_xaxis()
12✔
345
        ax.divider.nanaxx = nanaxx  # assign for later use
12✔
346
        ax.divider.nanaxy = nanaxy
12✔
347

348
    nanxdata = xdata[(np.isnan(ydata) & np.isfinite(xdata))]
12✔
349
    nanydata = ydata[(np.isnan(xdata) & np.isfinite(ydata))]
12✔
350

351
    # yminmax = np.nanmin(ydata), np.nanmax(ydata)
352
    no_ybins = 50
12✔
353
    ybinwidth = (np.nanmax(ydata) - np.nanmin(ydata)) / no_ybins
12✔
354
    ybins = np.linspace(np.nanmin(ydata), np.nanmax(ydata) + ybinwidth, no_ybins)
12✔
355

356
    nanaxy.hist(nanydata, bins=ybins, orientation="horizontal", **kwargs)
12✔
357
    nanaxy.scatter(
12✔
358
        10 * np.ones_like(nanydata) + 5 * np.random.randn(len(nanydata)),
359
        nanydata,
360
        zorder=-1,
361
        **kwargs,
362
    )
363

364
    # xminmax = np.nanmin(xdata), np.nanmax(xdata)
365
    no_xbins = 50
12✔
366
    xbinwidth = (np.nanmax(xdata) - np.nanmin(xdata)) / no_xbins
12✔
367
    xbins = np.linspace(np.nanmin(xdata), np.nanmax(xdata) + xbinwidth, no_xbins)
12✔
368

369
    nanaxx.hist(nanxdata, bins=xbins, **kwargs)
12✔
370
    nanaxx.scatter(
12✔
371
        nanxdata,
372
        10 * np.ones_like(nanxdata) + 5 * np.random.randn(len(nanxdata)),
373
        zorder=-1,
374
        **kwargs,
375
    )
376

377
    return ax
12✔
378

379

380
###############################################################################
381
# Helpers for pyrolite.comp.codata.sphere and related functions
382
from pyrolite.comp.codata import inverse_sphere
12✔
383

384

385
def _get_spherical_vector(phis):
12✔
386
    """
387
    Get a line aligned to a unit vector corresponding to a specific combination
388
    of angles.
389

390
    Parameters
391
    ----------
392
    phis : :class:`numpy.ndarray`
393

394
    Returns
395
    -------
396
    :class:`numpy.ndarray`
397
    """
398
    vector = np.sqrt(inverse_sphere(phis))
12✔
399
    return np.vstack([np.zeros_like(vector), vector, vector * 1.5])
12✔
400

401

402
def _plot_spherical_vector(ax, phis, marker="D", markevery=(1, 2), ls="--", **kwargs):
12✔
403
    """
404
    Plot a unit vector corresponding to angles `phis` on a specified axis.
405

406
    Parameters
407
    ----------
408
    ax : :class:`matplotlib.axes.Axes3D`
409
    """
410
    vector = _get_spherical_vector(phis)
12✔
411
    ax.plot(*vector.T, marker=marker, markevery=markevery, ls=ls, **kwargs)
12✔
412

413

414
def _get_spherical_arc(thetas0, thetas1, resolution=100):
12✔
415
    """
416
    Get a 3D arc on a sphere between two points.
417

418
    Parameters
419
    ----------
420
    thetas0 : :class:`numpy.ndarray`
421
        Angles corresponding to first unit vector.
422
    thetas1 : :class:`numpy.ndarray`
423
        Angles corresponding to second unit vector.
424
    resolution : :class:`int`
425
        Resolution of the line to be used/number of points in the line.
426

427
    Returns
428
    -------
429
    :class:`numpy.ndarray`
430
    """
431
    # check that the points are on the sphere?
432
    v0, v1 = _get_spherical_vector(thetas0)[1], _get_spherical_vector(thetas1)[1]
12✔
433
    vs = v0 + np.linspace(0, 1, resolution + 1)[:, None] * (v1 - v0)
12✔
434
    r = np.sqrt((vs**2).sum(axis=1))  # equivalent arc radius
12✔
435
    vs = vs / r[:, None]
12✔
436
    return vs
12✔
437

438

439
def init_spherical_octant(
12✔
440
    angle_indicated=30, labels=None, view_init=(25, 55), fontsize=10, **kwargs
441
):
442
    """
443
    Initalize a figure with a 3D octant of a unit sphere, appropriately labeled
444
    with regard to angles corresponding to the handling of the respective
445
    compositional data transformation function (:func:`~pyrolite.comp.codata.sphere`).
446

447
    Parameters
448
    -----------
449
    angle_indicated : :class:`float`
450
        Angle relative to axes for indicating the relative positioning, optional.
451
    labels : :class:`list`
452
        Optional specification of data/axes labels. This will be used for labelling
453
        of both the axes and optionally-added arcs specifying which angles are
454
        represented.
455

456
    Returns
457
    -------
458
    ax : :class:`matplotlib.axes.Axes3D`
459
        Initialized 3D axis.
460
    """
461
    ax = init_axes(subplot_kw=dict(projection="3d"), **kwargs)
12✔
462

463
    ax.view_init(*view_init)
12✔
464
    ax.set_xlabel("x")
12✔
465
    ax.set_ylabel("y")
12✔
466
    ax.set_zlabel("z")
12✔
467
    ax.xaxis.pane.fill = False
12✔
468
    ax.yaxis.pane.fill = False
12✔
469
    ax.zaxis.pane.fill = False
12✔
470
    ax.xaxis.pane.set_edgecolor("w")
12✔
471
    ax.yaxis.pane.set_edgecolor("w")
12✔
472
    ax.zaxis.pane.set_edgecolor("w")
12✔
473
    ax.grid(False)
12✔
474

475
    if labels is None:
12✔
476
        labels = ["x", "y", "z"]
12✔
477
        angle_labels = [r"$\theta_2$", r"$\theta_3$"]
12✔
478
    else:
UNCOV
479
        angle_labels = [
×
480
            r"$\theta_{" + labels[-2] + "}$",
481
            r"$\theta_{" + labels[-1] + "}$",
482
        ]
483

484
    # axes lines
485
    lines = np.array([[0, 1, 1.5], [0, 0, 0]])
12✔
486

487
    ax.plot(*lines[[0, 1, 1]], lw=2, color="k", marker="D", markevery=(1, 2))  # x axis
12✔
488
    ax.plot(*lines[[1, 0, 1]], lw=2, color="k", marker="D", markevery=(1, 2))  # y axis
12✔
489
    ax.plot(*lines[[1, 1, 0]], lw=2, color="k", marker="D", markevery=(1, 2))  # z axis
12✔
490
    # axes labels
491
    for ix, row in enumerate(np.eye(3) * 1.6):
12✔
492
        ax.text(*row, labels[ix], fontsize=fontsize)
12✔
493

494
    if angle_indicated is not None:
12✔
495
        _a = np.deg2rad(angle_indicated)
12✔
496
        # theta 2 ##############################################################
497
        _plot_spherical_vector(ax, np.array([[_a, np.pi / 2]]), color="purple")
12✔
498
        ax.plot(
12✔
499
            *_get_spherical_arc(
500
                np.array([[_a, np.pi / 2]]), np.array([[0, np.pi / 2]])
501
            ).T,
502
            color="purple",
503
        )
504
        theta2_pos = (
12✔
505
            _get_spherical_vector(np.array([[_a, np.pi / 2]]))[1]
506
            + np.array([0, 1, 0]) / 2
507
        )
508
        ax.text(*theta2_pos, angle_labels[0], color="purple", fontsize=fontsize)
12✔
509

510
        # theta 3 ##############################################################
511
        _plot_spherical_vector(ax, np.array([[_a, _a]]), color="g")
12✔
512
        ax.plot(
12✔
513
            *_get_spherical_arc(np.array([[np.pi / 2, 0]]), np.array([[_a, _a]])).T,
514
            color="green",
515
        )
516
        theta3_pos = (
12✔
517
            _get_spherical_vector(np.array([[_a, _a]]))[1] + np.array([0, 0, 1]) / 2
518
        )
519
        ax.text(
12✔
520
            *theta3_pos, angle_labels[1], ha="left", color="green", fontsize=fontsize
521
        )
522

523
    return ax
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc