• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9408

pending completion
9408

Pull #3652

travis-ci

web-flow
Metaclass-based solution.

Alternative approach to solving the context stack issues.
The previous version was a dead end, because there was no effective way to check the appropriateness of the classes on a single context stack.
So in this version, I split the context stacks into two, one for the pm.Models, and one for pm.distributions.distributions._DrawValuesContext.
This works better, but involved replacing the Context *class* with the ContextMeta parameterized metaclass.
Pull Request #3652: Fix context stack

114 of 114 new or added lines in 5 files covered. (100.0%)

52467 of 100945 relevant lines covered (51.98%)

2.18 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pymc3/tests/test_ndarray_backend.py
1
import numpy as np
×
2
import numpy.testing as npt
×
3
from pymc3.tests import backend_fixtures as bf
×
4
from pymc3.backends import base, ndarray
×
5
import pymc3 as pm
×
6
import pytest
×
7

8

9
STATS1 = [{
×
10
    'a': np.float64,
11
    'b': np.bool
12
}]
13

14
STATS2 = [{
×
15
    'a': np.float64
16
}, {
17
    'a': np.float64,
18
    'b': np.int64,
19
}]
20

21

22
class TestNDArray0dSampling(bf.SamplingTestCase):
×
23
    backend = ndarray.NDArray
×
24
    name = None
×
25
    shape = ()
×
26

27

28
class TestNDArray0dSamplingStats1(bf.SamplingTestCase):
×
29
    backend = ndarray.NDArray
×
30
    name = None
×
31
    sampler_vars = STATS1
×
32
    shape = ()
×
33

34

35
class TestNDArray0dSamplingStats2(bf.SamplingTestCase):
×
36
    backend = ndarray.NDArray
×
37
    name = None
×
38
    sampler_vars = STATS2
×
39
    shape = ()
×
40

41

42
class TestNDArray1dSampling(bf.SamplingTestCase):
×
43
    backend = ndarray.NDArray
×
44
    name = None
×
45
    shape = 2
×
46

47

48
class TestNDArray2dSampling(bf.SamplingTestCase):
×
49
    backend = ndarray.NDArray
×
50
    name = None
×
51
    shape = (2, 3)
×
52

53

54
class TestNDArrayStats(bf.StatsTestCase):
×
55
    backend = ndarray.NDArray
×
56
    name = None
×
57
    shape = (2, 3)
×
58

59

60
class TestNDArray0dSelection(bf.SelectionTestCase):
×
61
    backend = ndarray.NDArray
×
62
    name = None
×
63
    shape = ()
×
64
    sampler_vars = STATS1
×
65

66

67
class TestNDArray0dSelection2(bf.SelectionTestCase):
×
68
    backend = ndarray.NDArray
×
69
    name = None
×
70
    shape = ()
×
71
    sampler_vars = STATS2
×
72

73

74
class TestNDArray0dSelectionStats1(bf.SelectionTestCase):
×
75
    backend = ndarray.NDArray
×
76
    name = None
×
77
    shape = ()
×
78
    sampler_vars = STATS2
×
79

80

81
class TestNDArray0dSelectionStats2(bf.SelectionTestCase):
×
82
    backend = ndarray.NDArray
×
83
    name = None
×
84
    shape = ()
×
85

86

87
class TestNDArray1dSelection(bf.SelectionTestCase):
×
88
    backend = ndarray.NDArray
×
89
    name = None
×
90
    shape = 2
×
91

92

93
class TestNDArray2dSelection(bf.SelectionTestCase):
×
94
    backend = ndarray.NDArray
×
95
    name = None
×
96
    shape = (2, 3)
×
97

98

99
class TestMultiTrace(bf.ModelBackendSetupTestCase):
×
100
    name = None
×
101
    backend = ndarray.NDArray
×
102
    shape = ()
×
103

104
    def setup_method(self):
×
105
        super().setup_method()
×
106
        self.strace0 = self.strace
×
107

108
        super().setup_method()
×
109
        self.strace1 = self.strace
×
110

111
    def test_multitrace_nonunique(self):
×
112
        with pytest.raises(ValueError):
×
113
            base.MultiTrace([self.strace0, self.strace1])
×
114

115
    def test_merge_traces_no_traces(self):
×
116
        with pytest.raises(ValueError):
×
117
            base.merge_traces([])
×
118

119
    def test_merge_traces_diff_lengths(self):
×
120
        with self.model:
×
121
            strace0 = self.backend(self.name)
×
122
            strace0.setup(self.draws, 1)
×
123
            for i in range(self.draws):
×
124
                strace0.record(self.test_point)
×
125
            strace0.close()
×
126
        mtrace0 = base.MultiTrace([self.strace0])
×
127

128
        with self.model:
×
129
            strace1 = self.backend(self.name)
×
130
            strace1.setup(2 * self.draws, 1)
×
131
            for i in range(2 * self.draws):
×
132
                strace1.record(self.test_point)
×
133
            strace1.close()
×
134
        mtrace1 = base.MultiTrace([strace1])
×
135

136
        with pytest.raises(ValueError):
×
137
            base.merge_traces([mtrace0, mtrace1])
×
138

139
    def test_merge_traces_nonunique(self):
×
140
        mtrace0 = base.MultiTrace([self.strace0])
×
141
        mtrace1 = base.MultiTrace([self.strace1])
×
142

143
        with pytest.raises(ValueError):
×
144
            base.merge_traces([mtrace0, mtrace1])
×
145

146

147
class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
×
148
    name = None
×
149
    backend = ndarray.NDArray
×
150
    shape = ()
×
151

152
    def test_add_values(self):
×
153
        mtrace = self.mtrace
×
154
        orig_varnames = list(mtrace.varnames)
×
155
        name = 'new_var'
×
156
        vals = mtrace[orig_varnames[0]]
×
157
        mtrace.add_values({name: vals})
×
158
        assert len(orig_varnames) == len(mtrace.varnames) - 1
×
159
        assert name in mtrace.varnames
×
160
        assert np.all(mtrace[orig_varnames[0]] == mtrace[name])
×
161
        mtrace.remove_values(name)
×
162
        assert len(orig_varnames) == len(mtrace.varnames)
×
163
        assert name not in mtrace.varnames
×
164

165

166
class TestSqueezeCat:
×
167

168
    def setup_method(self):
×
169
        self.x = np.arange(10)
×
170
        self.y = np.arange(10, 20)
×
171

172
    def test_combine_false_squeeze_false(self):
×
173
        expected = [self.x, self.y]
×
174
        result = base._squeeze_cat([self.x, self.y], False, False)
×
175
        npt.assert_equal(result, expected)
×
176

177
    def test_combine_true_squeeze_false(self):
×
178
        expected = [np.concatenate([self.x, self.y])]
×
179
        result = base._squeeze_cat([self.x, self.y], True, False)
×
180
        npt.assert_equal(result, expected)
×
181

182
    def test_combine_false_squeeze_true_more_than_one_item(self):
×
183
        expected = [self.x, self.y]
×
184
        result = base._squeeze_cat([self.x, self.y], False, True)
×
185
        npt.assert_equal(result, expected)
×
186

187
    def test_combine_false_squeeze_true_one_item(self):
×
188
        expected = self.x
×
189
        result = base._squeeze_cat([self.x], False, True)
×
190
        npt.assert_equal(result, expected)
×
191

192
    def test_combine_true_squeeze_true(self):
×
193
        expected = np.concatenate([self.x, self.y])
×
194
        result = base._squeeze_cat([self.x, self.y], True, True)
×
195
        npt.assert_equal(result, expected)
×
196

197
class TestSaveLoad:
×
198
    @staticmethod
×
199
    def model():
200
        with pm.Model() as model:
×
201
            x = pm.Normal('x', 0, 1)
×
202
            y = pm.Normal('y', x, 1, observed=2)
×
203
            z = pm.Normal('z', x + y, 1)
×
204
        return model
×
205

206
    @classmethod
×
207
    def setup_class(cls):
208
        with TestSaveLoad.model():
×
209
            cls.trace = pm.sample()
×
210

211
    def test_save_new_model(self, tmpdir_factory):
×
212
        directory = str(tmpdir_factory.mktemp('data'))
×
213
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)
×
214

215
        assert save_dir == directory
×
216
        with pm.Model() as model:
×
217
            w = pm.Normal('w', 0, 1)
×
218
            new_trace = pm.sample()
×
219

220
        with pytest.raises(OSError):
×
221
            _ = pm.save_trace(new_trace, directory)
×
222

223
        _ = pm.save_trace(new_trace, directory, overwrite=True)
×
224
        with model:
×
225
            new_trace_copy = pm.load_trace(directory)
×
226

227
        assert (new_trace['w'] == new_trace_copy['w']).all()
×
228

229
    def test_save_and_load(self, tmpdir_factory):
×
230
        directory = str(tmpdir_factory.mktemp('data'))
×
231
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)
×
232

233
        assert save_dir == directory
×
234

235
        trace2 = pm.load_trace(directory, model=TestSaveLoad.model())
×
236

237
        for var in ('x', 'z'):
×
238
            assert (self.trace[var] == trace2[var]).all()
×
239

240
    def test_bad_load(self, tmpdir_factory):
×
241
        directory = str(tmpdir_factory.mktemp('data'))
×
242
        with pytest.raises(pm.TraceDirectoryError):
×
243
            pm.load_trace(directory, model=TestSaveLoad.model())
×
244

245
    def test_sample_posterior_predictive(self, tmpdir_factory):
×
246
        directory = str(tmpdir_factory.mktemp('data'))
×
247
        save_dir = pm.save_trace(self.trace, directory, overwrite=True)
×
248

249
        assert save_dir == directory
×
250

251
        seed = 10
×
252
        np.random.seed(seed)
×
253
        with TestSaveLoad.model():
×
254
            ppc = pm.sample_posterior_predictive(self.trace)
×
255

256
        seed = 10
×
257
        np.random.seed(seed)
×
258
        with TestSaveLoad.model():
×
259
            trace2 = pm.load_trace(directory)
×
260
            ppc2 = pm.sample_posterior_predictive(trace2)
×
261

262
        for key, value in ppc.items():
×
263
            assert (value == ppc2[key]).all()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc