• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 17569160869

09 Sep 2025 01:41AM UTC coverage: 91.465% (-0.1%) from 91.614%
17569160869

push

github

morganjwilliams
Add uncertainties, add optional deps for pyproject.toml; WIP demo NB

6226 of 6807 relevant lines covered (91.46%)

10.97 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.19
/pyrolite/util/plot/axes.py
1
"""
2
Functions for creating, ordering and modifying :class:`~matplolib.axes.Axes`.
3
"""
4

5
import warnings
12✔
6

7
import matplotlib.pyplot as plt
12✔
8
import numpy as np
12✔
9
from mpl_toolkits.axes_grid1 import make_axes_locatable
12✔
10

11
from ..log import Handle
12✔
12
from ..meta import subkwargs
12✔
13

14
logger = Handle(__name__)
12✔
15

16

17
def get_ordered_axes(fig):
12✔
18
    """
19
    Get the axes from a figure, which may or may not have been modified by
20
    pyrolite functions. This ensures that ordering is preserved.
21
    """
22
    if hasattr(fig, "orderedaxes"):  # previously modified
12✔
23
        axes = fig.orderedaxes
12✔
24
    else:  # unmodified axes
25
        axes = fig.axes
12✔
26
    return axes
12✔
27

28

29
def get_axes_index(ax):
12✔
30
    """
31
    Get the three-digit integer index of a subplot in a regular grid.
32

33
    Parameters
34
    -----------
35
    ax : :class:`matplotlib.axes.Axes`
36
        Axis to to get the gridspec index for.
37

38
    Returns
39
    -----------
40
    :class:`tuple`
41
        Rows, columns and axis index for the gridspec.
42
    """
43
    nrow, ncol = ax.get_gridspec()._nrows, ax.get_gridspec()._ncols
12✔
44
    index = get_ordered_axes(ax.figure).index(ax)
12✔
45
    triple = nrow, ncol, index + 1
12✔
46
    return triple
12✔
47

48

49
def replace_with_ternary_axis(ax):
12✔
50
    """
51
    Replace a specified axis with a ternary equivalent.
52

53
    Parameters
54
    ------------
55
    ax : :class:`~matplotlib.axes.Axes`
56

57
    Returns
58
    ------------
59
    tax : :class:`~mpltern.ternary.TernaryAxes`
60
    """
61
    if ax.name != "ternary":
12✔
62
        if not check_default_axes(ax):
12✔
63
            if not check_empty(ax):
×
64
                warnings.warn(
×
65
                    "Non-empty, non-default bivariate axes being replaced with ternary axes."
66
                )
67
            else:
68
                logger.info(
×
69
                    "Non-default bivraite axes being replaced with ternary axes."
70
                )
71
    fig = ax.figure
12✔
72
    axes = get_ordered_axes(fig)
12✔
73
    idx = axes.index(ax)
12✔
74
    tax = fig.add_subplot(*get_axes_index(ax), projection="ternary")
12✔
75
    fig.add_axes(tax)  # make sure the axis is added to fig.children
12✔
76
    fig.delaxes(ax)  # remove the original axes
12✔
77
    # update figure ordered axes
78
    fig.orderedaxes = [a if ix != idx else tax for (ix, a) in enumerate(axes)]
12✔
79
    return tax
12✔
80

81

82
def label_axes(ax, labels=[], **kwargs):
12✔
83
    """
84
    Convenience function for labelling rectilinear and ternary axes.
85

86
    Parameters
87
    -----------
88
    ax : :class:`~matplotlib.axes.Axes`
89
        Axes to label.
90
    labels : :class:`list`
91
        List of labels: [x, y] | or [t, l, r]
92
    """
93
    if (ax.name == "ternary") and (len(labels) == 3):
12✔
94
        tvar, lvar, rvar = labels
12✔
95
        ax.set_tlabel(tvar, **kwargs)
12✔
96
        ax.set_llabel(lvar, **kwargs)
12✔
97
        ax.set_rlabel(rvar, **kwargs)
12✔
98
    elif len(labels) == 2:
12✔
99
        xvar, yvar = labels
12✔
100
        ax.set_xlabel(xvar, **kwargs)
12✔
101
        ax.set_ylabel(yvar, **kwargs)
12✔
102
    else:
103
        raise NotImplementedError
×
104

105

106
def axes_to_ternary(ax):
12✔
107
    """
108
    Set axes to ternary projection after axis creation. As currently implemented,
109
    note that this will replace and reorder axes as acecessed from the figure (the
110
    ternary axis will then be at the end), and as such this returns a list of axes
111
    in the correct order.
112

113
    Parameters
114
    -----------
115
    ax : :class:`~matplotlib.axes.Axes` | :class:`list` (:class:`~matplotlib.axes.Axes`)
116
        Axis (or axes) to convert projection for.
117

118
    Returns
119
    ---------
120
    axes : :class:`list' (:class:`~matplotlib.axes.Axes`, class:`~mpltern.ternary.TernaryAxes`)
121
    """
122

123
    if isinstance(ax, (list, np.ndarray, tuple)):  # multiple Axes specified
12✔
124
        fig = ax[0].figure
12✔
125
        for a in ax:  # axes to set to ternary
12✔
126
            replace_with_ternary_axis(a)
12✔
127
    else:  # a single Axes is passed
128
        fig = ax.figure
12✔
129
        replace_with_ternary_axis(ax)
12✔
130
    return fig.orderedaxes
12✔
131

132

133
def check_default_axes(ax):
12✔
134
    """
135
    Simple test to check whether an axis is empty of artists and hasn't been
136
    rescaled from the default extent.
137

138
    Parameters
139
    -----------
140
    ax : :class:`matplotlib.axes.Axes`
141
        Axes to check for artists and scaling.
142

143
    Returns
144
    -------
145
    :class:`bool`
146
    """
147

148
    if np.allclose(ax.axis(), np.array([0, 1, 0, 1])):
12✔
149
        return check_empty(ax)
12✔
150
    else:
151
        return False
×
152

153

154
def check_empty(ax):
12✔
155
    """
156
    Simple test to check whether an axis is empty of artists.
157

158
    Parameters
159
    -----------
160
    ax : :class:`matplotlib.axes.Axes`
161
        Axes to check for artists.
162

163
    Returns
164
    -------
165
    :class:`bool`
166
    """
167
    if not (ax.lines + ax.collections + ax.patches + ax.artists + ax.texts + ax.images):
12✔
168
        return True
12✔
169
    else:
170
        return False
×
171

172

173
def init_axes(ax=None, projection=None, minsize=1.0, **kwargs):
12✔
174
    """
175
    Get or create an Axes from an optionally-specified starting Axes.
176

177
    Parameters
178
    -----------
179
    ax : :class:`~matplotlib.axes.Axes`
180
        Specified starting axes, optional.
181
    projection : :class:`str`
182
        Whether to create a projected (e.g. ternary) axes.
183
    minsize : :class:`float`
184
        Minimum figure dimension (inches).
185

186
    Returns
187
    --------
188
    ax : :class:`~matplotlib.axes.Axes`
189
    """
190
    if "figsize" in kwargs.keys():
12✔
191
        fs = kwargs["figsize"]
12✔
192
        kwargs["figsize"] = (
12✔
193
            max(fs[0], minsize),
194
            max(fs[1], minsize),
195
        )  # minimum figsize
196
    if projection is not None:  # e.g. ternary
12✔
197
        if ax is None:
12✔
198
            fig, ax = plt.subplots(
12✔
199
                1,
200
                subplot_kw=dict(projection=projection),
201
                **subkwargs(kwargs, plt.subplots, plt.figure),
202
            )
203
        else:  # axes passed
204
            if ax.name != "ternary":
12✔
205
                # if an axis is converted previously, but the original axes reference
206
                # is used again, we'll end up with an error
207
                current_axes = get_ordered_axes(ax.figure)
12✔
208
                try:
12✔
209
                    ix = current_axes.index(ax)
12✔
210
                    axes = axes_to_ternary(ax)  # returns list of axes
12✔
211
                    ax = axes[ix]
12✔
212
                except ValueError:  # ax is not in list
×
213
                    # ASSUMPTION due to mis-referencing:
214
                    # take the first ternary one
215
                    ax = [a for a in current_axes if a.name == "ternary"][0]
×
216
            else:
217
                pass
12✔
218
    else:
219
        if ax is None:
12✔
220
            fig, ax = plt.subplots(1, **subkwargs(kwargs, plt.subplots, plt.figure))
12✔
221
    return ax
12✔
222

223

224
def share_axes(axes, which="xy"):
12✔
225
    """
226
    Link the x, y or both axes across a group of :class:`~matplotlib.axes.Axes`.
227

228
    Parameters
229
    -----------
230
    axes : :class:`list`
231
        List of axes to link.
232
    which : :class:`str`
233
        Which axes to link. If :code:`x`, link the x-axes; if :code:`y` link the y-axes,
234
        otherwise link both.
235
    """
236
    if which == "both":
12✔
237
        which = "xy"
12✔
238
    if "x" in which:
12✔
239
        [a.sharex(axes[0]) for a in axes[1:]]
12✔
240
    if "y" in which:
12✔
241
        [a.sharey(axes[0]) for a in axes[1:]]
12✔
242

243

244
def get_twins(ax, which="y"):
12✔
245
    """
246
    Get twin axes of a specified axis.
247

248
    Parameters
249
    -----------
250
    ax : :class:`matplotlib.axes.Axes`
251
        Axes to get twins for.
252
    which : :class:`str`
253
        Which twins to get (shared :code:`'x'`, shared :code:`'y'` or the concatenatation
254
        of both, :code:`'xy'`).
255

256
    Returns
257
    --------
258
    :class:`list`
259

260
    Notes
261
    ------
262
    This function was designed to assist in avoiding creating a series of duplicate
263
    axes when replotting on an existing axis using a function which would typically
264
    create a twin axis.
265
    """
266
    s = []
12✔
267
    if "y" in which:
12✔
268
        s += ax.get_shared_y_axes().get_siblings(ax)
12✔
269
    if "x" in which:
12✔
270
        s += ax.get_shared_x_axes().get_siblings(ax)
12✔
271
    return list(
12✔
272
        set([a for a in s if (a is not ax) & (a.bbox.bounds == ax.bbox.bounds)])
273
    )
274

275

276
def subaxes(ax, side="bottom", width=0.2, moveticks=True):
12✔
277
    """
278
    Append a sub-axes to one side of an axes.
279

280
    Parameters
281
    -----------
282
    ax : :class:`matplotlib.axes.Axes`
283
        Axes to append a sub-axes to.
284
    side : :class:`str`
285
        Which side to append the axes on.
286
    width : :class:`float`
287
        Fraction of width to give to the subaxes.
288
    moveticks : :class:`bool`
289
        Whether to move ticks to the outer axes.
290

291
    Returns
292
    -------
293
    :class:`matplotlib.axes.Axes`
294
        Subaxes instance.
295
    """
296
    div = make_axes_locatable(ax)
12✔
297
    ax.divider = div
12✔
298

299
    if side in ["bottom", "top"]:
12✔
300
        which = "x"
12✔
301
        subax = div.append_axes(side, width, pad=0, sharex=ax)
12✔
302
        div.subax = subax
12✔
303
        subax.yaxis.set_visible(False)
12✔
304
        subax.spines["left"].set_visible(False)
12✔
305
        subax.spines["right"].set_visible(False)
12✔
306

307
    else:
308
        which = "y"
12✔
309
        subax = div.append_axes(side, width, pad=0, sharex=ax)
12✔
310
        div.subax = subax
12✔
311
        subax.yaxis.set_visible(False)
12✔
312
        subax.spines["top"].set_visible(False)
12✔
313
        subax.spines["bottom"].set_visible(False)
12✔
314

315
    share_axes([ax, subax], which=which)
12✔
316
    if moveticks:
12✔
317
        ax.tick_params(
12✔
318
            axis=which, which="both", bottom=False, top=False, labelbottom=False
319
        )
320
    return subax
12✔
321

322

323
def add_colorbar(mappable, **kwargs):
12✔
324
    """
325
    Adds a colorbar to a given mappable object.
326

327
    Source: http://joseph-long.com/writing/colorbars/
328

329
    Parameters
330
    ----------
331
    mappable
332
        The Image, ContourSet, etc. to which the colorbar applies.
333

334
    Returns
335
    -------
336
    :class:`matplotlib.colorbar.Colorbar`
337

338
    Todo
339
    ----
340
    *  Where no mappable specificed, get most recent axes, and check for collections etc
341
    """
342
    ax = kwargs.get("ax", None)
12✔
343
    if hasattr(mappable, "axes"):
12✔
344
        ax = ax or mappable.axes
12✔
345
    elif hasattr(mappable, "ax"):
×
346
        ax = ax or mappable.ax
×
347

348
    position = kwargs.pop("position", "right")
12✔
349
    size = kwargs.pop("size", "5%")
12✔
350
    pad = kwargs.pop("pad", 0.05)
12✔
351

352
    fig = ax.figure
12✔
353
    if ax.name == "ternary":
12✔
354
        cax = ax.inset_axes([1.05, 0.1, 0.05, 0.9], transform=ax.transAxes)
12✔
355
        colorbar = fig.colorbar(mappable, cax=cax, **kwargs)
12✔
356
    else:
357
        divider = make_axes_locatable(ax)
12✔
358
        cax = divider.append_axes(position, size=size, pad=pad)
12✔
359
        colorbar = fig.colorbar(mappable, cax=cax, **kwargs)
12✔
360
    return colorbar
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc