• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

JohannesBuchner / askcarl / 15449997508

04 Jun 2025 06:31PM UTC coverage: 95.749% (-0.1%) from 95.847%
15449997508

push

github

JohannesBuchner
bug fix, and simplify code

8 of 8 new or added lines in 2 files covered. (100.0%)

1 existing line in 1 file now uncovered.

856 of 894 relevant lines covered (95.75%)

1.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.7
/tests/test_density.py
1
import os
2✔
2
import numpy as np
2✔
3
from numpy import array
2✔
4
from scipy.stats import norm, multivariate_normal
2✔
5
from scipy.integrate import dblquad
2✔
6
from scipy.special import logsumexp
2✔
7
from numpy.testing import assert_allclose
2✔
8
from hypothesis import given, strategies as st, example, settings, HealthCheck
2✔
9
from hypothesis.extra.numpy import arrays
2✔
10
import pypmc.density.mixture
2✔
11
import pytest
2✔
12
import sklearn.mixture
2✔
13

14
import askcarl
2✔
15
from askcarl.utils import cov_to_prec_cholesky
2✔
16

17
import jax
2✔
18
jax.config.update("jax_enable_x64", True)
2✔
19

20

21
def test_stackoverflow_example():
2✔
22
    rng = np.random.default_rng(238492432)
2✔
23

24
    n = 6  # dimensionality  
2✔
25
    qc = 4  # number of given coordinates
2✔
26
    q = n - qc  # number of other coordinates (must be 2 if you want check to work)
2✔
27
    x = rng.random(n)  # generate values for all axes
2✔
28
    # the first q are the "other" coordinates for which you want the CDF
29
    # the rest are "given"
30

31
    A = rng.random(size=(n, n))  # generate covariance matrix 
2✔
32
    A = A + A.T + np.eye(n)*n
2✔
33
    mu = rng.random(n)  # generate mean
2✔
34
    dist0 = multivariate_normal(mean=mu, cov=A)
2✔
35

36
    # Generate MVN conditioned on x[q:] 
37
    # partition covariance matrix
38
    s11 = A[:q, :q]  # upper bound covariance
2✔
39
    s12 = A[:q, q:]  # mixed 1
2✔
40
    s21 = A[q:, :q]  # mixed 2
2✔
41
    s22 = A[q:, q:]  # given value covariance
2✔
42
    # partition mean
43
    mu1 = mu[:q]  # upper bound mean
2✔
44
    mu2 = mu[q:]  # given values mean
2✔
45
    x1 = x[:q]  # "other" values
2✔
46
    x2 = x[q:]  # given values
2✔
47

48
    print("input: upper", x1, mu1, "given", x2, mu2)
2✔
49
    print("cov_cross:", s12, s21)
2✔
50

51
    a = x2
2✔
52
    inv_s22 = np.linalg.inv(s22)
2✔
53
    print("inv_s22:", qc, inv_s22, x2)
2✔
54
    assert inv_s22.shape == (qc, qc)
2✔
55
    print((s12 @ inv_s22 @ (a - mu2)).shape)
2✔
56
    mu_c = mu1 + s12 @ inv_s22 @ (a - mu2)
2✔
57
    assert mu_c.shape == (q,)
2✔
58
    print("newcov shape:", (s12 @ inv_s22 @ s21).shape, s12 @ inv_s22 @ s21)
2✔
59
    A_c = s11 - s12 @ inv_s22 @ s21
2✔
60
    assert A_c.shape == (q, q)
2✔
61
    dist = multivariate_normal(mean=mu_c, cov=A_c)
2✔
62
    print("truth:", mu_c, A_c)
2✔
63
    pdf_part = multivariate_normal(mean=mu2, cov=s22).pdf(x2)
2✔
64
    logpdf_part = multivariate_normal(mean=mu2, cov=s22).logpdf(x2)
2✔
65

66
    # Check (assumes q = 2)
67
    def pdf(y, x):
2✔
68
        return dist0.pdf(np.concatenate(([x, y], x2)))
2✔
69

70
    p1 = dblquad(pdf, -np.inf, x[0], -np.inf, x[1])[0]  # joint probability
2✔
71
    p2 = dblquad(pdf, -np.inf, np.inf, -np.inf, np.inf)[0]  # marginal probability
2✔
72

73
    print("comparison:", p1, p2, dist.cdf(x1), pdf_part)
2✔
74
    # These should match (approximately)
75
    assert_allclose(dist.cdf(x1) * pdf_part, p1, atol=1e-6)
2✔
76
    #assert_allclose(dist.cdf(x1), 0.25772255281364065)
77
    #assert_allclose(p1/p2, 0.25772256555864476)
78

79
    c1 = askcarl.pdfcdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]), mean=mu, cov=A)
2✔
80
    #assert_allclose(mu_c, conditional_mean)
81
    #assert_allclose(A_c, conditional_cov)
82
    print("truth eval:", x1, dist.mean, dist.cov, dist.cdf(x1), c1)
2✔
83
    assert_allclose(dist.cdf(x1) * pdf_part, c1, atol=1e-6)
2✔
84

85
    g = askcarl.Gaussian(mean=mu, cov=A)
2✔
86
    c2 = g.conditional_pdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]))
2✔
87
    assert_allclose(dist.cdf(x1) * pdf_part, c2, atol=1e-6)
2✔
88

89
    logc2 = g.conditional_logpdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]))
2✔
90
    assert_allclose(dist.logcdf(x1) + logpdf_part, logc2, atol=1e-4)
2✔
91

92
def valid_QR(vectors):
2✔
93
    q, r = np.linalg.qr(vectors)
2✔
94
    return q.shape == vectors.shape and np.all(np.abs(np.diag(r)) > 1e-3) and np.all(np.abs(np.diag(r)) < 1000)
2✔
95

96
def make_covariance_matrix_via_QR(normalisations, vectors):
2✔
97
    q, r = np.linalg.qr(vectors)
2✔
98
    orthogonal_vectors = q @ np.diag(np.diag(r))
2✔
99
    cov = orthogonal_vectors @ np.diag(normalisations) @ orthogonal_vectors.T
2✔
100
    return cov
2✔
101

102
def valid_covariance_matrix(A, min_std=1e-6):
2✔
103
    if not np.isfinite(A).all():
2✔
104
        return False
×
105
    #if not np.std(A) > min_std:
106
    #    return False
107
    if (np.diag(A) <= min_std).any():
2✔
108
        return False
2✔
109

110
    try:
2✔
111
        np.linalg.inv(np.linalg.inv(A))
2✔
112
    except np.linalg.LinAlgError:
×
113
        return False
×
114

115
    try:
2✔
116
        multivariate_normal(mean=np.zeros(len(A)), cov=A)
2✔
117
    except ValueError:
2✔
118
        return False
2✔
119

120
    return True
2✔
121

122
@settings(max_examples=100, deadline=None)
2✔
123
@given(
2✔
124
    mu=arrays(np.float64, (6,), elements=st.floats(-10, 10)),
125
    x=arrays(np.float64, (6,), elements=st.floats(-10, 10)),
126
    eigval=arrays(np.float64, (6,), elements=st.floats(1e-6, 10)),
127
    vectors=arrays(np.float64, (6,6), elements=st.floats(-10, 10)).filter(valid_QR),
128
)
129
@example(
2✔
130
    mu=array([ 0.5    , -9.     , -2.     ,  0.99999,  0.99999,  0.99999]),
131
    x=array([0.00000000e+00, 3.86915453e+00, 0.00000000e+00, 0.00000000e+00,
132
           0.00000000e+00, 1.00000000e-05]),
133
    eigval=array([1.e-06, 1.e-06, 1.e-06, 1.e-06, 1.e-06, 1.e-06]),
134
    vectors=array([[ 0.00000000e+00, -2.00000000e+00, -2.00000000e+00,
135
            -2.00000000e+00, -2.00000000e+00, -2.00000000e+00],
136
           [-2.00000000e+00, -2.00000000e+00, -2.00000000e+00,
137
            -1.17549435e-38, -2.00000000e+00, -2.00000000e+00],
138
           [-2.00000000e+00, -2.00000000e+00, -2.00000000e+00,
139
            -2.00000000e+00, -2.00000000e+00, -2.00000000e+00],
140
           [-2.00000000e+00, -2.00000000e+00, -2.00000000e+00,
141
            -2.00000000e+00, -1.40129846e-45, -2.00000000e+00],
142
           [-2.00000000e+00,  3.33333333e-01, -2.00000000e+00,
143
            -2.00000000e+00, -2.00000000e+00, -2.00000000e+00],
144
           [-2.00000000e+00, -2.00000000e+00,  5.00000000e-01,
145
            -2.00000000e+00, -2.00000000e+00, -2.00000000e+00]]),
146
).via('discovered failure')
147
@example(
2✔
148
    mu=array([ 0., -9., 10.,  0.,  0.,  0.]),
149
    x=array([0., 4., 0., 0., 0., 0.]),
150
    eigval=array([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
151
    vectors=array([[0., 1., 1., 1., 1., 1.],
152
           [1., 1., 1., 0., 1., 1.],
153
           [1., 1., 1., 1., 1., 1.],
154
           [1., 1., 1., 1., 0., 1.],
155
           [1., 0., 1., 1., 1., 1.],
156
           [1., 1., 0., 1., 1., 1.]]),
157
).via('discovered failure')
158
@example(
2✔
159
    mu=array([0.     , 1.     , 0.5    , 0.03125, 1.     , 0.03125]),
160
    x=array([ 1.00000000e+01,  6.10351562e-05, -4.24959109e+00,  3.26712313e+00,
161
           -1.00000000e+01,  0.00000000e+00]),
162
    eigval=array([0.5, 1. , 0.5, 0.5, 0.5, 0.5]),
163
    vectors=array([[ 0.  ,  0.  ,  0.  , -0.25, -0.25, -0.25],
164
           [-0.25, -0.25, -0.25, -0.25,  0.  , -0.25],
165
           [ 0.  , -0.25, -0.25, -0.25, -0.25, -0.25],
166
           [-0.25, -0.25, -0.25,  0.  , -0.25, -0.25],
167
           [-0.25, -0.25,  0.  , -0.25, -0.25, -0.25],
168
           [-0.25, -0.25, -0.25, -0.25, -0.25, -0.25]]),
169
).via('discovered failure')
170
@example(
2✔
171
    mu=array([ 0.        ,  0.        , 10.        ,  0.        ,  6.64641649,
172
           -1.1       ]),
173
    x=array([10., 10., 10., 10., 10., 10.]),
174
    eigval=array([1.00000000e-06, 1.00000000e+00, 4.16143782e-01, 1.00000000e-06,
175
           6.99209529e-01, 2.42501010e-01]),
176
    vectors=array([[-1.00000000e-005, -5.67020051e+000, -5.67020051e+000,
177
            -5.67020051e+000, -5.67020051e+000, -5.67020051e+000],
178
           [-5.67020051e+000, -5.67020051e+000, -5.67020051e+000,
179
            -5.67020051e+000, -5.67020051e+000,  1.90000000e+000],
180
           [-5.67020051e+000, -5.67020051e+000, -5.67020051e+000,
181
             5.00000000e-001, -5.67020051e+000, -5.67020051e+000],
182
           [-5.67020051e+000, -1.19209290e-007,  2.22044605e-016,
183
            -5.67020051e+000, -5.67020051e+000, -5.67020051e+000],
184
           [-5.67020051e+000, -5.67020051e+000, -5.67020051e+000,
185
            -5.67020051e+000, -5.67020051e+000, -5.67020051e+000],
186
           [-5.67020051e+000,  1.11253693e-308, -5.67020051e+000,
187
            -5.67020051e+000, -5.67020051e+000, -1.17549435e-038]]),
188
).via('discovered failure')
189
def test_stackoverflow_like_examples(mu, x, eigval, vectors):
2✔
190
    A = make_covariance_matrix_via_QR(eigval, vectors)
2✔
191
    print("Cov:", A)
2✔
192
    stdevs = np.diag(A)**0.5
2✔
193
    print("stdevs:", stdevs)
2✔
194
    atol = max(stdevs) * 1e-4 * (1 + np.abs(x - mu).max()) + 1e-6
2✔
195
    print("atol:", atol)
2✔
196
    if not valid_covariance_matrix(A):
2✔
UNCOV
197
        return
×
198
    n = 6  # dimensionality  
2✔
199
    qc = 4  # number of given coordinates
2✔
200
    q = n - qc  # number of other coordinates (must be 2 if you want check to work)
2✔
201
    # the first q are the "other" coordinates for which you want the CDF
202
    # the rest are "given"
203

204
    A = A + A.T + np.eye(n)*n
2✔
205
    dist0 = multivariate_normal(mean=mu, cov=A)
2✔
206

207
    # Generate MVN conditioned on x[q:] 
208
    # partition covariance matrix
209
    s11 = A[:q, :q]  # upper bound covariance
2✔
210
    s12 = A[:q, q:]  # mixed 1
2✔
211
    s21 = A[q:, :q]  # mixed 2
2✔
212
    s22 = A[q:, q:]  # given value covariance
2✔
213
    # partition mean
214
    mu1 = mu[:q]  # upper bound mean
2✔
215
    mu2 = mu[q:]  # given values mean
2✔
216
    x1 = x[:q]  # "other" values
2✔
217
    x2 = x[q:]  # given values
2✔
218

219
    print("input: upper", x1, mu1, "given", x2, mu2)
2✔
220
    print("cov_cross:", s12, s21)
2✔
221

222
    a = x2
2✔
223
    inv_s22 = np.linalg.inv(s22)
2✔
224
    print("inv_s22:", qc, inv_s22, x2)
2✔
225
    assert inv_s22.shape == (qc, qc)
2✔
226
    print((s12 @ inv_s22 @ (a - mu2)).shape)
2✔
227
    mu_c = mu1 + s12 @ inv_s22 @ (a - mu2)
2✔
228
    assert mu_c.shape == (q,)
2✔
229
    print("newcov shape:", (s12 @ inv_s22 @ s21).shape, s12 @ inv_s22 @ s21)
2✔
230
    A_c = s11 - s12 @ inv_s22 @ s21
2✔
231
    assert A_c.shape == (q, q)
2✔
232
    dist = multivariate_normal(mean=mu_c, cov=A_c)
2✔
233
    print("truth:", mu_c, A_c)
2✔
234
    pdf_part = multivariate_normal(mean=mu2, cov=s22).pdf(x2)
2✔
235

236
    # Check (assumes q = 2)
237
    def pdf(y, x):
2✔
238
        return dist0.pdf(np.concatenate(([x, y], x2)))
2✔
239

240
    p1 = dblquad(pdf, -np.inf, x[0], -np.inf, x[1])[0]  # joint probability
2✔
241

242
    print("p1:", p1) #, "p2:", p2)
2✔
243
    # These should match (approximately)
244
    assert_allclose(dist.cdf(x1) * pdf_part, p1, atol=atol, rtol=1e-2)
2✔
245

246
    c1 = askcarl.pdfcdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]), mean=mu, cov=A)
2✔
247
    assert_allclose(dist.cdf(x1) * pdf_part, c1, atol=atol)
2✔
248

249
    g = askcarl.Gaussian(mean=mu, cov=A)
2✔
250
    c2 = g.conditional_pdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]))
2✔
251
    assert_allclose(dist.cdf(x1) * pdf_part, c2, atol=atol)
2✔
252

253
    precision_cholesky = cov_to_prec_cholesky(A)
2✔
254
    g = askcarl.Gaussian(mean=mu, cov=A, precision_cholesky=precision_cholesky)
2✔
255
    c2 = g.conditional_pdf(x.reshape((1, -1)), np.array([False, False, True, True, True, True]))
2✔
256
    assert_allclose(dist.cdf(x1) * pdf_part, c2, atol=atol)
2✔
257

258

259
def test_trivial_example():
2✔
260
    x = np.zeros((1, 1))
2✔
261
    g = askcarl.Gaussian(mean=np.zeros(1), cov=np.eye(1))
2✔
262
    assert_allclose(norm(0, 1).pdf(x), g.conditional_pdf(x, np.array([True])))
2✔
263

264
    print("zero")
2✔
265
    x = np.zeros((1, 1))
2✔
266
    g = askcarl.Gaussian(mean=np.zeros(1), cov=np.eye(1))
2✔
267
    assert_allclose(norm(0, 1).cdf(x), g.conditional_pdf(x, np.array([False])))
2✔
268

269

270
def test_trivial_mixture():
2✔
271
    x = np.zeros((1, 1))
2✔
272
    mask = np.ones((1, 1), dtype=bool)
2✔
273
    p_truth = norm(0, 1).pdf(x)[0]
2✔
274
    g = askcarl.Gaussian(mean=np.zeros(1), cov=np.eye(1))
2✔
275
    assert_allclose(p_truth, g.pdf(x, mask))
2✔
276
    mix = askcarl.GaussianMixture(means=[np.zeros(1)], covs=[np.eye(1)], weights=[1.0])
2✔
277
    assert_allclose(p_truth, mix.pdf(x, mask))
2✔
278

279
# Strategy to generate arbitrary dimensionality mean and covariance
280
@st.composite
2✔
281
def mean_and_cov(draw):
2✔
282
    dim = draw(st.integers(min_value=1, max_value=10))  # Arbitrary dimensionality
2✔
283
    mu = draw(arrays(np.float64, (dim,), elements=st.floats(-10, 10)))  # Mean vector
2✔
284
    eigval = draw(arrays(np.float64, (dim,), elements=st.floats(1e-6, 10)))
2✔
285
    vectors = draw(arrays(np.float64, (dim,dim), elements=st.floats(-10, 10)).filter(valid_QR))
2✔
286
    cov = make_covariance_matrix_via_QR(eigval, vectors)
2✔
287
    return dim, mu, cov
2✔
288

289

290
@given(mean_and_cov())
2✔
291
def test_single(mean_cov):
2✔
292
    # a askcarl with one component must behave the same as a single gaussian
293
    ndim, mu, cov = mean_cov
2✔
294
    if not valid_covariance_matrix(cov):
2✔
295
        return
×
296
    assert mu.shape == (ndim,), (mu, mu.shape, ndim)
2✔
297
    assert cov.shape == (ndim,ndim), (cov, cov.shape, ndim)
2✔
298
    
299
    # a askcarl with one component must behave the same as a single gaussian
300
    
301
    rv = askcarl.Gaussian(mu, cov)
2✔
302
    rv_truth = multivariate_normal(mu, cov)
2✔
303

304
    xi = np.random.randn(1, len(mu))  # A random vector of same dimensionality as `mu`
2✔
305
    assert_allclose(rv.conditional_pdf(xi), rv_truth.pdf(xi[0]))
2✔
306
    assert_allclose(rv.conditional_pdf(xi, np.array([True] * ndim)), rv_truth.pdf(xi[0]))
2✔
307

308
    assert_allclose(rv.pdf(xi, np.array([[True] * ndim])), rv_truth.pdf(xi[0]))
2✔
309

310
    assert_allclose(rv.logpdf(xi, np.array([[True] * ndim])), rv_truth.logpdf(xi[0]))
2✔
311

312
@st.composite
2✔
313
def mean_and_diag_stdevs2(draw):
2✔
314
    # at least 2 dimensions
315
    dim = draw(st.integers(min_value=2, max_value=10))
2✔
316
    mu = draw(arrays(np.float64, (dim,), elements=st.floats(-1e6, 1e6)))  # Mean vector
2✔
317
    stdevs = draw(arrays(np.float64, (dim,), elements=st.floats(1e-6, 1e6)))
2✔
318
    x = draw(arrays(np.float64, (dim,), elements=st.floats(-1e6, 1e6)))
2✔
319
    i = draw(st.integers(min_value=0, max_value=dim - 1))
2✔
320
    return dim, mu, stdevs, x, i
2✔
321

322

323
@given(mean_and_diag_stdevs2())
2✔
324
@settings(deadline=None)
2✔
325
@example(
2✔
326
    mean_and_cov=(2, array([1., 0.]), array([1., 1.]), array([0., 0.]), 1),
327
).via('discovered failure')
328
@example(
2✔
329
    mean_and_cov=(2, array([0., 0.]), array([2., 2.]), array([77., 77.]), 0),
330
).via('discovered failure')
331
@example(
2✔
332
    mean_and_cov=(2, array([0., 0.]), array([1., 1.]), array([39., 39.]), 0),
333
).via('discovered failure')
334
def test_single_with_UL(mean_and_cov):
2✔
335
    ndim, mu, stdevs, x, i = mean_and_cov
2✔
336
    cov = np.diag(stdevs**2)
2✔
337
    assert mu.shape == (ndim,), (mu, mu.shape, ndim)
2✔
338
    assert cov.shape == (ndim,ndim), (cov, cov.shape, ndim)
2✔
339
    if not valid_covariance_matrix(cov):
2✔
340
        return
2✔
341

342
    # a askcarl with one component must behave the same as a single gaussian
343
    print("inputs:", mu, stdevs, cov)
2✔
344
    rv = askcarl.Gaussian(mu, cov)
2✔
345

346
    mask = np.ones(ndim, dtype=bool)
2✔
347
    mask[i] = False
2✔
348
    rv_truth = multivariate_normal(mu[mask], np.diag(stdevs[mask]**2))
2✔
349

350
    xi = np.array([x, x])
2✔
351
    assert 0 <= i < ndim
2✔
352
    # set high/low upper limit
353
    xi[0,i] = 1e200
2✔
354
    xi[1,i] = -1e200
2✔
355
    pa = rv.conditional_pdf(xi, mask)
2✔
356
    pa_expected = np.array([1, 0]) * rv_truth.pdf(xi[:,mask])
2✔
357
    # pa_expected = rv_truth.pdf(xi[:,mask])
358
    print("for expectation:", xi[0,mask], mu[mask], stdevs[mask], pa, pa_expected)
2✔
359
    #print("Expected:", pa_expected)
360
    # pa_expected = 1 * rv_truth.pdf(xi)
361
    assert_allclose(pa, pa_expected, atol=1e-100)
2✔
362
    pb = rv.pdf(xi, np.array([mask,mask]))
2✔
363
    assert_allclose(pb, pa_expected, atol=1e-100)
2✔
364
    logpa_expected = np.array([0, -np.inf]) + rv_truth.logpdf(xi[:,mask])
2✔
365
    logpa = rv.logpdf(xi, np.array([mask,mask]))
2✔
366
    assert_allclose(logpa, logpa_expected)
2✔
367

368
@pytest.mark.parametrize("n_components", [1, 3, 10])
2✔
369
@pytest.mark.parametrize("covariance_type", ['full', 'tied', 'diag', 'spherical'])
2✔
370
def test_import(n_components, covariance_type):
2✔
371
    a = np.vstack((
2✔
372
        np.random.normal(3, 3, size=(10000, 3)),
373
        np.random.normal(0, 1, size=(3000, 3)),
374
        np.random.normal(3, 1, size=(10000, 3)),
375
    ))
376
    assert a.shape == (23000, 3), a.shape
2✔
377
    skgmm = sklearn.mixture.GaussianMixture(
2✔
378
        n_components=n_components, covariance_type=covariance_type)
379
    skgmm.fit(a)
2✔
380
    askcarl_fromsklearn = askcarl.GaussianMixture.from_sklearn(skgmm)
2✔
381
    
382
    means = [g.mean for g in askcarl_fromsklearn.components]
2✔
383
    covs = [g.cov for g in askcarl_fromsklearn.components]
2✔
384
    print(means)
2✔
385
    print([np.diag(cov) for cov in covs])
2✔
386
    if covariance_type in ('full', 'diag') and n_components == 3:
2✔
387
        assert any(np.allclose(mean, 3, atol=0.1) for mean in means)
2✔
388
        assert any(np.allclose(mean, 0, atol=0.1) for mean in means)
2✔
389
    
390
    target_mixture = pypmc.density.mixture.create_gaussian_mixture(
2✔
391
        means, covs, askcarl_fromsklearn.weights)
392
    askcarl_frompypmc = askcarl.GaussianMixture.from_pypmc(target_mixture)
2✔
393

394
    means2 = [g.mean for g in askcarl_frompypmc.components]
2✔
395
    covs2 = [g.cov for g in askcarl_frompypmc.components]
2✔
396

397
    assert_allclose(means2, means)
2✔
398
    assert_allclose(covs2, covs)
2✔
399

400
@st.composite
2✔
401
def mixture_strategy(draw):
2✔
402
    dim = draw(st.integers(min_value=1, max_value=10))
2✔
403
    ntest = draw(st.integers(min_value=1, max_value=10))
2✔
404
    ncomponents = draw(st.integers(min_value=1, max_value=10))
2✔
405
    means = [draw(arrays(np.float64, (dim,), elements=st.floats(-10, 10))) for _ in range(ncomponents)]
2✔
406
    covs = [make_covariance_matrix_via_QR(
2✔
407
        draw(arrays(np.float64, (dim,), elements=st.floats(1e-6, 10))),
408
        draw(arrays(np.float64, (dim,dim), elements=st.floats(-10, 10)).filter(valid_QR))
409
    ) for _ in range(ncomponents)]
410
    weights = draw(arrays(np.float64, (ncomponents,), elements=st.floats(0, 1)).filter(lambda weights: (weights>0).any()))
2✔
411
    weights /= weights.sum()
2✔
412
    x = draw(arrays(np.float64, (ntest, dim), elements=st.floats(-10, 10)))
2✔
413
    return dim, ncomponents, means, covs, weights, x
2✔
414

415
from  sklearn.mixture._gaussian_mixture import _estimate_log_gaussian_prob
2✔
416

417
@settings(suppress_health_check=[HealthCheck.filter_too_much], max_examples=1000, deadline=None)
2✔
418
@given(mixture_strategy())
2✔
419
@example(
2✔
420
    mixture=(5,
421
     2,
422
     [array([0., 0., 0., 0., 0.]), array([0., 0., 0., 0., 0.])],
423
     [array([[ 1.49244777, -0.07387876, -0.02245019, -0.25755223,  0.35388117],
424
             [-0.07387876,  1.96311104,  0.58596818,  0.48862124,  0.96229954],
425
             [-0.02245019,  0.58596818,  1.89882532,  0.54004981,  0.97515668],
426
             [-0.25755223,  0.48862124,  0.54004981,  2.05494777,  0.91638117],
427
             [ 0.35388117,  0.96229954,  0.97515668,  0.91638117,  1.1461626 ]]),
428
      array([[0.90301624, 0.46133735, 0.34701563, 0.13561725, 0.34582217],
429
             [0.46133735, 0.55336326, 0.4630311 , 0.36990653, 0.46231502],
430
             [0.34701563, 0.4630311 , 0.89632843, 0.14126309, 0.3479933 ],
431
             [0.13561725, 0.36990653, 0.14126309, 0.55367394, 0.13887617],
432
             [0.34582217, 0.46231502, 0.3479933 , 0.13887617, 0.89943142]])],
433
     array([0., 1.]),
434
     array([[0., 7., 0., 0., 0.]])),
435
).via('discovered failure')
436
@example(
2✔
437
    mixture=(5,
438
     2,
439
     [array([0., 0., 0., 0., 0.]), array([0., 8., 0., 0., 0.])],
440
     [array([[ 1.49244777e-06, -7.38787586e-08, -2.24501872e-08,
441
              -2.57552228e-07,  3.53881174e-07],
442
             [-7.38787586e-08,  1.96311104e-06,  5.85968180e-07,
443
               4.88621241e-07,  9.62299541e-07],
444
             [-2.24501872e-08,  5.85968180e-07,  1.89882532e-06,
445
               5.40049813e-07,  9.75156684e-07],
446
             [-2.57552228e-07,  4.88621241e-07,  5.40049813e-07,
447
               2.05494777e-06,  9.16381174e-07],
448
             [ 3.53881174e-07,  9.62299541e-07,  9.75156684e-07,
449
               9.16381174e-07,  1.14616260e-06]]),
450
      array([[ 0.2641    ,  0.05401538, -0.03798462,  0.14390769, -0.08798462],
451
             [ 0.05401538,  0.28357929,  0.23057929,  0.16052426,  0.21807929],
452
             [-0.03798462,  0.23057929,  0.46757929,  0.06852426,  0.08007929],
453
             [ 0.14390769,  0.16052426,  0.06852426,  0.31735444,  0.01852426],
454
             [-0.08798462,  0.21807929,  0.08007929,  0.01852426,  0.50507929]])],
455
     array([0., 1.]),
456
     array([[0., 0., 0., 0., 0.],
457
            [0., 0., 0., 0., 0.],
458
            [0., 0., 0., 0., 0.],
459
            [0., 0., 0., 0., 0.],
460
            [0., 0., 0., 0., 0.],
461
            [0., 0., 0., 0., 0.]])),
462
).via('discovered failure')
463
@example(
2✔
464
    mixture=(7,
465
     1,
466
     [array([0., 0., 0., 0., 0., 0., 0.])],
467
     [array([[156.25637361, -31.26969455,  12.40439959, -31.24564618,
468
              -31.22989933, -31.24225165, -31.2688804 ],
469
             [-31.26969455,   6.33006684,  -2.20457805,   6.2198964 ,
470
                6.16861271,   6.20884123,   6.34227887],
471
             [ 12.40439959,  -2.20457805,   2.43400622,  -2.56530348,
472
               -2.80150635,  -2.61622149,  -2.21679021],
473
             [-31.24564618,   6.2198964 ,  -2.56530348,   7.16206397,
474
                5.80788728,   5.8481158 ,   6.20768424],
475
             [-31.22989933,   6.16861271,  -2.80150635,   5.80788728,
476
                7.34003101,   5.75696927,   6.15640055],
477
             [-31.24225165,   6.20884123,  -2.61622149,   5.8481158 ,
478
                5.75696927,   7.23169779,   6.19662906],
479
             [-31.2688804 ,   6.34227887,  -2.21679021,   6.20768424,
480
                6.15640055,   6.19662906,   6.36588918]])],
481
     array([1.]),
482
     array([[0., 0., 0., 0., 0., 0., 0.]])),
483
).via('discovered failure')
484
@example(
2✔
485
    mixture=(1, 1, [array([0.])], [array([[1.]])], array([1.]), array([[0.]])),
486
).via('discovered failure')
487
@example(
2✔
488
    mixture=(8,
489
     1,
490
     [array([0., 0., 0., 0., 0., 0., 0., 0.])],
491
     [array([[ 1.53365164e-06, -4.56057424e-07, -2.47157040e-07,
492
              -1.88188005e-07,  1.47416388e-07,  1.47416391e-07,
493
               4.49153299e-07,  1.47416391e-07],
494
             [-4.56057424e-07,  1.76562629e+00,  9.99999625e-01,
495
               9.99999679e-01, -3.26562499e+00,  1.98437501e+00,
496
               2.53124938e+00,  1.98437501e+00],
497
             [-2.47157040e-07,  9.99999625e-01,  1.00000099e+00,
498
               9.99999580e-01,  9.99999915e-01,  9.99999915e-01,
499
               1.00000006e+00,  9.99999915e-01],
500
             [-1.88188005e-07,  9.99999679e-01,  9.99999580e-01,
501
               1.00000091e+00,  9.99999927e-01,  9.99999927e-01,
502
               1.00000005e+00,  9.99999927e-01],
503
             [ 1.47416388e-07, -3.26562499e+00,  9.99999915e-01,
504
               9.99999927e-01,  2.47656250e+01, -4.48437497e+00,
505
              -7.53124995e+00, -4.48437497e+00],
506
             [ 1.47416391e-07,  1.98437501e+00,  9.99999915e-01,
507
               9.99999927e-01, -4.48437497e+00,  2.26562529e+00,
508
               2.96875004e+00,  2.26562479e+00],
509
             [ 4.49153299e-07,  2.53124938e+00,  1.00000006e+00,
510
               1.00000005e+00, -7.53124995e+00,  2.96875004e+00,
511
               4.06250038e+00,  2.96875004e+00],
512
             [ 1.47416391e-07,  1.98437501e+00,  9.99999915e-01,
513
               9.99999927e-01, -4.48437497e+00,  2.26562479e+00,
514
               2.96875004e+00,  2.26562529e+00]])],
515
     array([1.]),
516
     array([[0., 0., 0., 0., 0., 0., 0., 0.]])),
517
).via('discovered failure')
518
@example(
2✔
519
    mixture=(7, 1,
520
        [array([0., 0., 0., 0., 0., 0., 0.])],
521
        [array([[ 7.52684752, -4.9545883 , -0.31361835,  0.52269766,  0.52269929,
522
             6.92515709, -9.70648636],
523
           [-4.9545883 ,  3.26139136,  0.20644122, -0.3440688 , -0.34406688,
524
            -4.55852195,  6.38935566],
525
           [-0.31361835,  0.20644122,  0.01308347, -0.02177316, -0.0217791 ,
526
            -0.28854817,  0.40443697],
527
           [ 0.52269766, -0.3440688 , -0.02177316,  0.03630094,  0.0362984 ,
528
             0.48091352, -0.67406172],
529
           [ 0.52269929, -0.34406688, -0.0217791 ,  0.0362984 ,  0.03630217,
530
             0.48091543, -0.67405981],
531
           [ 6.92515709, -4.55852195, -0.28854817,  0.48091352,  0.48091543,
532
             6.37156541, -8.9305567 ],
533
           [-9.70648636,  6.38935566,  0.40443697, -0.67406172, -0.67405981,
534
            -8.9305567 , 12.51732134]])],
535
       array([1.]),
536
       array([[0., 0., 0., 0., 0., 0., 0.]])),
537
).via('discovered failure')
538
@example(
2✔
539
    mixture=(3,
540
     1,
541
     [array([0., 0., 0.])],
542
     [array([[32.00000005, 35.99999994,  8.00000004],
543
             [35.99999994, 40.50000006,  8.99999996],
544
             [ 8.00000004,  8.99999996,  2.00000005]])],
545
     array([1.]),
546
     array([[0., 0., 0.]])),
547
).via('discovered failure')
548
@example(
2✔
549
    mixture=(3,
550
     1,
551
     [array([0., 0., 0.])],
552
     [array([[32.00000007, 35.99999993,  8.00000001],
553
             [35.99999993, 40.50000006,  8.99999998],
554
             [ 8.00000001,  8.99999998,  2.00000008]])],
555
     array([1.]),
556
     array([[0., 0., 0.]])),
557
).via('discovered failure')
558
@example(
2✔
559
    mixture=(4,
560
     1,
561
     [array([0., 0., 0., 0.])],
562
     [array([[ 198.43139701,  208.34459689,  178.58081708, -148.82075727],
563
             [ 208.34459689,  586.13441827,  514.06354813, -278.71597676],
564
             [ 178.58081708,  514.06354813,  450.99268609, -242.78684396],
565
             [-148.82075727, -278.71597676, -242.78684396,  152.43362983]])],
566
     array([1.]),
567
     array([[0., 0., 0., 0.]])),
568
).via('discovered failure')
569
def test_mixture(mixture):
2✔
570
    ndim, ncomponents, means, covs, weights, x = mixture
2✔
571
    mask = np.ones(x.shape, dtype=bool)
2✔
572

573
    if not all([valid_covariance_matrix(cov) for cov in covs]):
2✔
574
        return
2✔
575

576
    print("inputs:", ndim, ncomponents, means, covs, weights, x, mask)
2✔
577
    gmm = askcarl.GaussianMixture(weights, means, covs)
2✔
578
    #assert_allclose(gmm.log_weights, np.log(weights))
579
    askcarl_p = gmm.pdf(x, mask=mask)
2✔
580
    askcarl_logp = gmm.logpdf(x, mask=mask)
2✔
581
    gaussians = [multivariate_normal(mean, cov) for mean, cov in zip(means, covs)]
2✔
582
    if len(gaussians) == 1:
2✔
583
        assert_allclose(gmm.log_weights, 0)
2✔
584
        assert_allclose(gmm.components[0].pdf(x, mask), gaussians[0].pdf(x))
2✔
585
        assert_allclose(gmm.components[0].logpdf(x, mask), gaussians[0].logpdf(x))
2✔
586
        assert_allclose(askcarl_p, gaussians[0].pdf(x))
2✔
587
        assert_allclose(askcarl_logp, gaussians[0].logpdf(x))
2✔
588

589
    pdf_expected = sum(w * g.pdf(x) for g, w in zip(gaussians, weights))
2✔
590
    logpdf_expected = logsumexp([np.log(w) + g.logpdf(x) for g, w in zip(gaussians, weights)], axis=0)
2✔
591
    assert_allclose(askcarl_p, pdf_expected)
2✔
592
    assert_allclose(askcarl_logp, logpdf_expected)
2✔
593

594
    target_mixture = pypmc.density.mixture.create_gaussian_mixture(
2✔
595
        means, covs, weights)
596
    pypmc_logp = np.array([target_mixture.evaluate(xi) for xi in x])
2✔
597
    assert_allclose(askcarl_p, np.exp(pypmc_logp), atol=1e-300, rtol=1e-4)
2✔
598
    assert_allclose(askcarl_logp[pypmc_logp>-100000], pypmc_logp[pypmc_logp>-100000], atol=1)
2✔
599
    assert_allclose(askcarl_logp[askcarl_logp>-100000], askcarl_logp[askcarl_logp>-100000], atol=1)
2✔
600

601
    precisions = [np.linalg.inv(cov) for cov in covs]
2✔
602
    # compare results of GMM to sklearn
603
    try:
2✔
604
        skgmm = sklearn.mixture.GaussianMixture(
2✔
605
            n_components=ncomponents, weights_init=weights,
606
            means_init=means, precisions_init=precisions)
607
        skgmm._initialize(np.zeros((1, 1)), None)
2✔
608
    except np.linalg.LinAlgError:
×
609
        return
×
610
    skgmm._set_parameters((weights, np.array(means), covs, skgmm.precisions_cholesky_))
2✔
611
    assert_allclose(skgmm.weights_, weights)
2✔
612
    assert_allclose(skgmm.means_, means)
2✔
613
    assert_allclose(skgmm.covariances_, covs)
2✔
614
    # compare results of GMM to pypmc
615
    print(x, skgmm.means_, skgmm.precisions_cholesky_)
2✔
616
    sk_logp = logsumexp(
2✔
617
        np.log(weights).reshape((1, -1)) + _estimate_log_gaussian_prob(x, skgmm.means_, skgmm.precisions_cholesky_, 'full'),
618
            axis=1)
619
    print(sk_logp)
2✔
620
    assert sk_logp.shape == (len(x),), (sk_logp.shape, len(x))
2✔
621
    print(skgmm.weights_, askcarl_logp, askcarl_p)
2✔
622
    sk_p = skgmm.predict_proba(x)
2✔
623
    # TODO: https://github.com/scikit-learn/scikit-learn/issues/29989
624
    # commented out for now
625
    # assert_allclose(askcarl_logp, sk_logp, atol=1e-2, rtol=1e-2)
626
    # sk_logp1, sk_logp2 = skgmm._estimate_log_prob_resp(x)
627
    # assert_allclose(sk_logp1, sk_logp)
628
    # assert_allclose(askcarl_p, sk_p, atol=1e-300, rtol=1e-4)
629

STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc