• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 7965

pending completion
7965

Pull #3086

travis-ci

web-flow
Merge branch 'master' into release_3.4.2
Pull Request #3086: Release 3.5

1 of 1 new or added line in 1 file covered. (100.0%)

17647 of 19785 relevant lines covered (89.19%)

4.31 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.94
/pymc3/plots/pairplot.py
1
import warnings
10✔
2

3
try:
10✔
4
    import matplotlib.pyplot as plt
10✔
5
    from matplotlib import gridspec
10✔
6
except ImportError:  # mpl is optional
×
7
    pass
×
8
from ..util import get_default_varnames, is_transformed_name, get_untransformed_name
10✔
9
from .artists import get_trace_dict, scale_text
10✔
10

11

12
def pairplot(trace, varnames=None, figsize=None, text_size=None,
10✔
13
             gs=None, ax=None, hexbin=False, plot_transformed=False,
14
             divergences=False, kwargs_divergence=None,
15
             sub_varnames=None, **kwargs):
16
    """
17
    Plot a scatter or hexbin matrix of the sampled parameters.
18

19
    Parameters
20
    ----------
21

22
    trace : result of MCMC run
23
    varnames : list of variable names
24
        Variables to be plotted, if None all variable are plotted
25
    figsize : figure size tuple
26
        If None, size is (8 + numvars, 8 + numvars)
27
    text_size: int
28
        Text size for labels
29
    gs : Grid spec
30
        Matplotlib Grid spec.
31
    ax: axes
32
        Matplotlib axes
33
    hexbin : Boolean
34
        If True draws an hexbin plot
35
    plot_transformed : bool
36
        Flag for plotting automatically transformed variables in addition to
37
        original variables (defaults to False). Applies when varnames = None.
38
        When a list of varnames is passed, transformed variables can be passed
39
        using their names.
40
    divergences : Boolean
41
        If True divergences will be plotted in a diferent color
42
    kwargs_divergence : dicts, optional
43
        Aditional keywords passed to ax.scatter for divergences
44
    sub_varnames : list
45
        Aditional varnames passed for plotting subsets of multidimensional
46
        variables
47
    Returns
48
    -------
49

50
    ax : matplotlib axes
51
    gs : matplotlib gridspec
52

53
    """
54
    if varnames is None:
3✔
55
        if plot_transformed:
3✔
56

57
            varnames_copy = list(trace.varnames)
3✔
58
            remove = [get_untransformed_name(var) for var in trace.varnames
3✔
59
                      if is_transformed_name(var)]
60

61
            try:
3✔
62
                [varnames_copy.remove(i) for i in remove]
3✔
63
                varnames = varnames_copy
3✔
64
            except ValueError:
×
65
                varnames = varnames_copy
×
66

67
            trace_dict = get_trace_dict(
3✔
68
                trace, get_default_varnames(
69
                    varnames, plot_transformed))
70

71
        else:
72
            trace_dict = get_trace_dict(
3✔
73
                trace, get_default_varnames(
74
                    trace.varnames, plot_transformed))
75

76
        if sub_varnames is None:
3✔
77
            varnames = list(trace_dict.keys())
3✔
78

79
        else:
80
            trace_dict = get_trace_dict(
3✔
81
                trace, get_default_varnames(
82
                    trace.varnames, True))
83
            varnames = sub_varnames
3✔
84

85
    else:
86
        trace_dict = get_trace_dict(trace, varnames)
×
87
        varnames = list(trace_dict.keys())
×
88

89
    if text_size is None:
3✔
90
        text_size = scale_text(figsize, text_size=text_size)
3✔
91

92
    if kwargs_divergence is None:
3✔
93
        kwargs_divergence = {}
3✔
94

95
    numvars = len(varnames)
3✔
96

97
    if figsize is None:
3✔
98
        figsize = (8 + numvars, 8 + numvars)
3✔
99

100
    if numvars < 2:
3✔
101
        raise Exception(
×
102
            'Number of variables to be plotted must be 2 or greater.')
103

104
    if numvars == 2 and ax is not None:
3✔
105
        if hexbin:
×
106
            ax.hexbin(trace_dict[varnames[0]],
×
107
                      trace_dict[varnames[1]], mincnt=1, **kwargs)
108
        else:
109
            ax.scatter(trace_dict[varnames[0]],
×
110
                       trace_dict[varnames[1]], **kwargs)
111

112
        if divergences:
×
113
            try:
×
114
                divergent = trace['diverging']
×
115
            except KeyError:
×
116
                warnings.warn('No divergences were found.')
×
117

118
            diverge = (divergent == 1)
×
119
            ax.scatter(trace_dict[varnames[0]][diverge],
×
120
                       trace_dict[varnames[1]][diverge], **kwargs_divergence)
121
        ax.set_xlabel('{}'.format(varnames[0]),
×
122
                      fontsize=text_size)
123
        ax.set_ylabel('{}'.format(
×
124
            varnames[1]), fontsize=text_size)
125
        ax.tick_params(labelsize=text_size)
×
126

127
    if gs is None and ax is None:
3✔
128
        plt.figure(figsize=figsize)
3✔
129
        gs = gridspec.GridSpec(numvars - 1, numvars - 1)
3✔
130

131
        for i in range(0, numvars - 1):
3✔
132
            var1 = trace_dict[varnames[i]]
3✔
133

134
            for j in range(i, numvars - 1):
3✔
135
                var2 = trace_dict[varnames[j + 1]]
3✔
136

137
                ax = plt.subplot(gs[j, i])
3✔
138

139
                if hexbin:
3✔
140
                    ax.hexbin(var1, var2, mincnt=1, **kwargs)
3✔
141
                else:
142
                    ax.scatter(var1, var2, **kwargs)
3✔
143

144
                if divergences:
3✔
145
                    try:
×
146
                        divergent = trace['diverging']
×
147
                    except KeyError:
×
148
                        warnings.warn('No divergences were found.')
×
149
                        return ax
×
150

151
                    diverge = (divergent == 1)
×
152
                    ax.scatter(var1[diverge],
×
153
                               var2[diverge],
154
                               **kwargs_divergence)
155

156
                if j + 1 != numvars - 1:
3✔
157
                    ax.set_xticks([])
3✔
158
                else:
159
                    ax.set_xlabel('{}'.format(varnames[i]),
3✔
160
                                  fontsize=text_size)
161
                if i != 0:
3✔
162
                    ax.set_yticks([])
3✔
163
                else:
164
                    ax.set_ylabel('{}'.format(
3✔
165
                        varnames[j + 1]), fontsize=text_size)
166

167
                ax.tick_params(labelsize=text_size)
3✔
168

169
    plt.tight_layout()
3✔
170
    return ax, gs
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc