• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9391

pending completion
9391

Pull #3638

travis-ci

web-flow
Drop first dimension when computing determinant of the Jacobian of the transformation.
Pull Request #3638: Simple stick breaking (Formerly #3620)

23 of 23 new or added lines in 1 file covered. (100.0%)

52178 of 100270 relevant lines covered (52.04%)

2.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/pymc3/tests/backend_fixtures.py
1
import numpy as np
×
2
import numpy.testing as npt
×
3
import os
×
4
import shutil
×
5
import collections
×
6

7
from pymc3.tests import models
×
8
from pymc3.backends import base
×
9
import pytest
×
10
import theano
×
11

12

13
class ModelBackendSetupTestCase:
×
14
    """Set up a backend trace.
15

16
    Provides the attributes
17
    - test_point
18
    - model
19
    - strace
20
    - draws
21

22
    Children must define
23
    - backend
24
    - name
25
    - shape
26

27
    Children may define
28
    - sampler_vars
29
    """
30

31
    def setup_method(self):
×
32
        self.test_point, self.model, _ = models.beta_bernoulli(self.shape)
×
33
        with self.model:
×
34
            self.strace = self.backend(self.name)
×
35
        self.draws, self.chain = 3, 0
×
36
        if not hasattr(self, 'sampler_vars'):
×
37
            self.sampler_vars = None
×
38
        if self.sampler_vars is not None:
×
39
            assert self.strace.supports_sampler_stats
×
40
            self.strace.setup(self.draws, self.chain, self.sampler_vars)
×
41
        else:
42
            self.strace.setup(self.draws, self.chain)
×
43

44
    def test_append_invalid(self):
×
45
        if self.sampler_vars is not None:
×
46
            with pytest.raises(ValueError):
×
47
                self.strace.setup(self.draws, self.chain)
×
48
            with pytest.raises(ValueError):
×
49
                vars = self.sampler_vars + [{'a': np.bool}]
×
50
                self.strace.setup(self.draws, self.chain, vars)
×
51
        else:
52
            with pytest.raises((ValueError, TypeError)):
×
53
                self.strace.setup(self.draws, self.chain, [{'a': np.bool}])
×
54

55
    def test_append(self):
×
56
        if self.sampler_vars is None:
×
57
            self.strace.setup(self.draws, self.chain)
×
58
            assert len(self.strace) == 0
×
59
        else:
60
            self.strace.setup(self.draws, self.chain, self.sampler_vars)
×
61
            assert len(self.strace) == 0
×
62

63
    def test_double_close(self):
×
64
        self.strace.close()
×
65
        self.strace.close()
×
66

67
    def teardown_method(self):
×
68
        if self.name is not None:
×
69
            remove_file_or_directory(self.name)
×
70

71

72
class StatsTestCase:
×
73
    """Test for init and setup of backups.
74

75
    Provides the attributes
76
    - test_point
77
    - model
78
    - draws
79

80
    Children must define
81
    - backend
82
    - name
83
    - shape
84
    """
85
    def setup_method(self):
×
86
        self.test_point, self.model, _ = models.beta_bernoulli(self.shape)
×
87
        self.draws, self.chain = 3, 0
×
88

89
    def test_bad_dtype(self):
×
90
        bad_vars = [{'a': np.float64}, {'a': np.bool}]
×
91
        good_vars = [{'a': np.float64}, {'a': np.float64}]
×
92
        with self.model:
×
93
            strace = self.backend(self.name)
×
94
        with pytest.raises((ValueError, TypeError)):
×
95
            strace.setup(self.draws, self.chain, bad_vars)
×
96
        strace.setup(self.draws, self.chain, good_vars)
×
97
        if strace.supports_sampler_stats:
×
98
            assert strace.stat_names == set(['a'])
×
99
        else:
100
            with pytest.raises((ValueError, TypeError)):
×
101
                strace.setup(self.draws, self.chain, good_vars)
×
102

103
    def teardown_method(self):
×
104
        if self.name is not None:
×
105
            remove_file_or_directory(self.name)
×
106

107

108
class ModelBackendSampledTestCase:
×
109
    """Setup and sample a backend trace.
110

111
    Provides the attributes
112
    - test_point
113
    - model
114
    - mtrace (MultiTrace object)
115
    - draws
116
    - expected
117
        Expected values mapped to chain number and variable name.
118
    - stat_dtypes
119

120
    Children must define
121
    - backend
122
    - name
123
    - shape
124

125
    Children may define
126
    - sampler_vars
127
    - write_partial_chain
128
    """
129
    @classmethod
×
130
    def setup_class(cls):
131
        cls.test_point, cls.model, _ = models.beta_bernoulli(cls.shape)
×
132

133
        if hasattr(cls, 'write_partial_chain') and cls.write_partial_chain is True:
×
134
            cls.chain_vars = cls.model.unobserved_RVs[1:]
×
135
        else:
136
            cls.chain_vars = cls.model.unobserved_RVs
×
137

138
        with cls.model:
×
139
            strace0 = cls.backend(cls.name, vars=cls.chain_vars)
×
140
            strace1 = cls.backend(cls.name, vars=cls.chain_vars)
×
141

142
        if not hasattr(cls, 'sampler_vars'):
×
143
            cls.sampler_vars = None
×
144

145
        cls.draws = 5
×
146
        if cls.sampler_vars is not None:
×
147
            strace0.setup(cls.draws, chain=0, sampler_vars=cls.sampler_vars)
×
148
            strace1.setup(cls.draws, chain=1, sampler_vars=cls.sampler_vars)
×
149
        else:
150
            strace0.setup(cls.draws, chain=0)
×
151
            strace1.setup(cls.draws, chain=1)
×
152

153
        varnames = list(cls.test_point.keys())
×
154
        shapes = {varname: value.shape
×
155
                  for varname, value in cls.test_point.items()}
156
        dtypes = {varname: value.dtype
×
157
                  for varname, value in cls.test_point.items()}
158

159
        cls.expected = {0: {}, 1: {}}
×
160
        for varname in varnames:
×
161
            mcmc_shape = (cls.draws,) + shapes[varname]
×
162
            values = np.arange(cls.draws * np.prod(shapes[varname]),
×
163
                               dtype=dtypes[varname])
164
            cls.expected[0][varname] = values.reshape(mcmc_shape)
×
165
            cls.expected[1][varname] = values.reshape(mcmc_shape) * 100
×
166

167
        if cls.sampler_vars is not None:
×
168
            cls.expected_stats = {0: [], 1: []}
×
169
            for vars in cls.sampler_vars:
×
170
                stats = {}
×
171
                cls.expected_stats[0].append(stats)
×
172
                cls.expected_stats[1].append(stats)
×
173
                for key, dtype in vars.items():
×
174
                    if dtype == np.bool:
×
175
                        stats[key] = np.zeros(cls.draws, dtype=dtype)
×
176
                    else:
177
                        stats[key] = np.arange(cls.draws, dtype=dtype)
×
178

179

180
        for idx in range(cls.draws):
×
181
            point0 = {varname: cls.expected[0][varname][idx, ...]
×
182
                      for varname in varnames}
183
            point1 = {varname: cls.expected[1][varname][idx, ...]
×
184
                      for varname in varnames}
185
            if cls.sampler_vars is not None:
×
186
                stats1 = [dict((key, val[idx]) for key, val in stats.items())
×
187
                          for stats in cls.expected_stats[0]]
188
                stats2 = [dict((key, val[idx]) for key, val in stats.items())
×
189
                          for stats in cls.expected_stats[1]]
190
                strace0.record(point=point0, sampler_stats=stats1)
×
191
                strace1.record(point=point1, sampler_stats=stats2)
×
192
            else:
193
                strace0.record(point=point0)
×
194
                strace1.record(point=point1)
×
195
        strace0.close()
×
196
        strace1.close()
×
197
        cls.mtrace = base.MultiTrace([strace0, strace1])
×
198

199
        cls.stat_dtypes = {}
×
200
        cls.stats_counts = collections.Counter()
×
201
        for stats in cls.sampler_vars or []:
×
202
            cls.stat_dtypes.update(stats)
×
203
            cls.stats_counts.update(stats.keys())
×
204

205
    @classmethod
×
206
    def teardown_class(cls):
207
        if cls.name is not None:
×
208
            remove_file_or_directory(cls.name)
×
209

210
    def test_varnames_nonempty(self):
×
211
        # Make sure the test_point has variables names because many
212
        # tests rely on looping through these and would pass silently
213
        # if the loop is never entered.
214
        assert list(self.test_point.keys())
×
215

216
    def test_stat_names(self):
×
217
        names = set()
×
218
        for vars in self.sampler_vars or []:
×
219
            names.update(vars.keys())
×
220
        assert self.mtrace.stat_names == names
×
221

222

223
class SamplingTestCase(ModelBackendSetupTestCase):
×
224
    """Test backend sampling.
225

226
    Children must define
227
    - backend
228
    - name
229
    - shape
230
    """
231

232
    def record_point(self, val):
×
233
        point = {varname: np.tile(val, value.shape)
×
234
                 for varname, value in self.test_point.items()}
235
        if self.sampler_vars is not None:
×
236
            stats = [dict((key, dtype(val)) for key, dtype in vars.items())
×
237
                     for vars in self.sampler_vars]
238
            self.strace.record(point=point, sampler_stats=stats)
×
239
        else:
240
            self.strace.record(point=point)
×
241

242
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
×
243
    def test_standard_close(self):
244
        for idx in range(self.draws):
×
245
            self.record_point(idx)
×
246
        self.strace.close()
×
247

248
        for varname in self.test_point.keys():
×
249
            npt.assert_equal(self.strace.get_values(varname)[0, ...],
×
250
                             np.zeros(self.strace.var_shapes[varname]))
251
            last_idx = self.draws - 1
×
252
            npt.assert_equal(self.strace.get_values(varname)[last_idx, ...],
×
253
                             np.tile(last_idx, self.strace.var_shapes[varname]))
254
        if self.sampler_vars:
×
255
            for varname in self.strace.stat_names:
×
256
                vals = self.strace.get_sampler_stats(varname)
×
257
                assert vals.shape[0] == self.draws
×
258

259
    def test_missing_stats(self):
×
260
        if self.sampler_vars is not None:
×
261
            with pytest.raises(ValueError):
×
262
                self.strace.record(point=self.test_point)
×
263

264
    def test_clean_interrupt(self):
×
265
        self.record_point(0)
×
266
        self.strace.close()
×
267
        for varname in self.test_point.keys():
×
268
            assert self.strace.get_values(varname).shape[0] == 1
×
269
        for statname in self.strace.stat_names:
×
270
            assert self.strace.get_sampler_stats(statname).shape[0] == 1
×
271

272

273
class SelectionTestCase(ModelBackendSampledTestCase):
×
274
    """Test backend selection.
275

276
    Children must define
277
    - backend
278
    - name
279
    - shape
280
    """
281

282
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
×
283
    def test_get_values_default(self):
284
        for varname in self.test_point.keys():
×
285
            expected = np.concatenate([self.expected[chain][varname]
×
286
                                       for chain in [0, 1]])
287
            result = self.mtrace.get_values(varname)
×
288
            npt.assert_equal(result, expected)
×
289

290
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
×
291
    def test_get_values_nocombine_burn_keyword(self):
292
        burn = 2
×
293
        for varname in self.test_point.keys():
×
294
            expected = [self.expected[0][varname][burn:],
×
295
                        self.expected[1][varname][burn:]]
296
            result = self.mtrace.get_values(varname, burn=burn, combine=False)
×
297
            npt.assert_equal(result, expected)
×
298

299
    def test_len(self):
×
300
        assert len(self.mtrace) == self.draws
×
301

302
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
×
303
    def test_dtypes(self):
304
        for varname in self.test_point.keys():
×
305
            assert self.expected[0][varname].dtype == \
×
306
                             self.mtrace.get_values(varname, chains=0).dtype
307

308
        for statname in self.mtrace.stat_names:
×
309
            assert self.stat_dtypes[statname] == \
×
310
                             self.mtrace.get_sampler_stats(statname, chains=0).dtype
311

312
    def test_get_values_nocombine_thin_keyword(self):
×
313
        thin = 2
×
314
        for varname in self.test_point.keys():
×
315
            expected = [self.expected[0][varname][::thin],
×
316
                        self.expected[1][varname][::thin]]
317
            result = self.mtrace.get_values(varname, thin=thin, combine=False)
×
318
            npt.assert_equal(result, expected)
×
319

320
    def test_get_point(self):
×
321
        idx = 2
×
322
        result = self.mtrace.point(idx)
×
323
        for varname in self.test_point.keys():
×
324
            expected = self.expected[1][varname][idx]
×
325
            npt.assert_equal(result[varname], expected)
×
326

327
    def test_get_slice(self):
×
328
        expected = []
×
329
        for chain in [0, 1]:
×
330
            expected.append({varname: self.expected[chain][varname][2:]
×
331
                             for varname in self.mtrace.varnames})
332
        result = self.mtrace[2:]
×
333
        for chain in [0, 1]:
×
334
            for varname in self.test_point.keys():
×
335
                npt.assert_equal(result.get_values(varname, chains=[chain]),
×
336
                                 expected[chain][varname])
337

338
    def test_get_slice_step(self):
×
339
        result = self.mtrace[:]
×
340
        assert len(result) == self.draws
×
341

342
        result = self.mtrace[::2]
×
343
        assert len(result) == self.draws // 2
×
344

345

346
    def test_get_slice_neg_step(self):
×
347
        if hasattr(self, 'skip_test_get_slice_neg_step'):
×
348
            return
×
349

350
        result = self.mtrace[::-1]
×
351
        assert len(result) == self.draws
×
352

353
        result = self.mtrace[::-2]
×
354
        assert len(result) == self.draws // 2
×
355

356

357
    def test_get_neg_slice(self):
×
358
        expected = []
×
359
        for chain in [0, 1]:
×
360
            expected.append({varname: self.expected[chain][varname][-2:]
×
361
                             for varname in self.mtrace.varnames})
362
        result = self.mtrace[-2:]
×
363
        for chain in [0, 1]:
×
364
            for varname in self.test_point.keys():
×
365
                npt.assert_equal(result.get_values(varname, chains=[chain]),
×
366
                                 expected[chain][varname])
367

368
    def test_get_values_one_chain(self):
×
369
        for varname in self.test_point.keys():
×
370
            expected = self.expected[0][varname]
×
371
            result = self.mtrace.get_values(varname, chains=[0])
×
372
            npt.assert_equal(result, expected)
×
373

374
    def test_get_values_nocombine_chains_reversed(self):
×
375
        for varname in self.test_point.keys():
×
376
            expected = [self.expected[1][varname], self.expected[0][varname]]
×
377
            result = self.mtrace.get_values(varname, chains=[1, 0],
×
378
                                            combine=False)
379
            npt.assert_equal(result, expected)
×
380

381
    def test_nchains(self):
×
382
        self.mtrace.nchains == 2
×
383

384
    def test_get_values_one_chain_int_arg(self):
×
385
        for varname in self.test_point.keys():
×
386
            npt.assert_equal(self.mtrace.get_values(varname, chains=[0]),
×
387
                             self.mtrace.get_values(varname, chains=0))
388

389
    def test_get_values_combine(self):
×
390
        for varname in self.test_point.keys():
×
391
            expected = np.concatenate([self.expected[chain][varname]
×
392
                                       for chain in [0, 1]])
393
            result = self.mtrace.get_values(varname, combine=True)
×
394
            npt.assert_equal(result, expected)
×
395

396
    def test_get_values_combine_burn_arg(self):
×
397
        burn = 2
×
398
        for varname in self.test_point.keys():
×
399
            expected = np.concatenate([self.expected[chain][varname][burn:]
×
400
                                       for chain in [0, 1]])
401
            result = self.mtrace.get_values(varname, combine=True, burn=burn)
×
402
            npt.assert_equal(result, expected)
×
403

404
    def test_get_values_combine_thin_arg(self):
×
405
        thin = 2
×
406
        for varname in self.test_point.keys():
×
407
            expected = np.concatenate([self.expected[chain][varname][::thin]
×
408
                                       for chain in [0, 1]])
409
            result = self.mtrace.get_values(varname, combine=True, thin=thin)
×
410
            npt.assert_equal(result, expected)
×
411

412
    def test_getitem_equivalence(self):
×
413
        mtrace = self.mtrace
×
414
        for varname in self.test_point.keys():
×
415
            npt.assert_equal(mtrace[varname],
×
416
                             mtrace.get_values(varname, combine=True))
417
            npt.assert_equal(mtrace[varname, 2:],
×
418
                             mtrace.get_values(varname, burn=2,
419
                                               combine=True))
420
            npt.assert_equal(mtrace[varname, 2::2],
×
421
                             mtrace.get_values(varname, burn=2, thin=2,
422
                                               combine=True))
423

424
    def test_selection_method_equivalence(self):
×
425
        varname = self.mtrace.varnames[0]
×
426
        mtrace = self.mtrace
×
427
        npt.assert_equal(mtrace.get_values(varname),
×
428
                         mtrace[varname])
429
        npt.assert_equal(mtrace[varname],
×
430
                         mtrace.__getattr__(varname))
431

432

433
class DumpLoadTestCase(ModelBackendSampledTestCase):
×
434
    """Test equality of a dumped and loaded trace with original.
435

436
    Children must define
437
    - backend
438
    - load_func
439
        Function to load dumped backend
440
    - name
441
    - shape
442
    """
443
    @classmethod
×
444
    def setup_class(cls):
445
        super().setup_class()
×
446
        try:
×
447
            with cls.model:
×
448
                cls.dumped = cls.load_func(cls.name)
×
449
        except:
×
450
            remove_file_or_directory(cls.name)
×
451
            raise
×
452

453
    @classmethod
×
454
    def teardown_class(cls):
455
        remove_file_or_directory(cls.name)
×
456

457
    def test_nchains(self):
×
458
        assert self.mtrace.nchains == self.dumped.nchains
×
459

460
    def test_varnames(self):
×
461
        trace_names = list(sorted(self.mtrace.varnames))
×
462
        dumped_names = list(sorted(self.dumped.varnames))
×
463
        assert trace_names == dumped_names
×
464

465
    def test_values(self):
×
466
        trace = self.mtrace
×
467
        dumped = self.dumped
×
468
        for chain in trace.chains:
×
469
            for varname in self.chain_vars:
×
470
                data = trace.get_values(varname, chains=[chain])
×
471
                dumped_data = dumped.get_values(varname, chains=[chain])
×
472
                npt.assert_equal(data, dumped_data)
×
473

474

475
class BackendEqualityTestCase(ModelBackendSampledTestCase):
×
476
    """Test equality of attirbutes from two backends.
477

478
    Children must define
479
    - backend0
480
    - backend1
481
    - name0
482
    - name1
483
    - shape
484
    """
485
    @classmethod
×
486
    def setup_class(cls):
487
        cls.backend = cls.backend0
×
488
        cls.name = cls.name0
×
489
        super().setup_class()
×
490
        cls.mtrace0 = cls.mtrace
×
491

492
        cls.backend = cls.backend1
×
493
        cls.name = cls.name1
×
494
        super().setup_class()
×
495
        cls.mtrace1 = cls.mtrace
×
496

497
    @classmethod
×
498
    def teardown_class(cls):
499
        for name in [cls.name0, cls.name1]:
×
500
            if name is not None:
×
501
                remove_file_or_directory(name)
×
502

503
    def test_chain_length(self):
×
504
        assert self.mtrace0.nchains == self.mtrace1.nchains
×
505
        assert len(self.mtrace0) == len(self.mtrace1)
×
506

507
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
×
508
    def test_dtype(self):
509
        for varname in self.test_point.keys():
×
510
            assert self.mtrace0.get_values(varname, chains=0).dtype == \
×
511
                             self.mtrace1.get_values(varname, chains=0).dtype
512

513
    def test_number_of_draws(self):
×
514
        for varname in self.test_point.keys():
×
515
            values0 = self.mtrace0.get_values(varname, combine=False,
×
516
                                              squeeze=False)
517
            values1 = self.mtrace1.get_values(varname, combine=False,
×
518
                                              squeeze=False)
519
            assert values0[0].shape[0] == self.draws
×
520
            assert values1[0].shape[0] == self.draws
×
521

522
    def test_get_item(self):
×
523
        for varname in self.test_point.keys():
×
524
            npt.assert_equal(self.mtrace0[varname], self.mtrace1[varname])
×
525

526
    def test_get_values(self):
×
527
        for varname in self.test_point.keys():
×
528
            for cf in [False, True]:
×
529
                npt.assert_equal(self.mtrace0.get_values(varname, combine=cf),
×
530
                                 self.mtrace1.get_values(varname, combine=cf))
531

532
    def test_get_values_no_squeeze(self):
×
533
        for varname in self.test_point.keys():
×
534
            npt.assert_equal(self.mtrace0.get_values(varname, combine=False,
×
535
                                                     squeeze=False),
536
                             self.mtrace1.get_values(varname, combine=False,
537
                                                     squeeze=False))
538

539
    def test_get_values_combine_and_no_squeeze(self):
×
540
        for varname in self.test_point.keys():
×
541
            npt.assert_equal(self.mtrace0.get_values(varname, combine=True,
×
542
                                                     squeeze=False),
543
                             self.mtrace1.get_values(varname, combine=True,
544
                                                     squeeze=False))
545

546
    def test_get_values_with_burn(self):
×
547
        for varname in self.test_point.keys():
×
548
            for cf in [False, True]:
×
549
                npt.assert_equal(self.mtrace0.get_values(varname, combine=cf,
×
550
                                                         burn=3),
551
                                 self.mtrace1.get_values(varname, combine=cf,
552
                                                         burn=3))
553
                # Burn to one value.
554
                npt.assert_equal(self.mtrace0.get_values(varname, combine=cf,
×
555
                                                         burn=self.draws - 1),
556
                                 self.mtrace1.get_values(varname, combine=cf,
557
                                                         burn=self.draws - 1))
558

559
    def test_get_values_with_thin(self):
×
560
        for varname in self.test_point.keys():
×
561
            for cf in [False, True]:
×
562
                npt.assert_equal(self.mtrace0.get_values(varname, combine=cf,
×
563
                                                         thin=2),
564
                                 self.mtrace1.get_values(varname, combine=cf,
565
                                                         thin=2))
566

567
    def test_get_values_with_burn_and_thin(self):
×
568
        for varname in self.test_point.keys():
×
569
            for cf in [False, True]:
×
570
                npt.assert_equal(self.mtrace0.get_values(varname, combine=cf,
×
571
                                                         burn=2, thin=2),
572
                                 self.mtrace1.get_values(varname, combine=cf,
573
                                                         burn=2, thin=2))
574

575
    def test_get_values_with_chains_arg(self):
×
576
        for varname in self.test_point.keys():
×
577
            for cf in [False, True]:
×
578
                npt.assert_equal(self.mtrace0.get_values(varname, chains=[0],
×
579
                                                         combine=cf),
580
                                 self.mtrace1.get_values(varname, chains=[0],
581
                                                         combine=cf))
582

583
    def test_get_point(self):
×
584
        npoint, spoint = self.mtrace0[4], self.mtrace1[4]
×
585
        for varname in self.test_point.keys():
×
586
            npt.assert_equal(npoint[varname], spoint[varname])
×
587

588
    def test_point_with_chain_arg(self):
×
589
        npoint = self.mtrace0.point(4, chain=0)
×
590
        spoint = self.mtrace1.point(4, chain=0)
×
591
        for varname in self.test_point.keys():
×
592
            npt.assert_equal(npoint[varname], spoint[varname])
×
593

594

595
def remove_file_or_directory(name):
×
596
    try:
×
597
        os.remove(name)
×
598
    except OSError:
×
599
        shutil.rmtree(name, ignore_errors=True)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc