• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

morganjwilliams / pyrolite / 5497027953

pending completion
5497027953

push

github

morganjwilliams
Merge branch 'release/0.3.3' into main

249 of 270 new or added lines in 48 files covered. (92.22%)

217 existing lines in 33 files now uncovered.

5967 of 6601 relevant lines covered (90.4%)

10.84 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.19
/pyrolite/util/plot/axes.py
1
"""
2
Functions for creating, ordering and modifying :class:`~matplolib.axes.Axes`.
3
"""
4
import warnings
12✔
5

6
import matplotlib.pyplot as plt
12✔
7
import numpy as np
12✔
8
from mpl_toolkits.axes_grid1 import make_axes_locatable
12✔
9

10
from ..log import Handle
12✔
11
from ..meta import subkwargs
12✔
12

13
logger = Handle(__name__)
12✔
14

15

16
def get_ordered_axes(fig):
12✔
17
    """
18
    Get the axes from a figure, which may or may not have been modified by
19
    pyrolite functions. This ensures that ordering is preserved.
20
    """
21
    if hasattr(fig, "orderedaxes"):  # previously modified
12✔
22
        axes = fig.orderedaxes
12✔
23
    else:  # unmodified axes
24
        axes = fig.axes
12✔
25
    return axes
12✔
26

27

28
def get_axes_index(ax):
12✔
29
    """
30
    Get the three-digit integer index of a subplot in a regular grid.
31

32
    Parameters
33
    -----------
34
    ax : :class:`matplotlib.axes.Axes`
35
        Axis to to get the gridspec index for.
36

37
    Returns
38
    -----------
39
    :class:`tuple`
40
        Rows, columns and axis index for the gridspec.
41
    """
42
    nrow, ncol = ax.get_gridspec()._nrows, ax.get_gridspec()._ncols
12✔
43
    index = get_ordered_axes(ax.figure).index(ax)
12✔
44
    triple = nrow, ncol, index + 1
12✔
45
    return triple
12✔
46

47

48
def replace_with_ternary_axis(ax):
12✔
49
    """
50
    Replace a specified axis with a ternary equivalent.
51

52
    Parameters
53
    ------------
54
    ax : :class:`~matplotlib.axes.Axes`
55

56
    Returns
57
    ------------
58
    tax : :class:`~mpltern.ternary.TernaryAxes`
59
    """
60
    if ax.name != "ternary":
12✔
61
        if not check_default_axes(ax):
12✔
UNCOV
62
            if not check_empty(ax):
×
UNCOV
63
                warnings.warn(
×
64
                    "Non-empty, non-default bivariate axes being replaced with ternary axes."
65
                )
66
            else:
UNCOV
67
                logger.info(
×
68
                    "Non-default bivraite axes being replaced with ternary axes."
69
                )
70
    fig = ax.figure
12✔
71
    axes = get_ordered_axes(fig)
12✔
72
    idx = axes.index(ax)
12✔
73
    tax = fig.add_subplot(*get_axes_index(ax), projection="ternary")
12✔
74
    fig.add_axes(tax)  # make sure the axis is added to fig.children
12✔
75
    fig.delaxes(ax)  # remove the original axes
12✔
76
    # update figure ordered axes
77
    fig.orderedaxes = [a if ix != idx else tax for (ix, a) in enumerate(axes)]
12✔
78
    return tax
12✔
79

80

81
def label_axes(ax, labels=[], **kwargs):
12✔
82
    """
83
    Convenience function for labelling rectilinear and ternary axes.
84

85
    Parameters
86
    -----------
87
    ax : :class:`~matplotlib.axes.Axes`
88
        Axes to label.
89
    labels : :class:`list`
90
        List of labels: [x, y] | or [t, l, r]
91
    """
92
    if (ax.name == "ternary") and (len(labels) == 3):
12✔
93
        tvar, lvar, rvar = labels
12✔
94
        ax.set_tlabel(tvar, **kwargs)
12✔
95
        ax.set_llabel(lvar, **kwargs)
12✔
96
        ax.set_rlabel(rvar, **kwargs)
12✔
97
    elif len(labels) == 2:
12✔
98
        xvar, yvar = labels
12✔
99
        ax.set_xlabel(xvar, **kwargs)
12✔
100
        ax.set_ylabel(yvar, **kwargs)
12✔
101
    else:
UNCOV
102
        raise NotImplementedError
×
103

104

105
def axes_to_ternary(ax):
12✔
106
    """
107
    Set axes to ternary projection after axis creation. As currently implemented,
108
    note that this will replace and reorder axes as acecessed from the figure (the
109
    ternary axis will then be at the end), and as such this returns a list of axes
110
    in the correct order.
111

112
    Parameters
113
    -----------
114
    ax : :class:`~matplotlib.axes.Axes` | :class:`list` (:class:`~matplotlib.axes.Axes`)
115
        Axis (or axes) to convert projection for.
116

117
    Returns
118
    ---------
119
    axes : :class:`list' (:class:`~matplotlib.axes.Axes`, class:`~mpltern.ternary.TernaryAxes`)
120
    """
121

122
    if isinstance(ax, (list, np.ndarray, tuple)):  # multiple Axes specified
12✔
123
        fig = ax[0].figure
12✔
124
        for a in ax:  # axes to set to ternary
12✔
125
            replace_with_ternary_axis(a)
12✔
126
    else:  # a single Axes is passed
127
        fig = ax.figure
12✔
128
        replace_with_ternary_axis(ax)
12✔
129
    return fig.orderedaxes
12✔
130

131

132
def check_default_axes(ax):
12✔
133
    """
134
    Simple test to check whether an axis is empty of artists and hasn't been
135
    rescaled from the default extent.
136

137
    Parameters
138
    -----------
139
    ax : :class:`matplotlib.axes.Axes`
140
        Axes to check for artists and scaling.
141

142
    Returns
143
    -------
144
    :class:`bool`
145
    """
146

147
    if np.allclose(ax.axis(), np.array([0, 1, 0, 1])):
12✔
148
        return check_empty(ax)
12✔
149
    else:
UNCOV
150
        return False
×
151

152

153
def check_empty(ax):
12✔
154
    """
155
    Simple test to check whether an axis is empty of artists.
156

157
    Parameters
158
    -----------
159
    ax : :class:`matplotlib.axes.Axes`
160
        Axes to check for artists.
161

162
    Returns
163
    -------
164
    :class:`bool`
165
    """
166
    if not (ax.lines + ax.collections + ax.patches + ax.artists + ax.texts + ax.images):
12✔
167
        return True
12✔
168
    else:
UNCOV
169
        return False
×
170

171

172
def init_axes(ax=None, projection=None, minsize=1.0, **kwargs):
12✔
173
    """
174
    Get or create an Axes from an optionally-specified starting Axes.
175

176
    Parameters
177
    -----------
178
    ax : :class:`~matplotlib.axes.Axes`
179
        Specified starting axes, optional.
180
    projection : :class:`str`
181
        Whether to create a projected (e.g. ternary) axes.
182
    minsize : :class:`float`
183
        Minimum figure dimension (inches).
184

185
    Returns
186
    --------
187
    ax : :class:`~matplotlib.axes.Axes`
188
    """
189
    if "figsize" in kwargs.keys():
12✔
190
        fs = kwargs["figsize"]
12✔
191
        kwargs["figsize"] = (
12✔
192
            max(fs[0], minsize),
193
            max(fs[1], minsize),
194
        )  # minimum figsize
195
    if projection is not None:  # e.g. ternary
12✔
196
        if ax is None:
12✔
197
            fig, ax = plt.subplots(
12✔
198
                1,
199
                subplot_kw=dict(projection=projection),
200
                **subkwargs(kwargs, plt.subplots, plt.figure)
201
            )
202
        else:  # axes passed
203
            if ax.name != "ternary":
12✔
204
                # if an axis is converted previously, but the original axes reference
205
                # is used again, we'll end up with an error
206
                current_axes = get_ordered_axes(ax.figure)
12✔
207
                try:
12✔
208
                    ix = current_axes.index(ax)
12✔
209
                    axes = axes_to_ternary(ax)  # returns list of axes
12✔
210
                    ax = axes[ix]
12✔
UNCOV
211
                except ValueError:  # ax is not in list
×
212
                    # ASSUMPTION due to mis-referencing:
213
                    # take the first ternary one
UNCOV
214
                    ax = [a for a in current_axes if a.name == "ternary"][0]
×
215
            else:
216
                pass
12✔
217
    else:
218
        if ax is None:
12✔
219
            fig, ax = plt.subplots(1, **subkwargs(kwargs, plt.subplots, plt.figure))
12✔
220
    return ax
12✔
221

222

223
def share_axes(axes, which="xy"):
12✔
224
    """
225
    Link the x, y or both axes across a group of :class:`~matplotlib.axes.Axes`.
226

227
    Parameters
228
    -----------
229
    axes : :class:`list`
230
        List of axes to link.
231
    which : :class:`str`
232
        Which axes to link. If :code:`x`, link the x-axes; if :code:`y` link the y-axes,
233
        otherwise link both.
234
    """
235
    if which == "both":
12✔
236
        which = "xy"
12✔
237
    if "x" in which:
12✔
238
        [a.sharex(axes[0]) for a in axes[1:]]
12✔
239
    if "y" in which:
12✔
240
        [a.sharey(axes[0]) for a in axes[1:]]
12✔
241

242

243
def get_twins(ax, which="y"):
12✔
244
    """
245
    Get twin axes of a specified axis.
246

247
    Parameters
248
    -----------
249
    ax : :class:`matplotlib.axes.Axes`
250
        Axes to get twins for.
251
    which : :class:`str`
252
        Which twins to get (shared :code:`'x'`, shared :code:`'y'` or the concatenatation
253
        of both, :code:`'xy'`).
254

255
    Returns
256
    --------
257
    :class:`list`
258

259
    Notes
260
    ------
261
    This function was designed to assist in avoiding creating a series of duplicate
262
    axes when replotting on an existing axis using a function which would typically
263
    create a twin axis.
264
    """
265
    s = []
12✔
266
    if "y" in which:
12✔
267
        s += ax.get_shared_y_axes().get_siblings(ax)
12✔
268
    if "x" in which:
12✔
269
        s += ax.get_shared_x_axes().get_siblings(ax)
12✔
270
    return list(
12✔
271
        set([a for a in s if (a is not ax) & (a.bbox.bounds == ax.bbox.bounds)])
272
    )
273

274

275
def subaxes(ax, side="bottom", width=0.2, moveticks=True):
12✔
276
    """
277
    Append a sub-axes to one side of an axes.
278

279
    Parameters
280
    -----------
281
    ax : :class:`matplotlib.axes.Axes`
282
        Axes to append a sub-axes to.
283
    side : :class:`str`
284
        Which side to append the axes on.
285
    width : :class:`float`
286
        Fraction of width to give to the subaxes.
287
    moveticks : :class:`bool`
288
        Whether to move ticks to the outer axes.
289

290
    Returns
291
    -------
292
    :class:`matplotlib.axes.Axes`
293
        Subaxes instance.
294
    """
295
    div = make_axes_locatable(ax)
12✔
296
    ax.divider = div
12✔
297

298
    if side in ["bottom", "top"]:
12✔
299
        which = "x"
12✔
300
        subax = div.append_axes(side, width, pad=0, sharex=ax)
12✔
301
        div.subax = subax
12✔
302
        subax.yaxis.set_visible(False)
12✔
303
        subax.spines["left"].set_visible(False)
12✔
304
        subax.spines["right"].set_visible(False)
12✔
305

306
    else:
307
        which = "y"
12✔
308
        subax = div.append_axes(side, width, pad=0, sharex=ax)
12✔
309
        div.subax = subax
12✔
310
        subax.yaxis.set_visible(False)
12✔
311
        subax.spines["top"].set_visible(False)
12✔
312
        subax.spines["bottom"].set_visible(False)
12✔
313

314
    share_axes([ax, subax], which=which)
12✔
315
    if moveticks:
12✔
316
        ax.tick_params(
12✔
317
            axis=which, which="both", bottom=False, top=False, labelbottom=False
318
        )
319
    return subax
12✔
320

321

322
def add_colorbar(mappable, **kwargs):
12✔
323
    """
324
    Adds a colorbar to a given mappable object.
325

326
    Source: http://joseph-long.com/writing/colorbars/
327

328
    Parameters
329
    ----------
330
    mappable
331
        The Image, ContourSet, etc. to which the colorbar applies.
332

333
    Returns
334
    -------
335
    :class:`matplotlib.colorbar.Colorbar`
336

337
    Todo
338
    ----
339
    *  Where no mappable specificed, get most recent axes, and check for collections etc
340
    """
341
    ax = kwargs.get("ax", None)
12✔
342
    if hasattr(mappable, "axes"):
12✔
343
        ax = ax or mappable.axes
12✔
344
    elif hasattr(mappable, "ax"):
×
345
        ax = ax or mappable.ax
×
346

347
    position = kwargs.pop("position", "right")
12✔
348
    size = kwargs.pop("size", "5%")
12✔
349
    pad = kwargs.pop("pad", 0.05)
12✔
350

351
    fig = ax.figure
12✔
352
    if ax.name == "ternary":
12✔
353
        cax = ax.inset_axes([1.05, 0.1, 0.05, 0.9], transform=ax.transAxes)
12✔
354
        colorbar = fig.colorbar(mappable, cax=cax, **kwargs)
12✔
355
    else:
356
        divider = make_axes_locatable(ax)
12✔
357
        cax = divider.append_axes(position, size=size, pad=pad)
12✔
358
        colorbar = fig.colorbar(mappable, cax=cax, **kwargs)
12✔
359
    return colorbar
12✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc