• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9362

pending completion
9362

Pull #3597

travis-ci

web-flow
Fix test bugs.
Pull Request #3597: WIP: Second try to vectorize sample_posterior_predictive.

513 of 513 new or added lines in 16 files covered. (100.0%)

12617 of 20534 relevant lines covered (61.44%)

0.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.01
/pymc3/tests/test_distributions.py
1
import itertools
1✔
2
import sys
1✔
3

4
from .helpers import SeededTest, select_by_precision
1✔
5
from ..vartypes import continuous_types
1✔
6
from ..model import Model, Point, Deterministic
1✔
7
from ..blocking import DictToVarBijection
1✔
8
from ..distributions import (
1✔
9
    DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
10
    MvStudentT, MvNormal, MatrixNormal, ZeroInflatedPoisson,
11
    ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
12
    BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
13
    InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
14
    NegativeBinomial, Geometric, Exponential, ExGaussian, Normal, TruncatedNormal,
15
    Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
16
    Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull,
17
    Gumbel, Logistic, OrderedLogistic, LogitNormal, Interpolated,
18
    ZeroInflatedBinomial, HalfFlat, AR1, KroneckerNormal, Rice,
19
    Kumaraswamy
20
)
21

22
from ..distributions import continuous
1✔
23
from pymc3.theanof import floatX
1✔
24
from numpy import array, inf, log, exp
1✔
25
from numpy.testing import assert_almost_equal, assert_allclose, assert_equal
1✔
26
import numpy.random as nr
1✔
27
import numpy as np
1✔
28
import pytest
1✔
29

30
from scipy import integrate
1✔
31
import scipy.stats.distributions as sp
1✔
32
import scipy.stats
1✔
33
from scipy.special import logit
1✔
34
import theano
1✔
35
import theano.tensor as tt
1✔
36
from ..math import kronecker
1✔
37

38
def get_lkj_cases():
1✔
39
    """
40
    Log probabilities calculated using the formulas in:
41
    http://www.sciencedirect.com/science/article/pii/S0047259X09000876
42
    """
43
    tri = np.array([0.7, 0.0, -0.7])
1✔
44
    return [
1✔
45
        (tri, 1, 3, 1.5963125911388549),
46
        (tri, 3, 3, -7.7963493376312742),
47
        (tri, 0, 3, -np.inf),
48
        (np.array([1.1, 0.0, -0.7]), 1, 3, -np.inf),
49
        (np.array([0.7, 0.0, -1.1]), 1, 3, -np.inf)
50
    ]
51

52

53
LKJ_CASES = get_lkj_cases()
1✔
54

55

56
class Domain:
1✔
57
    def __init__(self, vals, dtype=None, edges=None, shape=None):
1✔
58
        avals = array(vals, dtype=dtype)
1✔
59
        if dtype is None and not str(avals.dtype).startswith('int'):
1✔
60
            avals = avals.astype(theano.config.floatX)
1✔
61
        vals = [array(v, dtype=avals.dtype) for v in vals]
1✔
62

63
        if edges is None:
1✔
64
            edges = array(vals[0]), array(vals[-1])
1✔
65
            vals = vals[1:-1]
1✔
66
        if shape is None:
1✔
67
            shape = avals[0].shape
1✔
68

69
        self.vals = vals
1✔
70
        self.shape = shape
1✔
71

72
        self.lower, self.upper = edges
1✔
73
        self.dtype = avals.dtype
1✔
74

75
    def __add__(self, other):
1✔
76
        return Domain(
1✔
77
            [v + other for v in self.vals],
78
            self.dtype,
79
            (self.lower + other, self.upper + other),
80
            self.shape)
81

82
    def __mul__(self, other):
1✔
83
        try:
1✔
84
            return Domain(
1✔
85
                [v * other for v in self.vals],
86
                self.dtype,
87
                (self.lower * other, self.upper * other),
88
                self.shape)
89
        except TypeError:
1✔
90
            return Domain(
1✔
91
                [v * other for v in self.vals],
92
                self.dtype,
93
                (self.lower, self.upper),
94
                self.shape)
95

96
    def __neg__(self):
1✔
97
        return Domain(
1✔
98
            [-v for v in self.vals],
99
            self.dtype,
100
            (-self.lower, -self.upper),
101
            self.shape)
102

103

104
def product(domains, n_samples=-1):
1✔
105
    """Get an iterator over a product of domains.
106

107
    Args:
108
        domains: a dictionary of (name, object) pairs, where the objects
109
                 must be "domain-like", as in, have a `.vals` property
110
        n_samples: int, maximum samples to return.  -1 to return whole product
111

112
    Returns:
113
        list of the cartesian product of the domains
114
    """
115
    try:
1✔
116
        names, domains = zip(*domains.items())
1✔
117
    except ValueError:  # domains.items() is empty
1✔
118
        return []
1✔
119
    all_vals = [zip(names, val) for val in itertools.product(*[d.vals for d in domains])]
1✔
120
    if n_samples > 0 and len(all_vals) > n_samples:
1✔
121
            return (all_vals[j] for j in nr.choice(len(all_vals), n_samples, replace=False))
1✔
122
    return all_vals
1✔
123

124

125
R = Domain([-inf, -2.1, -1, -.01, .0, .01, 1, 2.1, inf])
1✔
126
Rplus = Domain([0, .01, .1, .9, .99, 1, 1.5, 2, 100, inf])
1✔
127
Rplusbig = Domain([0, .5, .9, .99, 1, 1.5, 2, 20, inf])
1✔
128
Rminusbig = Domain([-inf, -2, -1.5, -1, -.99, -.9, -.5, -0.01, 0])
1✔
129
Unit = Domain([0, .001, .1, .5, .75, .99, 1])
1✔
130

131
Circ = Domain([-np.pi, -2.1, -1, -.01, .0, .01, 1, 2.1, np.pi])
1✔
132

133
Runif = Domain([-1, -.4, 0, .4, 1])
1✔
134
Rdunif = Domain([-10, 0, 10.])
1✔
135
Rplusunif = Domain([0, .5, inf])
1✔
136
Rplusdunif = Domain([2, 10, 100], 'int64')
1✔
137

138
I = Domain([-1000, -3, -2, -1, 0, 1, 2, 3, 1000], 'int64')
1✔
139

140
NatSmall = Domain([0, 3, 4, 5, 1000], 'int64')
1✔
141
Nat = Domain([0, 1, 2, 3, 2000], 'int64')
1✔
142
NatBig = Domain([0, 1, 2, 3, 5000, 50000], 'int64')
1✔
143
PosNat = Domain([1, 2, 3, 2000], 'int64')
1✔
144

145
Bool = Domain([0, 0, 1, 1], 'int64')
1✔
146

147

148
def build_model(distfam, valuedomain, vardomains, extra_args=None):
1✔
149
    if extra_args is None:
1✔
150
        extra_args = {}
×
151
    with Model() as m:
1✔
152
        vals = {}
1✔
153
        for v, dom in vardomains.items():
1✔
154
            vals[v] = Flat(v, dtype=dom.dtype, shape=dom.shape,
1✔
155
                           testval=dom.vals[0])
156
        vals.update(extra_args)
1✔
157
        distfam('value', shape=valuedomain.shape, transform=None, **vals)
1✔
158
    return m
1✔
159

160

161
def integrate_nd(f, domain, shape, dtype):
1✔
162
    if shape == () or shape == (1,):
1✔
163
        if dtype in continuous_types:
1✔
164
            return integrate.quad(f, domain.lower, domain.upper, epsabs=1e-8)[0]
×
165
        else:
166
            return sum(f(j) for j in range(domain.lower, domain.upper + 1))
1✔
167
    elif shape == (2,):
×
168
        def f2(a, b):
×
169
            return f([a, b])
×
170

171
        return integrate.dblquad(f2, domain.lower[0], domain.upper[0],
×
172
                                 lambda _: domain.lower[1],
173
                                 lambda _: domain.upper[1])[0]
174
    elif shape == (3,):
×
175
        def f3(a, b, c):
×
176
            return f([a, b, c])
×
177

178
        return integrate.tplquad(f3, domain.lower[0], domain.upper[0],
×
179
                                 lambda _: domain.lower[1],
180
                                 lambda _: domain.upper[1],
181
                                 lambda _, __: domain.lower[2],
182
                                 lambda _, __: domain.upper[2])[0]
183
    else:
184
        raise ValueError("Dont know how to integrate shape: " + str(shape))
×
185

186

187
def multinomial_logpdf(value, n, p):
1✔
188
    if value.sum() == n and (0 <= value).all() and (value <= n).all():
1✔
189
        logpdf = scipy.special.gammaln(n + 1)
1✔
190
        logpdf -= scipy.special.gammaln(value + 1).sum()
1✔
191
        logpdf += logpow(p, value).sum()
1✔
192
        return logpdf
1✔
193
    else:
194
        return -inf
1✔
195

196

197
def beta_mu_sigma(value, mu, sigma):
1✔
198
    kappa = mu * (1 - mu) / sigma**2 - 1
1✔
199
    if kappa > 0:
1✔
200
        return sp.beta.logpdf(value, mu * kappa, (1 - mu) * kappa)
1✔
201
    else:
202
        return -inf
1✔
203

204

205
class ProductDomain:
1✔
206
    def __init__(self, domains):
1✔
207
        self.vals = list(itertools.product(*[d.vals for d in domains]))
1✔
208
        self.shape = (len(domains),) + domains[0].shape
1✔
209
        self.lower = [d.lower for d in domains]
1✔
210
        self.upper = [d.upper for d in domains]
1✔
211
        self.dtype = domains[0].dtype
1✔
212

213

214
def Vector(D, n):
1✔
215
    return ProductDomain([D] * n)
1✔
216

217

218
def SortedVector(n):
1✔
219
    vals = []
1✔
220
    np.random.seed(42)
1✔
221
    for _ in range(10):
1✔
222
        vals.append(np.sort(np.random.randn(n)))
1✔
223
    return Domain(vals, edges=(None, None))
1✔
224

225

226
def UnitSortedVector(n):
1✔
227
    vals = []
1✔
228
    np.random.seed(42)
1✔
229
    for _ in range(10):
1✔
230
        vals.append(np.sort(np.random.rand(n)))
1✔
231
    return Domain(vals, edges=(None, None))
1✔
232

233

234
def RealMatrix(n, m):
1✔
235
    vals = []
1✔
236
    np.random.seed(42)
1✔
237
    for _ in range(10):
1✔
238
        vals.append(np.random.randn(n, m))
1✔
239
    return Domain(vals, edges=(None, None))
1✔
240

241

242
def simplex_values(n):
1✔
243
    if n == 1:
1✔
244
        yield array([1.0])
1✔
245
    else:
246
        for v in Unit.vals:
1✔
247
            for vals in simplex_values(n - 1):
1✔
248
                yield np.concatenate([[v], (1 - v) * vals])
1✔
249

250

251
def normal_logpdf_tau(value, mu, tau):
1✔
252
    return normal_logpdf_cov(value, mu, np.linalg.inv(tau)).sum()
1✔
253

254

255
def normal_logpdf_cov(value, mu, cov):
1✔
256
    return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum()
1✔
257

258

259
def normal_logpdf_chol(value, mu, chol):
1✔
260
    return normal_logpdf_cov(value, mu, np.dot(chol, chol.T)).sum()
1✔
261

262

263
def normal_logpdf_chol_upper(value, mu, chol):
1✔
264
    return normal_logpdf_cov(value, mu, np.dot(chol.T, chol)).sum()
1✔
265

266

267
def matrix_normal_logpdf_cov(value, mu, rowcov, colcov):
1✔
268
    return scipy.stats.matrix_normal.logpdf(value, mu, rowcov, colcov)
1✔
269

270

271
def matrix_normal_logpdf_chol(value, mu, rowchol, colchol):
1✔
272
    return matrix_normal_logpdf_cov(value, mu, np.dot(rowchol, rowchol.T),
1✔
273
                                    np.dot(colchol, colchol.T))
274

275

276
def kron_normal_logpdf_cov(value, mu, covs, sigma):
1✔
277
    cov = kronecker(*covs).eval()
1✔
278
    if sigma is not None:
1✔
279
        cov += sigma**2 * np.eye(*cov.shape)
1✔
280
    return scipy.stats.multivariate_normal.logpdf(value, mu, cov).sum()
1✔
281

282

283
def kron_normal_logpdf_chol(value, mu, chols, sigma):
1✔
284
    covs = [np.dot(chol, chol.T) for chol in chols]
1✔
285
    return kron_normal_logpdf_cov(value, mu, covs, sigma=sigma)
1✔
286

287

288
def kron_normal_logpdf_evd(value, mu, evds, sigma):
1✔
289
    covs = []
1✔
290
    for eigs, Q in evds:
1✔
291
        try:
1✔
292
            eigs = eigs.eval()
1✔
293
        except AttributeError:
1✔
294
            pass
1✔
295
        try:
1✔
296
            Q = Q.eval()
1✔
297
        except AttributeError:
1✔
298
            pass
1✔
299
        covs.append(np.dot(Q, np.dot(np.diag(eigs), Q.T)))
1✔
300
    return kron_normal_logpdf_cov(value, mu, covs, sigma)
1✔
301

302

303
def betafn(a):
1✔
304
    return floatX(scipy.special.gammaln(a).sum(-1) - scipy.special.gammaln(a.sum(-1)))
1✔
305

306

307
def logpow(v, p):
1✔
308
    return np.choose(v == 0, [p * np.log(v), 0])
1✔
309

310

311
def discrete_weibull_logpmf(value, q, beta):
1✔
312
    return floatX(np.log(np.power(floatX(q),
1✔
313
                                  np.power(floatX(value), floatX(beta)))
314
                  - np.power(floatX(q), np.power(floatX(value + 1),
315
                                                 floatX(beta)))))
316

317

318
def dirichlet_logpdf(value, a):
1✔
319
    return floatX((-betafn(a) + logpow(value, a - 1).sum(-1)).sum())
1✔
320

321

322
def categorical_logpdf(value, p):
1✔
323
    if value >= 0 and value <= len(p):
1✔
324
        return floatX(np.log(np.moveaxis(p, -1, 0)[value]))
1✔
325
    else:
326
        return -inf
×
327

328
def mvt_logpdf(value, nu, Sigma, mu=0):
1✔
329
    d = len(Sigma)
1✔
330
    dist = np.atleast_2d(value) - mu
1✔
331
    chol = np.linalg.cholesky(Sigma)
1✔
332
    trafo = np.linalg.solve(chol, dist.T).T
1✔
333
    logdet = np.log(np.diag(chol)).sum()
1✔
334

335
    lgamma = scipy.special.gammaln
1✔
336
    norm = lgamma((nu + d) / 2.)  - 0.5 * d * np.log(nu * np.pi) - lgamma(nu / 2.)
1✔
337
    logp = norm - logdet - (nu + d) / 2. * np.log1p((trafo * trafo).sum(-1) / nu)
1✔
338
    return logp.sum()
1✔
339

340
def AR1_logpdf(value, k, tau_e):
1✔
341
    return (sp.norm(loc=0, scale=1/np.sqrt(tau_e)).logpdf(value[0]) +
1✔
342
            sp.norm(loc=k*value[:-1], scale=1/np.sqrt(tau_e)).logpdf(value[1:]).sum())
343

344
def invlogit(x, eps=sys.float_info.epsilon):
1✔
345
    return (1. - 2. * eps) / (1. + np.exp(-x)) + eps
1✔
346

347
def orderedlogistic_logpdf(value, eta, cutpoints):
1✔
348
    c = np.concatenate(([-np.inf], cutpoints, [np.inf]))
1✔
349
    ps = np.array([invlogit(eta - cc) - invlogit(eta - cc1)
1✔
350
                   for cc, cc1 in zip(c[:-1], c[1:])])
351
    p = ps[value]
1✔
352
    return np.where(np.all(ps >= 0), np.log(p), -np.inf)
1✔
353

354
class Simplex:
1✔
355
    def __init__(self, n):
1✔
356
        self.vals = list(simplex_values(n))
1✔
357
        self.shape = (n,)
1✔
358
        self.dtype = Unit.dtype
1✔
359

360

361
class MultiSimplex:
1✔
362
    def __init__(self, n_dependent, n_independent):
1✔
363
        self.vals = []
1✔
364
        for simplex_value in itertools.product(simplex_values(n_dependent), repeat=n_independent):
1✔
365
            self.vals.append(np.vstack(simplex_value))
1✔
366
        self.shape = (n_independent, n_dependent)
1✔
367
        self.dtype = Unit.dtype
1✔
368

369

370
def PdMatrix(n):
1✔
371
    if n == 1:
1✔
372
        return PdMatrix1
1✔
373
    elif n == 2:
1✔
374
        return PdMatrix2
1✔
375
    elif n == 3:
1✔
376
        return PdMatrix3
1✔
377
    else:
378
        raise ValueError("n out of bounds")
×
379

380
PdMatrix1 = Domain([np.eye(1), [[.5]]], edges=(None, None))
1✔
381

382
PdMatrix2 = Domain([np.eye(2), [[.5, .05], [.05, 4.5]]], edges=(None, None))
1✔
383

384
PdMatrix3 = Domain(
1✔
385
    [np.eye(3), [[.5, .1, 0], [.1, 1, 0], [0, 0, 2.5]]], edges=(None, None))
386

387

388
PdMatrixChol1 = Domain([np.eye(1), [[0.001]]], edges=(None, None))
1✔
389
PdMatrixChol2 = Domain([np.eye(2), [[0.1, 0], [10, 1]]], edges=(None, None))
1✔
390
PdMatrixChol3 = Domain([np.eye(3), [[0.1, 0, 0], [10, 100, 0], [0, 1, 10]]],
1✔
391
                       edges=(None, None))
392

393

394
def PdMatrixChol(n):
1✔
395
    if n == 1:
1✔
396
        return PdMatrixChol1
1✔
397
    elif n == 2:
1✔
398
        return PdMatrixChol2
1✔
399
    elif n == 3:
1✔
400
        return PdMatrixChol3
1✔
401
    else:
402
        raise ValueError("n out of bounds")
×
403

404

405
PdMatrixCholUpper1 = Domain([np.eye(1), [[0.001]]], edges=(None, None))
1✔
406
PdMatrixCholUpper2 = Domain([np.eye(2), [[0.1, 10], [0, 1]]], edges=(None, None))
1✔
407
PdMatrixCholUpper3 = Domain([np.eye(3), [[0.1, 10, 0], [0, 100, 1], [0, 0, 10]]],
1✔
408
                            edges=(None, None))
409

410

411
def PdMatrixCholUpper(n):
1✔
412
    if n == 1:
1✔
413
        return PdMatrixCholUpper1
1✔
414
    elif n == 2:
1✔
415
        return PdMatrixCholUpper2
1✔
416
    elif n == 3:
1✔
417
        return PdMatrixCholUpper3
1✔
418
    else:
419
        raise ValueError("n out of bounds")
×
420

421

422
def RandomPdMatrix(n):
1✔
423
    A = np.random.rand(n, n)
1✔
424
    return np.dot(A, A.T) + n * np.identity(n)
1✔
425

426

427
class TestMatchesScipy(SeededTest):
1✔
428
    def pymc3_matches_scipy(self, pymc3_dist, domain, paramdomains, scipy_dist,
1✔
429
                            decimal=None, extra_args=None, scipy_args=None):
430
        if extra_args is None:
1✔
431
            extra_args = {}
1✔
432
        if scipy_args is None:
1✔
433
            scipy_args = {}
1✔
434
        model = build_model(pymc3_dist, domain, paramdomains, extra_args)
1✔
435
        value = model.named_vars['value']
1✔
436

437
        def logp(args):
1✔
438
            args.update(scipy_args)
1✔
439
            return scipy_dist(**args)
1✔
440
        self.check_logp(model, value, domain, paramdomains, logp, decimal=decimal)
1✔
441

442
    def check_logp(self, model, value, domain, paramdomains, logp_reference, decimal=None):
1✔
443
        domains = paramdomains.copy()
1✔
444
        domains['value'] = domain
1✔
445
        logp = model.fastlogp
1✔
446
        for pt in product(domains, n_samples=100):
1✔
447
            pt = Point(pt, model=model)
1✔
448
            if decimal is None:
1✔
449
                decimal = select_by_precision(float64=6, float32=3)
1✔
450
            assert_almost_equal(logp(pt), logp_reference(pt), decimal=decimal, err_msg=str(pt))
1✔
451

452
    def check_logcdf(self, pymc3_dist, domain, paramdomains, scipy_logcdf, decimal=None):
1✔
453
        domains = paramdomains.copy()
1✔
454
        domains['value'] = domain
1✔
455
        if decimal is None:
1✔
456
            decimal = select_by_precision(float64=6, float32=3)
1✔
457
        for pt in product(domains, n_samples=100):
1✔
458
            params = dict(pt)
1✔
459
            scipy_cdf = scipy_logcdf(**params)
1✔
460
            value = params.pop('value')
1✔
461
            dist = pymc3_dist.dist(**params)
1✔
462
            assert_almost_equal(dist.logcdf(value).tag.test_value, scipy_cdf,
1✔
463
                                decimal=decimal, err_msg=str(pt))
464

465
    def check_int_to_1(self, model, value, domain, paramdomains):
1✔
466
        pdf = model.fastfn(exp(model.logpt))
1✔
467
        for pt in product(paramdomains, n_samples=10):
1✔
468
            pt = Point(pt, value=value.tag.test_value, model=model)
1✔
469
            bij = DictToVarBijection(value, (), pt)
1✔
470
            pdfx = bij.mapf(pdf)
1✔
471
            area = integrate_nd(pdfx, domain, value.dshape, value.dtype)
1✔
472
            assert_almost_equal(area, 1, err_msg=str(pt))
1✔
473

474
    def checkd(self, distfam, valuedomain, vardomains, checks=None, extra_args=None):
1✔
475
        if checks is None:
1✔
476
            checks = (self.check_int_to_1, )
1✔
477

478
        if extra_args is None:
1✔
479
            extra_args = {}
1✔
480
        m = build_model(distfam, valuedomain, vardomains, extra_args=extra_args)
1✔
481
        for check in checks:
1✔
482
            check(m, m.named_vars['value'], valuedomain, vardomains)
1✔
483

484
    def test_uniform(self):
1✔
485
        self.pymc3_matches_scipy(
1✔
486
            Uniform, Runif, {'lower': -Rplusunif, 'upper': Rplusunif},
487
            lambda value, lower, upper: sp.uniform.logpdf(value, lower, upper - lower))
488
        self.check_logcdf(Uniform, Runif, {'lower': -Rplusunif, 'upper': Rplusunif},
1✔
489
                          lambda value, lower, upper: sp.uniform.logcdf(value, lower, upper - lower))
490

491
    def test_triangular(self):
1✔
492
        self.pymc3_matches_scipy(
1✔
493
            Triangular, Runif, {'lower': -Rplusunif, 'c': Runif, 'upper': Rplusunif},
494
            lambda value, c, lower, upper: sp.triang.logpdf(value, c-lower, lower, upper-lower))
495
        self.check_logcdf(Triangular, Runif, {'lower': -Rplusunif, 'c': Runif, 'upper': Rplusunif},
1✔
496
                          lambda value, c, lower, upper: sp.triang.logcdf(value, c-lower, lower, upper-lower))
497

498
    def test_bound_normal(self):
1✔
499
        PositiveNormal = Bound(Normal, lower=0.)
1✔
500
        self.pymc3_matches_scipy(PositiveNormal, Rplus, {'mu': Rplus, 'sigma': Rplus},
1✔
501
                                 lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
502
                                 decimal=select_by_precision(float64=6, float32=-1))
503
        with Model(): x = PositiveNormal('x', mu=0, sigma=1, transform=None)
1✔
504
        assert np.isinf(x.logp({'x':-1}))
1✔
505

506
    def test_discrete_unif(self):
1✔
507
        self.pymc3_matches_scipy(
1✔
508
            DiscreteUniform, Rdunif, {'lower': -Rplusdunif, 'upper': Rplusdunif},
509
            lambda value, lower, upper: sp.randint.logpmf(value, lower, upper + 1))
510

511
    def test_flat(self):
1✔
512
        self.pymc3_matches_scipy(Flat, Runif, {}, lambda value: 0)
1✔
513
        with Model():
1✔
514
            x = Flat('a')
1✔
515
            assert_allclose(x.tag.test_value, 0)
1✔
516
        self.check_logcdf(Flat, Runif, {}, lambda value: np.log(0.5))
1✔
517
        # Check infinite cases individually.
518
        assert 0. == Flat.dist().logcdf(np.inf).tag.test_value
1✔
519
        assert -np.inf == Flat.dist().logcdf(-np.inf).tag.test_value
1✔
520

521
    def test_half_flat(self):
1✔
522
        self.pymc3_matches_scipy(HalfFlat, Rplus, {}, lambda value: 0)
1✔
523
        with Model():
1✔
524
            x = HalfFlat('a', shape=2)
1✔
525
            assert_allclose(x.tag.test_value, 1)
1✔
526
            assert x.tag.test_value.shape == (2,)
1✔
527
        self.check_logcdf(HalfFlat, Runif, {}, lambda value: -np.inf)
1✔
528
        # Check infinite cases individually.
529
        assert 0. == HalfFlat.dist().logcdf(np.inf).tag.test_value
1✔
530
        assert -np.inf == HalfFlat.dist().logcdf(-np.inf).tag.test_value
1✔
531

532
    def test_normal(self):
1✔
533
        self.pymc3_matches_scipy(Normal, R, {'mu': R, 'sigma': Rplus},
1✔
534
                                 lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
535
                                 decimal=select_by_precision(float64=6, float32=1)
536
                                 )
537
        self.check_logcdf(Normal, R, {'mu': R, 'sigma': Rplus},
1✔
538
                          lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma))
539

540
    def test_truncated_normal(self):
1✔
541
        def scipy_logp(value, mu, sigma, lower, upper):
1✔
542
            return sp.truncnorm.logpdf(
1✔
543
                value, (lower-mu)/sigma, (upper-mu)/sigma, loc=mu, scale=sigma)
544

545
        self.pymc3_matches_scipy(
1✔
546
            TruncatedNormal, R,
547
            {'mu': R, 'sigma': Rplusbig, 'lower': -Rplusbig, 'upper': Rplusbig},
548
            scipy_logp,
549
            decimal=select_by_precision(float64=6, float32=1)
550
        )
551

552
    def test_half_normal(self):
1✔
553
        self.pymc3_matches_scipy(HalfNormal, Rplus, {'sigma': Rplus},
1✔
554
                                 lambda value, sigma: sp.halfnorm.logpdf(value, scale=sigma),
555
                                 decimal=select_by_precision(float64=6, float32=-1)
556
                                 )
557
        self.check_logcdf(HalfNormal, Rplus, {'sigma': Rplus},
1✔
558
                          lambda value, sigma: sp.halfnorm.logcdf(value, scale=sigma))
559

560
    def test_chi_squared(self):
1✔
561
        self.pymc3_matches_scipy(ChiSquared, Rplus, {'nu': Rplusdunif},
1✔
562
                                 lambda value, nu: sp.chi2.logpdf(value, df=nu))
563

564
    @pytest.mark.xfail(reason="Poor CDF in SciPy. See scipy/scipy#869 for details.")
1✔
565
    def test_wald_scipy(self):
566
        self.pymc3_matches_scipy(Wald, Rplus, {'mu': Rplus, 'alpha': Rplus},
1✔
567
                                 lambda value, mu, alpha: sp.invgauss.logpdf(value, mu=mu, loc=alpha),
568
                                 decimal=select_by_precision(float64=6, float32=1)
569
                                 )
570
        self.check_logcdf(Wald, Rplus, {'mu': Rplus, 'alpha': Rplus},
1✔
571
                          lambda value, mu, alpha: sp.invgauss.logcdf(value, mu=mu, loc=alpha))
572

573
    @pytest.mark.parametrize('value,mu,lam,phi,alpha,logp', [
1✔
574
        (.5, .001, .5, None, 0., -124500.7257914),
575
        (1., .5, .001, None, 0., -4.3733162),
576
        (2., 1., None, None, 0., -2.2086593),
577
        (5., 2., 2.5, None, 0., -3.4374500),
578
        (7.5, 5., None, 1., 0., -3.2199074),
579
        (15., 10., None, .75, 0., -4.0360623),
580
        (50., 15., None, .66666, 0., -6.1801249),
581
        (.5, .001, 0.5, None, 0., -124500.7257914),
582
        (1., .5, .001, None, .5, -3.3330954),
583
        (2., 1., None, None, 1., -0.9189385),
584
        (5., 2., 2.5, None, 2., -2.2128783),
585
        (7.5, 5., None, 1., 2.5, -2.5283764),
586
        (15., 10., None, .75, 5., -3.3653647),
587
        (50., 15., None, .666666, 10., -5.6481874)
588
    ])
589
    def test_wald(self, value, mu, lam, phi, alpha, logp):
590
        # Log probabilities calculated using the dIG function from the R package gamlss.
591
        # See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or
592
        # http://www.gamlss.org/.
593
        with Model() as model:
1✔
594
            Wald('wald', mu=mu, lam=lam, phi=phi, alpha=alpha, transform=None)
1✔
595
        pt = {'wald': value}
1✔
596
        decimals = select_by_precision(float64=6, float32=1)
1✔
597
        assert_almost_equal(model.fastlogp(pt), logp, decimal=decimals, err_msg=str(pt))
1✔
598

599
    def test_beta(self):
1✔
600
        self.pymc3_matches_scipy(Beta, Unit, {'alpha': Rplus, 'beta': Rplus},
1✔
601
                                 lambda value, alpha, beta: sp.beta.logpdf(value, alpha, beta))
602
        self.pymc3_matches_scipy(Beta, Unit, {'mu': Unit, 'sigma': Rplus}, beta_mu_sigma)
1✔
603
        self.check_logcdf(Beta, Unit, {'alpha': Rplus, 'beta': Rplus},
1✔
604
                                lambda value, alpha, beta: sp.beta.logcdf(value, alpha, beta))
605

606
    def test_kumaraswamy(self):
1✔
607
        # Scipy does not have a built-in Kumaraswamy pdf
608
        def scipy_log_pdf(value, a, b):
1✔
609
            return np.log(a) + np.log(b) + (a - 1) * np.log(value) + (b - 1) * np.log(1 - value ** a)
1✔
610
        self.pymc3_matches_scipy(Kumaraswamy, Unit, {'a': Rplus, 'b': Rplus}, scipy_log_pdf)
1✔
611

612
    def test_exponential(self):
1✔
613
        self.pymc3_matches_scipy(Exponential, Rplus, {'lam': Rplus},
1✔
614
                                 lambda value, lam: sp.expon.logpdf(value, 0, 1 / lam))
615
        self.check_logcdf(Exponential, Rplus, {'lam': Rplus},
1✔
616
                          lambda value, lam: sp.expon.logcdf(value, 0, 1 / lam))
617

618
    def test_geometric(self):
1✔
619
        self.pymc3_matches_scipy(Geometric, Nat, {'p': Unit},
1✔
620
                                 lambda value, p: np.log(sp.geom.pmf(value, p)))
621

622
    def test_negative_binomial(self):
1✔
623
        def test_fun(value, mu, alpha):
1✔
624
            return sp.nbinom.logpmf(value, alpha, 1 - mu / (mu + alpha))
1✔
625
        self.pymc3_matches_scipy(NegativeBinomial, Nat, {
1✔
626
                            'mu': Rplus, 'alpha': Rplus}, test_fun)
627

628
    def test_laplace(self):
1✔
629
        self.pymc3_matches_scipy(Laplace, R, {'mu': R, 'b': Rplus},
1✔
630
                                 lambda value, mu, b: sp.laplace.logpdf(value, mu, b))
631
        self.check_logcdf(Laplace, R, {'mu': R, 'b': Rplus},
1✔
632
                          lambda value, mu, b: sp.laplace.logcdf(value, mu, b))
633

634
    def test_lognormal(self):
1✔
635
        self.pymc3_matches_scipy(
1✔
636
            Lognormal, Rplus, {'mu': R, 'tau': Rplusbig},
637
            lambda value, mu, tau: floatX(sp.lognorm.logpdf(value, tau**-.5, 0, np.exp(mu))))
638
        self.check_logcdf(Lognormal, Rplus, {'mu': R, 'tau': Rplusbig},
1✔
639
                          lambda value, mu, tau: sp.lognorm.logcdf(value, tau**-.5, 0, np.exp(mu)))
640

641
    def test_t(self):
1✔
642
        self.pymc3_matches_scipy(StudentT, R, {'nu': Rplus, 'mu': R, 'lam': Rplus},
1✔
643
                                 lambda value, nu, mu, lam: sp.t.logpdf(value, nu, mu, lam**-0.5))
644
        self.check_logcdf(StudentT, R, {'nu': Rplus, 'mu': R, 'lam': Rplus},
1✔
645
                          lambda value, nu, mu, lam: sp.t.logcdf(value, nu, mu, lam**-0.5))
646

647
    def test_cauchy(self):
1✔
648
        self.pymc3_matches_scipy(Cauchy, R, {'alpha': R, 'beta': Rplusbig},
1✔
649
                                 lambda value, alpha, beta: sp.cauchy.logpdf(value, alpha, beta))
650
        self.check_logcdf(Cauchy, R, {'alpha': R, 'beta': Rplusbig},
1✔
651
                          lambda value, alpha, beta: sp.cauchy.logcdf(value, alpha, beta))
652

653
    def test_half_cauchy(self):
1✔
654
        self.pymc3_matches_scipy(HalfCauchy, Rplus, {'beta': Rplusbig},
1✔
655
                                 lambda value, beta: sp.halfcauchy.logpdf(value, scale=beta))
656
        self.check_logcdf(HalfCauchy, Rplus, {'beta': Rplusbig},
1✔
657
                          lambda value, beta: sp.halfcauchy.logcdf(value, scale=beta))
658

659
    def test_gamma(self):
1✔
660
        self.pymc3_matches_scipy(
1✔
661
            Gamma, Rplus, {'alpha': Rplusbig, 'beta': Rplusbig},
662
            lambda value, alpha, beta: sp.gamma.logpdf(value, alpha, scale=1.0 / beta))
663

664
        def test_fun(value, mu, sigma):
1✔
665
            return sp.gamma.logpdf(value, mu**2 / sigma**2, scale=1.0 / (mu / sigma**2))
1✔
666
        self.pymc3_matches_scipy(
1✔
667
            Gamma, Rplus, {'mu': Rplusbig, 'sigma': Rplusbig}, test_fun)
668

669
        self.check_logcdf(
1✔
670
            Gamma, Rplus, {'alpha': Rplusbig, 'beta': Rplusbig},
671
            lambda value, alpha, beta: sp.gamma.logcdf(value, alpha, scale=1.0/beta))
672

673
    def test_inverse_gamma(self):
1✔
674
        self.pymc3_matches_scipy(
1✔
675
            InverseGamma, Rplus, {'alpha': Rplus, 'beta': Rplus},
676
            lambda value, alpha, beta: sp.invgamma.logpdf(value, alpha, scale=beta))
677

678
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"),
1✔
679
                           reason="Fails on float32 due to scaling issues")
680
    def test_inverse_gamma_alt_params(self):
681
        def test_fun(value, mu, sigma):
1✔
682
            alpha, beta = InverseGamma._get_alpha_beta(None, None, mu, sigma)
1✔
683
            return sp.invgamma.logpdf(value, alpha, scale=beta)
1✔
684
        self.pymc3_matches_scipy(
1✔
685
            InverseGamma, Rplus, {'mu': Rplus, 'sigma': Rplus}, test_fun)
686

687
    def test_pareto(self):
1✔
688
        self.pymc3_matches_scipy(Pareto, Rplus, {'alpha': Rplusbig, 'm': Rplusbig},
1✔
689
                                 lambda value, alpha, m: sp.pareto.logpdf(value, alpha, scale=m))
690
        self.check_logcdf(Pareto, Rplus, {'alpha': Rplusbig, 'm': Rplusbig},
1✔
691
                          lambda value, alpha, m: sp.pareto.logcdf(value, alpha, scale=m))
692

693
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32 due to inf issues")
1✔
694
    def test_weibull(self):
695
        self.pymc3_matches_scipy(Weibull, Rplus, {'alpha': Rplusbig, 'beta': Rplusbig},
1✔
696
                                 lambda value, alpha, beta: sp.exponweib.logpdf(value, 1, alpha, scale=beta),
697
                                 )
698
        self.check_logcdf(Weibull, Rplus, {'alpha': Rplusbig, 'beta': Rplusbig},
1✔
699
                          lambda value, alpha, beta:
700
                          sp.exponweib.logcdf(value, 1, alpha, scale=beta),)
701

702
    def test_half_studentt(self):
1✔
703
        # this is only testing for nu=1 (halfcauchy)
704
        self.pymc3_matches_scipy(HalfStudentT, Rplus, {'sigma': Rplus},
1✔
705
                                 lambda value, sigma: sp.halfcauchy.logpdf(value, 0, sigma))
706

707
    def test_skew_normal(self):
1✔
708
        self.pymc3_matches_scipy(SkewNormal, R, {'mu': R, 'sigma': Rplusbig, 'alpha': R},
1✔
709
                                 lambda value, alpha, mu, sigma: sp.skewnorm.logpdf(value, alpha, mu, sigma))
710

711
    def test_binomial(self):
1✔
712
        self.pymc3_matches_scipy(Binomial, Nat, {'n': NatSmall, 'p': Unit},
1✔
713
                                 lambda value, n, p: sp.binom.logpmf(value, n, p))
714

715
    # Too lazy to propagate decimal parameter through the whole chain of deps
716
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
717
    def test_beta_binomial(self):
718
        self.checkd(BetaBinomial, Nat, {'alpha': Rplus, 'beta': Rplus, 'n': NatSmall})
1✔
719

720
    def test_bernoulli(self):
1✔
721
        self.pymc3_matches_scipy(
1✔
722
            Bernoulli, Bool, {'logit_p': R},
723
            lambda value, logit_p: sp.bernoulli.logpmf(value, scipy.special.expit(logit_p)))
724
        self.pymc3_matches_scipy(Bernoulli, Bool, {'p': Unit},
1✔
725
                                 lambda value, p: sp.bernoulli.logpmf(value, p))
726

727

728
    def test_discrete_weibull(self):
1✔
729
        self.pymc3_matches_scipy(DiscreteWeibull, Nat,
1✔
730
                {'q': Unit, 'beta': Rplusdunif}, discrete_weibull_logpmf)
731

732
    def test_poisson(self):
1✔
733
        self.pymc3_matches_scipy(Poisson, Nat, {'mu': Rplus},
1✔
734
                                 lambda value, mu: sp.poisson.logpmf(value, mu))
735

736
    def test_bound_poisson(self):
1✔
737
        NonZeroPoisson = Bound(Poisson, lower=1.)
1✔
738
        self.pymc3_matches_scipy(NonZeroPoisson, PosNat, {'mu': Rplus},
1✔
739
                                lambda value, mu: sp.poisson.logpmf(value, mu))
740

741
        with Model(): x = NonZeroPoisson('x', mu=4)
1✔
742
        assert np.isinf(x.logp({'x':0}))
1✔
743

744
    def test_constantdist(self):
1✔
745
        self.pymc3_matches_scipy(Constant, I, {'c': I},
1✔
746
                                 lambda value, c: np.log(c == value))
747

748
    # Too lazy to propagate decimal parameter through the whole chain of deps
749
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
750
    def test_zeroinflatedpoisson(self):
751
        self.checkd(ZeroInflatedPoisson, Nat, {'theta': Rplus, 'psi': Unit})
1✔
752

753
    # Too lazy to propagate decimal parameter through the whole chain of deps
754
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
755
    def test_zeroinflatednegativebinomial(self):
756
        self.checkd(ZeroInflatedNegativeBinomial, Nat,
1✔
757
                    {'mu': Rplusbig, 'alpha': Rplusbig, 'psi': Unit})
758

759
    # Too lazy to propagate decimal parameter through the whole chain of deps
760
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
761
    def test_zeroinflatedbinomial(self):
762
        self.checkd(ZeroInflatedBinomial, Nat,
1✔
763
                    {'n': NatSmall, 'p': Unit, 'psi': Unit})
764

765
    @pytest.mark.parametrize('n', [1, 2, 3])
1✔
766
    def test_mvnormal(self, n):
767
        self.pymc3_matches_scipy(MvNormal, RealMatrix(5, n),
1✔
768
                                 {'mu': Vector(R, n), 'tau': PdMatrix(n)},
769
                                 normal_logpdf_tau)
770
        self.pymc3_matches_scipy(MvNormal, Vector(R, n),
1✔
771
                                 {'mu': Vector(R, n), 'tau': PdMatrix(n)},
772
                                 normal_logpdf_tau)
773
        self.pymc3_matches_scipy(MvNormal, RealMatrix(5, n),
1✔
774
                                 {'mu': Vector(R, n), 'cov': PdMatrix(n)},
775
                                 normal_logpdf_cov)
776
        self.pymc3_matches_scipy(MvNormal, Vector(R, n),
1✔
777
                                 {'mu': Vector(R, n), 'cov': PdMatrix(n)},
778
                                 normal_logpdf_cov)
779
        self.pymc3_matches_scipy(MvNormal, RealMatrix(5, n),
1✔
780
                                 {'mu': Vector(R, n), 'chol': PdMatrixChol(n)},
781
                                 normal_logpdf_chol,
782
                                 decimal=select_by_precision(float64=6, float32=-1))
783
        self.pymc3_matches_scipy(MvNormal, Vector(R, n),
1✔
784
                                 {'mu': Vector(R, n), 'chol': PdMatrixChol(n)},
785
                                 normal_logpdf_chol,
786
                                 decimal=select_by_precision(float64=6, float32=0))
787

788
        def MvNormalUpper(*args, **kwargs):
1✔
789
            return MvNormal(lower=False, *args, **kwargs)
1✔
790

791
        self.pymc3_matches_scipy(MvNormalUpper, Vector(R, n),
1✔
792
                                 {'mu': Vector(R, n), 'chol': PdMatrixCholUpper(n)},
793
                                 normal_logpdf_chol_upper,
794
                                 decimal=select_by_precision(float64=6, float32=0))
795

796
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32 due to inf issues")
1✔
797
    def test_mvnormal_indef(self):
798
        cov_val = np.array([[1, 0.5], [0.5, -2]])
1✔
799
        cov = tt.matrix('cov')
1✔
800
        cov.tag.test_value = np.eye(2)
1✔
801
        mu = floatX(np.zeros(2))
1✔
802
        x = tt.vector('x')
1✔
803
        x.tag.test_value = np.zeros(2)
1✔
804
        logp = MvNormal.dist(mu=mu, cov=cov).logp(x)
1✔
805
        f_logp = theano.function([cov, x], logp)
1✔
806
        assert f_logp(cov_val, np.ones(2)) == -np.inf
1✔
807
        dlogp = tt.grad(logp, cov)
1✔
808
        f_dlogp = theano.function([cov, x], dlogp)
1✔
809
        assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2))))
1✔
810

811
        logp = MvNormal.dist(mu=mu, tau=cov).logp(x)
1✔
812
        f_logp = theano.function([cov, x], logp)
1✔
813
        assert f_logp(cov_val, np.ones(2)) == -np.inf
1✔
814
        dlogp = tt.grad(logp, cov)
1✔
815
        f_dlogp = theano.function([cov, x], dlogp)
1✔
816
        assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2))))
1✔
817

818
    def test_mvnormal_init_fail(self):
1✔
819
        with Model():
1✔
820
            with pytest.raises(ValueError):
1✔
821
                x = MvNormal('x', mu=np.zeros(3), shape=3)
1✔
822
            with pytest.raises(ValueError):
1✔
823
                x = MvNormal('x', mu=np.zeros(3), cov=np.eye(3), tau=np.eye(3), shape=3)
1✔
824

825
    @pytest.mark.parametrize('n', [1, 2, 3])
1✔
826
    def test_matrixnormal(self, n):
827
        mat_scale = 1e3  # To reduce logp magnitude
1✔
828
        mean_scale = .1
1✔
829
        self.pymc3_matches_scipy(MatrixNormal, RealMatrix(n, n),
1✔
830
                                 {'mu': RealMatrix(n, n)*mean_scale,
831
                                  'rowcov': PdMatrix(n)*mat_scale,
832
                                  'colcov': PdMatrix(n)*mat_scale},
833
                                 matrix_normal_logpdf_cov)
834
        self.pymc3_matches_scipy(MatrixNormal, RealMatrix(2, n),
1✔
835
                                 {'mu': RealMatrix(2, n)*mean_scale,
836
                                  'rowcov': PdMatrix(2)*mat_scale,
837
                                  'colcov': PdMatrix(n)*mat_scale},
838
                                 matrix_normal_logpdf_cov)
839
        self.pymc3_matches_scipy(MatrixNormal, RealMatrix(3, n),
1✔
840
                                 {'mu': RealMatrix(3, n)*mean_scale,
841
                                  'rowchol': PdMatrixChol(3)*mat_scale,
842
                                  'colchol': PdMatrixChol(n)*mat_scale},
843
                                 matrix_normal_logpdf_chol,
844
                                 decimal=select_by_precision(float64=6, float32=-1))
845
        self.pymc3_matches_scipy(MatrixNormal, RealMatrix(n, 3),
1✔
846
                                 {'mu': RealMatrix(n, 3)*mean_scale,
847
                                  'rowchol': PdMatrixChol(n)*mat_scale,
848
                                  'colchol': PdMatrixChol(3)*mat_scale},
849
                                 matrix_normal_logpdf_chol,
850
                                 decimal=select_by_precision(float64=6, float32=0))
851

852
    @pytest.mark.parametrize('n', [2, 3])
1✔
853
    @pytest.mark.parametrize('m', [3])
1✔
854
    @pytest.mark.parametrize('sigma', [None, 1.0])
1✔
855
    def test_kroneckernormal(self, n, m, sigma):
856
        np.random.seed(5)
1✔
857
        N = n*m
1✔
858
        covs = [RandomPdMatrix(n), RandomPdMatrix(m)]
1✔
859
        chols = list(map(np.linalg.cholesky, covs))
1✔
860
        evds = list(map(np.linalg.eigh, covs))
1✔
861
        dom = Domain([np.random.randn(N)*0.1], edges=(None, None), shape=N)
1✔
862
        mu = Domain([np.random.randn(N)*0.1], edges=(None, None), shape=N)
1✔
863

864
        std_args = {'mu': mu}
1✔
865
        cov_args = {'covs': covs}
1✔
866
        chol_args = {'chols': chols}
1✔
867
        evd_args = {'evds': evds}
1✔
868
        if sigma is not None and sigma != 0:
1✔
869
            std_args['sigma'] = Domain([sigma], edges=(None, None))
1✔
870
        else:
871
            for args in [cov_args, chol_args, evd_args]:
1✔
872
                args['sigma'] = sigma
1✔
873

874
        self.pymc3_matches_scipy(
1✔
875
             KroneckerNormal, dom, std_args, kron_normal_logpdf_cov,
876
             extra_args=cov_args, scipy_args=cov_args)
877
        self.pymc3_matches_scipy(
1✔
878
             KroneckerNormal, dom, std_args, kron_normal_logpdf_chol,
879
             extra_args=chol_args, scipy_args=chol_args)
880
        self.pymc3_matches_scipy(
1✔
881
             KroneckerNormal, dom, std_args, kron_normal_logpdf_evd,
882
             extra_args=evd_args, scipy_args=evd_args)
883

884
        dom = Domain([np.random.randn(2, N)*0.1], edges=(None, None), shape=(2, N))
1✔
885

886
        self.pymc3_matches_scipy(
1✔
887
             KroneckerNormal, dom, std_args, kron_normal_logpdf_cov,
888
             extra_args=cov_args, scipy_args=cov_args)
889
        self.pymc3_matches_scipy(
1✔
890
             KroneckerNormal, dom, std_args, kron_normal_logpdf_chol,
891
             extra_args=chol_args, scipy_args=chol_args)
892
        self.pymc3_matches_scipy(
1✔
893
             KroneckerNormal, dom, std_args, kron_normal_logpdf_evd,
894
             extra_args=evd_args, scipy_args=evd_args)
895

896
    @pytest.mark.parametrize('n', [1, 2])
1✔
897
    def test_mvt(self, n):
898
        self.pymc3_matches_scipy(MvStudentT, Vector(R, n),
1✔
899
                                 {'nu': Rplus, 'Sigma': PdMatrix(n), 'mu': Vector(R, n)},
900
                                 mvt_logpdf)
901
        self.pymc3_matches_scipy(MvStudentT, RealMatrix(2, n),
1✔
902
                                 {'nu': Rplus, 'Sigma': PdMatrix(n), 'mu': Vector(R, n)},
903
                                 mvt_logpdf)
904

905
    @pytest.mark.parametrize('n', [2, 3, 4])
1✔
906
    def test_AR1(self, n):
907
        self.pymc3_matches_scipy(AR1, Vector(R, n), {'k': Unit, 'tau_e': Rplus}, AR1_logpdf)
1✔
908

909

910
    @pytest.mark.parametrize('n', [2, 3])
1✔
911
    def test_wishart(self, n):
912
        # This check compares the autodiff gradient to the numdiff gradient.
913
        # However, due to the strict constraints of the wishart,
914
        # it is impossible to numerically determine the gradient as a small
915
        # pertubation breaks the symmetry. Thus disabling. Also, numdifftools was
916
        # removed in June 2019, so an alternative would be needed.
917
        #
918
        # self.checkd(Wishart, PdMatrix(n), {'n': Domain([2, 3, 4, 2000]), 'V': PdMatrix(n)},
919
        #             checks=[self.check_dlogp])
920
        pass
1✔
921

922
    @pytest.mark.parametrize('x,eta,n,lp', LKJ_CASES)
1✔
923
    def test_lkj(self, x, eta, n, lp):
924
        with Model() as model:
1✔
925
            LKJCorr('lkj', eta=eta, n=n, transform=None)
1✔
926

927
        pt = {'lkj': x}
1✔
928
        decimals = select_by_precision(float64=6, float32=4)
1✔
929
        assert_almost_equal(model.fastlogp(pt), lp, decimal=decimals, err_msg=str(pt))
1✔
930

931
    @pytest.mark.parametrize('n', [2, 3])
1✔
932
    def test_dirichlet(self, n):
933
        self.pymc3_matches_scipy(Dirichlet, Simplex(
1✔
934
            n), {'a': Vector(Rplus, n)}, dirichlet_logpdf)
935

936
    def test_dirichlet_2D(self):
1✔
937
        self.pymc3_matches_scipy(Dirichlet, MultiSimplex(2, 2),
1✔
938
                                 {'a': Vector(Vector(Rplus, 2), 2)}, dirichlet_logpdf)
939

940
    @pytest.mark.parametrize('n', [2, 3])
1✔
941
    def test_multinomial(self, n):
942
        self.pymc3_matches_scipy(Multinomial, Vector(Nat, n), {'p': Simplex(n), 'n': Nat},
1✔
943
                                 multinomial_logpdf)
944

945
    @pytest.mark.parametrize('p,n', [
1✔
946
        [[.25, .25, .25, .25], 1],
947
        [[.3, .6, .05, .05], 2],
948
        [[.3, .6, .05, .05], 10],
949
    ])
950
    def test_multinomial_mode(self, p, n):
951
        _p = np.array(p)
1✔
952
        with Model() as model:
1✔
953
            m = Multinomial('m', n, _p, _p.shape)
1✔
954
        assert_allclose(m.distribution.mode.eval().sum(), n)
1✔
955
        _p = np.array([p, p])
1✔
956
        with Model() as model:
1✔
957
            m = Multinomial('m', n, _p, _p.shape)
1✔
958
        assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
1✔
959

960
    @pytest.mark.parametrize('p, shape, n', [
1✔
961
        [[.25, .25, .25, .25], 4, 2],
962
        [[.25, .25, .25, .25], (1, 4), 3],
963
        # 3: expect to fail
964
        # [[.25, .25, .25, .25], (10, 4)],
965
        [[.25, .25, .25, .25], (10, 1, 4), 5],
966
        # 5: expect to fail
967
        # [[[.25, .25, .25, .25]], (2, 4), [7, 11]],
968
        [[[.25, .25, .25, .25],
969
         [.25, .25, .25, .25]], (2, 4), 13],
970
        [[[.25, .25, .25, .25],
971
         [.25, .25, .25, .25]], (1, 2, 4), [23, 29]],
972
        [[[.25, .25, .25, .25],
973
         [.25, .25, .25, .25]], (10, 2, 4), [31, 37]],
974
        [[[.25, .25, .25, .25],
975
         [.25, .25, .25, .25]], (2, 4), [17, 19]],
976
    ])
977
    def test_multinomial_random(self, p, shape, n):
978
        p = np.asarray(p)
1✔
979
        with Model() as model:
1✔
980
            m = Multinomial('m', n=n, p=p, shape=shape)
1✔
981
        m.random()
1✔
982

983
    def test_multinomial_mode_with_shape(self):
1✔
984
        n = [1, 10]
1✔
985
        p = np.asarray([[.25, .25, .25, .25], [.26, .26, .26, .22]])
1✔
986
        with Model() as model:
1✔
987
            m = Multinomial('m', n=n, p=p, shape=(2, 4))
1✔
988
        assert_allclose(m.distribution.mode.eval().sum(axis=-1), n)
1✔
989

990
    def test_multinomial_vec(self):
1✔
991
        vals = np.array([[2, 4, 4], [3, 3, 4]])
1✔
992
        p = np.array([0.2, 0.3, 0.5])
1✔
993
        n = 10
1✔
994

995
        with Model() as model_single:
1✔
996
            Multinomial('m', n=n, p=p, shape=len(p))
1✔
997

998
        with Model() as model_many:
1✔
999
            Multinomial('m', n=n, p=p, shape=vals.shape)
1✔
1000

1001
        assert_almost_equal(scipy.stats.multinomial.logpmf(vals, n, p),
1✔
1002
                            np.asarray([model_single.fastlogp({'m': val}) for val in vals]),
1003
                            decimal=4)
1004

1005
        assert_almost_equal(scipy.stats.multinomial.logpmf(vals, n, p),
1✔
1006
                            model_many.free_RVs[0].logp_elemwise({'m': vals}).squeeze(),
1007
                            decimal=4)
1008

1009
        assert_almost_equal(sum([model_single.fastlogp({'m': val}) for val in vals]),
1✔
1010
                            model_many.fastlogp({'m': vals}),
1011
                            decimal=4)
1012

1013
    def test_multinomial_vec_1d_n(self):
1✔
1014
        vals = np.array([[2, 4, 4], [4, 3, 4]])
1✔
1015
        p = np.array([0.2, 0.3, 0.5])
1✔
1016
        ns = np.array([10, 11])
1✔
1017

1018
        with Model() as model:
1✔
1019
            Multinomial('m', n=ns, p=p, shape=vals.shape)
1✔
1020

1021
        assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, n in zip(vals, ns)]),
1✔
1022
                            model.fastlogp({'m': vals}),
1023
                            decimal=4)
1024

1025
    def test_multinomial_vec_1d_n_2d_p(self):
1✔
1026
        vals = np.array([[2, 4, 4], [4, 3, 4]])
1✔
1027
        ps = np.array([[0.2, 0.3, 0.5],
1✔
1028
                       [0.9, 0.09, 0.01]])
1029
        ns = np.array([10, 11])
1✔
1030

1031
        with Model() as model:
1✔
1032
            Multinomial('m', n=ns, p=ps, shape=vals.shape)
1✔
1033

1034
        assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, n, p in zip(vals, ns, ps)]),
1✔
1035
                            model.fastlogp({'m': vals}),
1036
                            decimal=4)
1037

1038
    def test_multinomial_vec_2d_p(self):
1✔
1039
        vals = np.array([[2, 4, 4], [3, 3, 4]])
1✔
1040
        ps = np.array([[0.2, 0.3, 0.5],
1✔
1041
                       [0.3, 0.3, 0.4]])
1042
        n = 10
1✔
1043

1044
        with Model() as model:
1✔
1045
            Multinomial('m', n=n, p=ps, shape=vals.shape)
1✔
1046

1047
        assert_almost_equal(sum([multinomial_logpdf(val, n, p) for val, p in zip(vals, ps)]),
1✔
1048
                            model.fastlogp({'m': vals}),
1049
                            decimal=4)
1050

1051
    def test_categorical_bounds(self):
1✔
1052
        with Model():
1✔
1053
            x = Categorical('x', p=np.array([0.2, 0.3, 0.5]))
1✔
1054
            assert np.isinf(x.logp({'x': -1}))
1✔
1055
            assert np.isinf(x.logp({'x': 3}))
1✔
1056

1057
    def test_categorical_valid_p(self):
1✔
1058
        with Model():
1✔
1059
            x = Categorical('x', p=np.array([-0.2, 0.3, 0.5]))
1✔
1060
            assert np.isinf(x.logp({'x': 0}))
1✔
1061
            assert np.isinf(x.logp({'x': 1}))
1✔
1062
            assert np.isinf(x.logp({'x': 2}))
1✔
1063
        with Model():
1✔
1064
            # A model where p sums to 1 but contains negative values
1065
            x = Categorical('x', p=np.array([-0.2, 0.7, 0.5]))
1✔
1066
            assert np.isinf(x.logp({'x': 0}))
1✔
1067
            assert np.isinf(x.logp({'x': 1}))
1✔
1068
            assert np.isinf(x.logp({'x': 2}))
1✔
1069
        with Model():
1✔
1070
            # Hard edge case from #2082
1071
            # Early automatic normalization of p's sum would hide the negative
1072
            # entries if there is a single or pair number of negative values
1073
            # and the rest are zero
1074
            x = Categorical('x', p=np.array([-1, -1, 0, 0]))
1✔
1075
            assert np.isinf(x.logp({'x': 0}))
1✔
1076
            assert np.isinf(x.logp({'x': 1}))
1✔
1077
            assert np.isinf(x.logp({'x': 2}))
1✔
1078
            assert np.isinf(x.logp({'x': 3}))
1✔
1079

1080
    @pytest.mark.parametrize('n', [2, 3, 4])
1✔
1081
    def test_categorical(self, n):
1082
        self.pymc3_matches_scipy(Categorical, Domain(range(n), 'int64'), {'p': Simplex(n)},
1✔
1083
                                 lambda value, p: categorical_logpdf(value, p))
1084

1085
    @pytest.mark.parametrize('n', [2, 3, 4])
1✔
1086
    def test_orderedlogistic(self, n):
1087
        self.pymc3_matches_scipy(OrderedLogistic, Domain(range(n), 'int64'),
1✔
1088
                                 {'eta': R, 'cutpoints': Vector(R, n-1)},
1089
                                 lambda value, eta, cutpoints: orderedlogistic_logpdf(value, eta, cutpoints))
1090

1091
    def test_densitydist(self):
1✔
1092
        def logp(x):
1✔
1093
            return -log(2 * .5) - abs(x - .5) / .5
1✔
1094
        self.checkd(DensityDist, R, {}, extra_args={'logp': logp})
1✔
1095

1096
    def test_get_tau_sigma(self):
1✔
1097
        sigma = np.array([2])
1✔
1098
        assert_almost_equal(continuous.get_tau_sigma(sigma=sigma), [1. / sigma**2, sigma])
1✔
1099

1100
    @pytest.mark.parametrize('value,mu,sigma,nu,logp', [
1✔
1101
        (0.5, -50.000, 0.500, 0.500, -99.8068528),
1102
        (1.0, -1.000, 0.001, 0.001, -1992.5922447),
1103
        (2.0, 0.001, 1.000, 1.000, -1.6720416),
1104
        (5.0, 0.500, 2.500, 2.500, -2.4543644),
1105
        (7.5, 2.000, 5.000, 5.000, -2.8259429),
1106
        (15.0, 5.000, 7.500, 7.500, -3.3093854),
1107
        (50.0, 50.000, 10.000, 10.000, -3.6436067),
1108
        (1000.0, 500.000, 10.000, 20.000, -27.8707323)
1109
    ])
1110
    def test_ex_gaussian(self, value, mu, sigma, nu, logp):
1111
        """Log probabilities calculated using the dexGAUS function from the R package gamlss.
1112
        See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or http://www.gamlss.org/."""
1113
        with Model() as model:
1✔
1114
            ExGaussian('eg', mu=mu, sigma=sigma, nu=nu)
1✔
1115
        pt = {'eg': value}
1✔
1116
        assert_almost_equal(model.fastlogp(pt), logp, decimal=select_by_precision(float64=6, float32=2), err_msg=str(pt))
1✔
1117

1118
    @pytest.mark.parametrize('value,mu,sigma,nu,logcdf', [
1✔
1119
        (0.5, -50.000, 0.500, 0.500, 0.0000000),
1120
        (1.0, -1.000, 0.001, 0.001, 0.0000000),
1121
        (2.0, 0.001, 1.000, 1.000, -0.2365674),
1122
        (5.0, 0.500, 2.500, 2.500, -0.2886489),
1123
        (7.5, 2.000, 5.000, 5.000, -0.5655104),
1124
        (15.0, 5.000, 7.500, 7.500, -0.4545255),
1125
        (50.0, 50.000, 10.000, 10.000, -1.433714),
1126
        (1000.0, 500.000, 10.000, 20.000, -1.573708e-11),
1127
    ])
1128
    def test_ex_gaussian_cdf(self, value, mu, sigma, nu, logcdf):
1129
        """Log probabilities calculated using the pexGAUS function from the R package gamlss.
1130
        See e.g., doi: 10.1111/j.1467-9876.2005.00510.x, or http://www.gamlss.org/."""
1131
        assert_almost_equal(
1✔
1132
            ExGaussian.dist(mu=mu, sigma=sigma, nu=nu).logcdf(value).tag.test_value,
1133
            logcdf,
1134
            decimal=select_by_precision(float64=6, float32=2),
1135
            err_msg=str((value, mu, sigma, nu, logcdf)))
1136

1137
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
1138
    def test_vonmises(self):
1139
        self.pymc3_matches_scipy(
1✔
1140
            VonMises, R, {'mu': Circ, 'kappa': Rplus},
1141
            lambda value, mu, kappa: floatX(sp.vonmises.logpdf(value, kappa, loc=mu)))
1142

1143
    def test_gumbel(self):
1✔
1144
        def gumbel(value, mu, beta):
1✔
1145
            return floatX(sp.gumbel_r.logpdf(value, loc=mu, scale=beta))
1✔
1146
        self.pymc3_matches_scipy(Gumbel, R, {'mu': R, 'beta': Rplusbig}, gumbel)
1✔
1147

1148
        def gumbellcdf(value, mu, beta):
1✔
1149
            return floatX(sp.gumbel_r.logcdf(value, loc=mu, scale=beta))
1✔
1150
        self.check_logcdf(Gumbel, R, {'mu': R, 'beta': Rplusbig}, gumbellcdf)
1✔
1151

1152
    def test_logistic(self):
1✔
1153
        self.pymc3_matches_scipy(Logistic, R, {'mu': R, 's': Rplus},
1✔
1154
                                 lambda value, mu, s: sp.logistic.logpdf(value, mu, s),
1155
                                 decimal=select_by_precision(float64=6, float32=1))
1156
        self.check_logcdf(Logistic, R, {'mu': R, 's': Rplus},
1✔
1157
                          lambda value, mu, s: sp.logistic.logcdf(value, mu, s),
1158
                          decimal=select_by_precision(float64=6, float32=1))
1159

1160
    def test_logitnormal(self):
1✔
1161
        self.pymc3_matches_scipy(LogitNormal, Unit, {'mu': R, 'sigma': Rplus},
1✔
1162
                                 lambda value, mu, sigma: (sp.norm.logpdf(logit(value), mu, sigma)
1163
                                                        - (np.log(value) + np.log1p(-value))),
1164
                                 decimal=select_by_precision(float64=6, float32=1))
1165

1166
    def test_multidimensional_beta_construction(self):
1✔
1167
        with Model():
1✔
1168
            Beta('beta', alpha=1., beta=1., shape=(10, 20))
1✔
1169

1170
    def test_rice(self):
1✔
1171
        self.pymc3_matches_scipy(Rice, Rplus, {'nu': Rplus, 'sigma': Rplusbig},
1✔
1172
                                 lambda value, nu, sigma: sp.rice.logpdf(value, b=nu / sigma, loc=0, scale=sigma))
1173
        self.pymc3_matches_scipy(Rice, Rplus, {'b': Rplus, 'sigma': Rplusbig},
1✔
1174
                                 lambda value, b, sigma: sp.rice.logpdf(value, b=b, loc=0, scale=sigma))
1175

1176
    @pytest.mark.xfail(condition=(theano.config.floatX == "float32"), reason="Fails on float32")
1✔
1177
    def test_interpolated(self):
1178
        for mu in R.vals:
1✔
1179
            for sigma in Rplus.vals:
1✔
1180
                #pylint: disable=cell-var-from-loop
1181
                xmin = mu - 5 * sigma
1✔
1182
                xmax = mu + 5 * sigma
1✔
1183

1184
                class TestedInterpolated (Interpolated):
1✔
1185
                    def __init__(self, **kwargs):
1✔
1186
                        x_points = np.linspace(xmin, xmax, 100000)
1✔
1187
                        pdf_points = sp.norm.pdf(x_points, loc=mu, scale=sigma)
1✔
1188
                        super().__init__(x_points=x_points, pdf_points=pdf_points, **kwargs)
1✔
1189

1190
                def ref_pdf(value):
1✔
1191
                    return np.where(
1✔
1192
                        np.logical_and(value >= xmin, value <= xmax),
1193
                        sp.norm.logpdf(value, mu, sigma),
1194
                        -np.inf * np.ones(value.shape)
1195
                    )
1196

1197
                self.pymc3_matches_scipy(TestedInterpolated, R, {}, ref_pdf)
1✔
1198

1199

1200
def test_bound():
1✔
1201
    np.random.seed(42)
1✔
1202
    UnboundNormal = Bound(Normal)
1✔
1203
    dist = UnboundNormal.dist(mu=0, sigma=1)
1✔
1204
    assert dist.transform is None
1✔
1205
    assert dist.default() == 0.
1✔
1206
    assert isinstance(dist.random(), np.ndarray)
1✔
1207

1208
    LowerNormal = Bound(Normal, lower=1)
1✔
1209
    dist = LowerNormal.dist(mu=0, sigma=1)
1✔
1210
    assert dist.logp(0).eval() == -np.inf
1✔
1211
    assert dist.default() > 1
1✔
1212
    assert dist.transform is not None
1✔
1213
    assert np.all(dist.random() > 1)
1✔
1214

1215
    UpperNormal = Bound(Normal, upper=-1)
1✔
1216
    dist = UpperNormal.dist(mu=0, sigma=1)
1✔
1217
    assert dist.logp(-0.5).eval() == -np.inf
1✔
1218
    assert dist.default() < -1
1✔
1219
    assert dist.transform is not None
1✔
1220
    assert np.all(dist.random() < -1)
1✔
1221

1222
    ArrayNormal = Bound(Normal, lower=[1, 2], upper=[2, 3])
1✔
1223
    dist = ArrayNormal.dist(mu=0, sigma=1, shape=2)
1✔
1224
    assert_equal(dist.logp([0.5, 3.5]).eval(), -np.array([np.inf, np.inf]))
1✔
1225
    assert_equal(dist.default(), np.array([1.5, 2.5]))
1✔
1226
    assert dist.transform is not None
1✔
1227
    with pytest.raises(ValueError) as err:
1✔
1228
        dist.random()
1✔
1229
    err.match('Drawing samples from distributions with array-valued')
1✔
1230

1231
    with Model():
1✔
1232
        a = ArrayNormal('c', shape=2)
1✔
1233
        assert_equal(a.tag.test_value, np.array([1.5, 2.5]))
1✔
1234

1235
    lower = tt.vector('lower')
1✔
1236
    lower.tag.test_value = np.array([1, 2]).astype(theano.config.floatX)
1✔
1237
    upper = 3
1✔
1238
    ArrayNormal = Bound(Normal, lower=lower, upper=upper)
1✔
1239
    dist = ArrayNormal.dist(mu=0, sigma=1, shape=2)
1✔
1240
    logp = dist.logp([0.5, 3.5]).eval({lower: lower.tag.test_value})
1✔
1241
    assert_equal(logp, -np.array([np.inf, np.inf]))
1✔
1242
    assert_equal(dist.default(), np.array([2, 2.5]))
1✔
1243
    assert dist.transform is not None
1✔
1244

1245
    with Model():
1✔
1246
        a = ArrayNormal('c', shape=2)
1✔
1247
        assert_equal(a.tag.test_value, np.array([2, 2.5]))
1✔
1248

1249
    rand = Bound(Binomial, lower=10).dist(n=20, p=0.3).random()
1✔
1250
    assert rand.dtype in [np.int16, np.int32, np.int64]
1✔
1251
    assert rand >= 10
1✔
1252

1253
    rand = Bound(Binomial, upper=10).dist(n=20, p=0.8).random()
1✔
1254
    assert rand.dtype in [np.int16, np.int32, np.int64]
1✔
1255
    assert rand <= 10
1✔
1256

1257
    rand = Bound(Binomial, lower=5, upper=8).dist(n=10, p=0.6).random()
1✔
1258
    assert rand.dtype in [np.int16, np.int32, np.int64]
1✔
1259
    assert rand >= 5 and rand <= 8
1✔
1260

1261
    with Model():
1✔
1262
        BoundPoisson = Bound(Poisson, upper=6)
1✔
1263
        BoundPoisson(name="y", mu=1)
1✔
1264

1265
    with Model():
1✔
1266
        BoundNormalNamedArgs = Bound(Normal, upper=6)("y", mu=2., sd=1.)
1✔
1267
        BoundNormalPositionalArgs = Bound(Normal, upper=6)("x", 2., 1.)
1✔
1268

1269

1270
    with Model():
1✔
1271
        BoundPoissonNamedArgs = Bound(Poisson, upper=6)("y", mu=2.)
1✔
1272
        BoundPoissonPositionalArgs = Bound(Poisson, upper=6)("x", 2.)
1✔
1273

1274

1275
class TestLatex:
1✔
1276

1277
    def setup_class(self):
1✔
1278
        # True parameter values
1279
        alpha, sigma = 1, 1
1✔
1280
        beta = [1, 2.5]
1✔
1281

1282
        # Size of dataset
1283
        size = 100
1✔
1284

1285
        # Predictor variable
1286
        X = np.random.normal(size=(size, 2)).dot(np.array([[1, 0], [0, 0.2]]))
1✔
1287

1288
        # Simulate outcome variable
1289
        Y = alpha + X.dot(beta) + np.random.randn(size)*sigma
1✔
1290
        with Model() as self.model:
1✔
1291
            # Priors for unknown model parameters
1292
            alpha = Normal('alpha', mu=0, sigma=10)
1✔
1293
            b = Normal('beta', mu=0, sigma=10, shape=(2,), observed=beta)
1✔
1294
            sigma = HalfNormal('sigma', sigma=1)
1✔
1295

1296
            #Test Cholesky parameterization
1297
            Z = MvNormal('Z', mu=np.zeros(2), chol=np.eye(2), shape=(2,))
1✔
1298

1299
            # Expected value of outcome
1300
            mu = Deterministic('mu', floatX(alpha + tt.dot(X, b)))
1✔
1301

1302
            # Likelihood (sampling distribution) of observations
1303
            Y_obs = Normal('Y_obs', mu=mu, sigma=sigma, observed=Y)
1✔
1304
        self.distributions = [alpha, sigma, mu, b, Z, Y_obs]
1✔
1305
        self.expected = (
1✔
1306
            r'$\text{alpha} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$',
1307
            r'$\text{sigma} \sim \text{HalfNormal}(\mathit{sigma}=1.0)$',
1308
            r'$\text{mu} \sim \text{Deterministic}(\text{alpha},~\text{Constant},~\text{beta})$',
1309
            r'$\text{beta} \sim \text{Normal}(\mathit{mu}=0.0,~\mathit{sigma}=10.0)$',
1310
            r'$Z \sim \text{MvNormal}(\mathit{mu}=array, \mathit{chol}=array)$',
1311
            r'$\text{Y_obs} \sim \text{Normal}(\mathit{mu}=\text{mu},~\mathit{sigma}=f(\text{sigma}))$'
1312
        )
1313

1314
    def test__repr_latex_(self):
1✔
1315
        for distribution, tex in zip(self.distributions, self.expected):
1✔
1316
            assert distribution._repr_latex_() == tex
1✔
1317

1318
        model_tex = self.model._repr_latex_()
1✔
1319

1320
        for tex in self.expected:  # make sure each variable is in the model
1✔
1321
            for segment in tex.strip('$').split(r'\sim'):
1✔
1322
                assert segment in model_tex
1✔
1323

1324
    def test___latex__(self):
1✔
1325
        for distribution, tex in zip(self.distributions, self.expected):
1✔
1326
            assert distribution._repr_latex_() == distribution.__latex__()
1✔
1327
        assert self.model._repr_latex_() == self.model.__latex__()
1✔
1328

1329

1330
def test_discrete_trafo():
1✔
1331
    with pytest.raises(ValueError) as err:
1✔
1332
        Binomial.dist(n=5, p=0.5, transform='log')
1✔
1333
    err.match('Transformations for discrete distributions')
1✔
1334
    with Model():
1✔
1335
        with pytest.raises(ValueError) as err:
1✔
1336
            Binomial('a', n=5, p=0.5, transform='log')
1✔
1337
        err.match('Transformations for discrete distributions')
1✔
1338

1339

1340
@pytest.mark.parametrize("shape", [tuple(), (1,), (3, 1), (3, 2)], ids=str)
1✔
1341
def test_orderedlogistic_dimensions(shape):
1342
    # Test for issue #3535
1343
    loge = np.log10(np.exp(1))
1✔
1344
    size = 7
1✔
1345
    p = np.ones(shape + (10,)) / 10
1✔
1346
    cutpoints = np.tile(logit(np.linspace(0, 1, 11)[1:-1]), shape + (1,))
1✔
1347
    obs = np.random.randint(0, 1, size=(size,) + shape)
1✔
1348
    with Model():
1✔
1349
        ol = OrderedLogistic(
1✔
1350
            "ol",
1351
            eta=np.zeros(shape),
1352
            cutpoints=cutpoints,
1353
            shape=shape,
1354
            observed=obs
1355
        )
1356
        c = Categorical(
1✔
1357
            "c",
1358
            p=p,
1359
            shape=shape,
1360
            observed=obs
1361
        )
1362
    ologp = ol.logp({"ol": 1}) * loge
1✔
1363
    clogp = c.logp({"c": 1}) * loge
1✔
1364
    expected = -np.prod((size,) + shape)
1✔
1365

1366
    assert c.distribution.p.ndim == (len(shape) + 1)
1✔
1367
    assert np.allclose(clogp, expected)
1✔
1368
    assert ol.distribution.p.ndim == (len(shape) + 1)
1✔
1369
    assert np.allclose(ologp, expected)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc