• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 8566

pending completion
8566

Pull #3327

travis-ci

web-flow
Added assertion to unused argument error
Pull Request #3327: WIP: Merge nuts_kwargs and step_kwargs into kwargs

14 of 14 new or added lines in 2 files covered. (100.0%)

9961 of 20108 relevant lines covered (49.54%)

1.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.68
/pymc3/step_methods/hmc/integration.py
1
from collections import namedtuple
6✔
2

3
import numpy as np
6✔
4
from scipy import linalg
6✔
5

6

7
State = namedtuple("State", 'q, p, v, q_grad, energy, model_logp')
6✔
8

9

10
class IntegrationError(RuntimeError):
6✔
11
    pass
6✔
12

13

14
class CpuLeapfrogIntegrator:
6✔
15
    def __init__(self, potential, logp_dlogp_func):
6✔
16
        """Leapfrog integrator using CPU."""
17
        self._potential = potential
4✔
18
        self._logp_dlogp_func = logp_dlogp_func
4✔
19
        self._dtype = self._logp_dlogp_func.dtype
4✔
20
        if self._potential.dtype != self._dtype:
4✔
21
            raise ValueError("dtypes of potential (%s) and logp function (%s)"
×
22
                             "don't match."
23
                             % (self._potential.dtype, self._dtype))
24

25
    def compute_state(self, q, p):
6✔
26
        """Compute Hamiltonian functions using a position and momentum."""
27
        if q.dtype != self._dtype or p.dtype != self._dtype:
4✔
28
            raise ValueError('Invalid dtype. Must be %s' % self._dtype)
1✔
29
        logp, dlogp = self._logp_dlogp_func(q)
4✔
30
        v = self._potential.velocity(p)
4✔
31
        kinetic = self._potential.energy(p, velocity=v)
4✔
32
        energy = kinetic - logp
4✔
33
        return State(q, p, v, dlogp, energy, logp)
4✔
34

35
    def step(self, epsilon, state, out=None):
6✔
36
        """Leapfrog integrator step.
37

38
        Half a momentum update, full position update, half momentum update.
39

40
        Parameters
41
        ----------
42
        epsilon: float, > 0
43
            step scale
44
        state: State namedtuple,
45
            current position data
46
        out: (optional) State namedtuple,
47
            preallocated arrays to write to in place
48

49
        Returns
50
        -------
51
        None if `out` is provided, else a State namedtuple
52
        """
53
        try:
4✔
54
            return self._step(epsilon, state, out=None)
4✔
55
        except linalg.LinAlgError as err:
×
56
            msg = "LinAlgError during leapfrog step."
×
57
            raise IntegrationError(msg)
×
58
        except ValueError as err:
×
59
            # Raised by many scipy.linalg functions
60
            scipy_msg = "array must not contain infs or nans"
×
61
            if len(err.args) > 0 and scipy_msg in err.args[0].lower():
×
62
                msg = "Infs or nans in scipy.linalg during leapfrog step."
×
63
                raise IntegrationError(msg)
×
64
            else:
65
                raise
×
66

67
    def _step(self, epsilon, state, out=None):
6✔
68
        pot = self._potential
4✔
69
        axpy = linalg.blas.get_blas_funcs('axpy', dtype=self._dtype)
4✔
70

71
        q, p, v, q_grad, energy, logp = state
4✔
72
        if out is None:
4✔
73
            q_new = q.copy()
4✔
74
            p_new = p.copy()
4✔
75
            v_new = np.empty_like(q)
4✔
76
            q_new_grad = np.empty_like(q)
4✔
77
        else:
78
            q_new, p_new, v_new, q_new_grad, energy = out
×
79
            q_new[:] = q
×
80
            p_new[:] = p
×
81

82
        dt = 0.5 * epsilon
4✔
83

84
        # p is already stored in p_new
85
        # p_new = p + dt * q_grad
86
        axpy(q_grad, p_new, a=dt)
4✔
87

88
        pot.velocity(p_new, out=v_new)
4✔
89
        # q is already stored in q_new
90
        # q_new = q + epsilon * v_new
91
        axpy(v_new, q_new, a=epsilon)
4✔
92

93
        logp = self._logp_dlogp_func(q_new, q_new_grad)
4✔
94

95
        # p_new = p_new + dt * q_new_grad
96
        axpy(q_new_grad, p_new, a=dt)
4✔
97

98
        kinetic = pot.velocity_energy(p_new, v_new)
4✔
99
        energy = kinetic - logp
4✔
100

101
        if out is not None:
4✔
102
            out.energy = energy
×
103
            return
×
104
        else:
105
            return State(q_new, p_new, v_new, q_new_grad, energy, logp)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc