• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9336

pending completion
9336

Pull #3590

travis-ci

web-flow
Update pymc3/ode/ode.py

Co-Authored-By: Thomas Wiecki <thomas.wiecki@gmail.com>
Pull Request #3590: Add Differential Equation API

297 of 297 new or added lines in 5 files covered. (100.0%)

17939 of 20051 relevant lines covered (89.47%)

3.71 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.15
/pymc3/backends/base.py
1
"""Base backend for traces
2

3
See the docstring for pymc3.backends for more information (including
4
creating custom backends).
5
"""
6
import itertools as itl
9✔
7
import logging
9✔
8

9
import numpy as np
9✔
10
import warnings
9✔
11
import theano.tensor as tt
9✔
12

13
from ..model import modelcontext
9✔
14
from .report import SamplerReport, merge_reports
9✔
15

16
logger = logging.getLogger('pymc3')
9✔
17

18

19
class BackendError(Exception):
9✔
20
    pass
9✔
21

22

23
class BaseTrace:
9✔
24
    """Base trace object
25

26
    Parameters
27
    ----------
28
    name : str
29
        Name of backend
30
    model : Model
31
        If None, the model is taken from the `with` context.
32
    vars : list of variables
33
        Sampling values will be stored for these variables. If None,
34
        `model.unobserved_RVs` is used.
35
    test_point : dict
36
        use different test point that might be with changed variables shapes
37
    """
38

39
    supports_sampler_stats = False
9✔
40

41
    def __init__(self, name, model=None, vars=None, test_point=None):
9✔
42
        self.name = name
9✔
43

44
        model = modelcontext(model)
9✔
45
        self.model = model
9✔
46
        if vars is None:
9✔
47
            vars = model.unobserved_RVs
9✔
48
        self.vars = vars
9✔
49
        self.varnames = [var.name for var in vars]
9✔
50
        self.fn = model.fastfn(vars)
9✔
51

52
        # Get variable shapes. Most backends will need this
53
        # information.
54
        if test_point is None:
9✔
55
            test_point = model.test_point
9✔
56
        else:
57
            test_point_ = model.test_point.copy()
5✔
58
            test_point_.update(test_point)
5✔
59
            test_point = test_point_
5✔
60
        var_values = list(zip(self.varnames, self.fn(test_point)))
9✔
61
        self.var_shapes = {var: value.shape
9✔
62
                           for var, value in var_values}
63
        self.var_dtypes = {var: value.dtype
9✔
64
                           for var, value in var_values}
65
        self.chain = None
9✔
66
        self._is_base_setup = False
9✔
67
        self.sampler_vars = None
9✔
68
        self._warnings = []
9✔
69

70
    def _add_warnings(self, warnings):
9✔
71
        self._warnings.extend(warnings)
7✔
72

73
    # Sampling methods
74

75
    def _set_sampler_vars(self, sampler_vars):
9✔
76
        if sampler_vars is not None and not self.supports_sampler_stats:
9✔
77
            raise ValueError("Backend does not support sampler stats.")
×
78

79
        if self._is_base_setup and self.sampler_vars != sampler_vars:
9✔
80
            raise ValueError("Can't change sampler_vars")
2✔
81

82
        if sampler_vars is None:
9✔
83
            self.sampler_vars = None
9✔
84
            return
9✔
85

86
        dtypes = {}
9✔
87
        for stats in sampler_vars:
9✔
88
            for key, dtype in stats.items():
9✔
89
                if dtypes.setdefault(key, dtype) != dtype:
9✔
90
                    raise ValueError("Sampler statistic %s appears with "
2✔
91
                                     "different types." % key)
92

93
        self.sampler_vars = sampler_vars
9✔
94

95
    def setup(self, draws, chain, sampler_vars=None):
9✔
96
        """Perform chain-specific setup.
97

98
        Parameters
99
        ----------
100
        draws : int
101
            Expected number of draws
102
        chain : int
103
            Chain number
104
        sampler_vars : list of dictionaries (name -> dtype), optional
105
            Diagnostics / statistics for each sampler. Before passing this
106
            to a backend, you should check, that the `supports_sampler_state`
107
            flag is set.
108
        """
109
        self._set_sampler_vars(sampler_vars)
9✔
110
        self._is_base_setup = True
9✔
111

112
    def record(self, point, sampler_states=None):
9✔
113
        """Record results of a sampling iteration.
114

115
        Parameters
116
        ----------
117
        point : dict
118
            Values mapped to variable names
119
        sampler_states : list of dicts
120
            The diagnostic values for each sampler
121
        """
122
        raise NotImplementedError
×
123

124
    def close(self):
9✔
125
        """Close the database backend.
126

127
        This is called after sampling has finished.
128
        """
129
        pass
×
130

131
    # Selection methods
132

133
    def __getitem__(self, idx):
9✔
134
        if isinstance(idx, slice):
7✔
135
            return self._slice(idx)
3✔
136

137
        try:
7✔
138
            return self.point(int(idx))
7✔
139
        except (ValueError, TypeError):  # Passed variable or variable name.
7✔
140
            raise ValueError('Can only index with slice or integer')
×
141

142
    def __len__(self):
9✔
143
        raise NotImplementedError
×
144

145
    def get_values(self, varname, burn=0, thin=1):
9✔
146
        """Get values from trace.
147

148
        Parameters
149
        ----------
150
        varname : str
151
        burn : int
152
        thin : int
153

154
        Returns
155
        -------
156
        A NumPy array
157
        """
158
        raise NotImplementedError
×
159

160
    def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
9✔
161
        """Get sampler statistics from the trace.
162

163
        Parameters
164
        ----------
165
        stat_name : str
166
        sampler_idx : int or None
167
        burn : int
168
        thin : int
169

170
        Returns
171
        -------
172
        If the `sampler_idx` is specified, return the statistic with
173
        the given name in a numpy array. If it is not specified and there
174
        is more than one sampler that provides this statistic, return
175
        a numpy array of shape (m, n), where `m` is the number of
176
        such samplers, and `n` is the number of samples.
177
        """
178
        if not self.supports_sampler_stats:
9✔
179
            raise ValueError("This backend does not support sampler stats")
×
180

181
        if sampler_idx is not None:
9✔
182
            return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
×
183

184
        sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
9✔
185
                        if stat_name in s]
186
        if not sampler_idxs:
9✔
187
            raise KeyError("Unknown sampler stat %s" % stat_name)
×
188

189
        vals = np.stack([self._get_sampler_stats(stat_name, i, burn, thin)
9✔
190
                         for i in sampler_idxs], axis=-1)
191
        if vals.shape[-1] == 1:
9✔
192
            return vals[..., 0]
9✔
193
        else:
194
            return vals
3✔
195

196
    def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
9✔
197
        """Get sampler statistics."""
198
        raise NotImplementedError()
×
199

200
    def _slice(self, idx):
9✔
201
        """Slice trace object."""
202
        raise NotImplementedError()
×
203

204
    def point(self, idx):
9✔
205
        """Return dictionary of point values at `idx` for current chain
206
        with variables names as keys.
207
        """
208
        raise NotImplementedError()
×
209

210
    @property
9✔
211
    def stat_names(self):
212
        if self.supports_sampler_stats:
2✔
213
            names = set()
2✔
214
            for vars in self.sampler_vars or []:
2✔
215
                names.update(vars.keys())
2✔
216
            return names
2✔
217
        else:
218
            return set()
2✔
219

220

221
class MultiTrace:
9✔
222
    """Main interface for accessing values from MCMC results.
223

224
    The core method to select values is `get_values`. The method
225
    to select sampler statistics is `get_sampler_stats`. Both kinds of
226
    values can also be accessed by indexing the MultiTrace object.
227
    Indexing can behave in four ways:
228

229
    1. Indexing with a variable or variable name (str) returns all
230
       values for that variable, combining values for all chains.
231

232
       >>> trace[varname]
233

234
       Slicing after the variable name can be used to burn and thin
235
       the samples.
236

237
       >>> trace[varname, 1000:]
238

239
       For convenience during interactive use, values can also be
240
       accessed using the variable as an attribute.
241

242
       >>> trace.varname
243

244
    2. Indexing with an integer returns a dictionary with values for
245
       each variable at the given index (corresponding to a single
246
       sampling iteration).
247

248
    3. Slicing with a range returns a new trace with the number of draws
249
       corresponding to the range.
250

251
    4. Indexing with the name of a sampler statistic that is not also
252
       the name of a variable returns those values from all chains.
253
       If there is more than one sampler that provides that statistic,
254
       the values are concatenated along a new axis.
255

256
    For any methods that require a single trace (e.g., taking the length
257
    of the MultiTrace instance, which returns the number of draws), the
258
    trace with the highest chain number is always used.
259

260
    Attributes
261
    ----------
262
        nchains : int
263
            Number of chains in the `MultiTrace`.
264
        chains : `List[int]`
265
            List of chain indices
266
        report : str
267
            Report on the sampling process.
268
        varnames : `List[str]`
269
            List of variable names in the trace(s)
270
    """
271

272
    def __init__(self, straces):
9✔
273
        self._straces = {}
9✔
274
        for strace in straces:
9✔
275
            if strace.chain in self._straces:
9✔
276
                raise ValueError("Chains are not unique.")
2✔
277
            self._straces[strace.chain] = strace
9✔
278

279
        self._report = SamplerReport()
9✔
280
        for strace in straces:
9✔
281
            if hasattr(strace, '_warnings'):
9✔
282
                self._report._add_warnings(strace._warnings, strace.chain)
9✔
283

284
    def __repr__(self):
9✔
285
        template = '<{}: {} chains, {} iterations, {} variables>'
×
286
        return template.format(self.__class__.__name__,
×
287
                               self.nchains, len(self), len(self.varnames))
288

289
    @property
9✔
290
    def nchains(self):
291
        return len(self._straces)
9✔
292

293
    @property
9✔
294
    def chains(self):
295
        return list(sorted(self._straces.keys()))
9✔
296

297
    @property
9✔
298
    def report(self):
299
        return self._report
9✔
300

301
    def __getitem__(self, idx):
9✔
302
        if isinstance(idx, slice):
9✔
303
            return self._slice(idx)
9✔
304

305
        try:
9✔
306
            return self.point(int(idx))
9✔
307
        except (ValueError, TypeError):  # Passed variable or variable name.
9✔
308
            pass
9✔
309

310
        if isinstance(idx, tuple):
9✔
311
            var, vslice = idx
2✔
312
            burn, thin = vslice.start, vslice.step
2✔
313
            if burn is None:
2✔
314
                burn = 0
×
315
            if thin is None:
2✔
316
                thin = 1
2✔
317
        else:
318
            var = idx
9✔
319
            burn, thin = 0, 1
9✔
320

321
        var = str(var)
9✔
322
        if var in self.varnames:
9✔
323
            if var in self.stat_names:
7✔
324
                warnings.warn("Attribute access on a trace object is ambigous. "
×
325
                              "Sampler statistic and model variable share a name. Use "
326
                              "trace.get_values or trace.get_sampler_stats.")
327
            return self.get_values(var, burn=burn, thin=thin)
7✔
328
        if var in self.stat_names:
4✔
329
            return self.get_sampler_stats(var, burn=burn, thin=thin)
4✔
330
        raise KeyError("Unknown variable %s" % var)
×
331

332
    _attrs = set(['_straces', 'varnames', 'chains', 'stat_names',
9✔
333
                  'supports_sampler_stats', '_report'])
334

335
    def __getattr__(self, name):
9✔
336
        # Avoid infinite recursion when called before __init__
337
        # variables are set up (e.g., when pickling).
338
        if name in self._attrs:
5✔
339
            raise AttributeError
×
340

341
        name = str(name)
5✔
342
        if name in self.varnames:
5✔
343
            if name in self.stat_names:
2✔
344
                warnings.warn("Attribute access on a trace object is ambigous. "
×
345
                              "Sampler statistic and model variable share a name. Use "
346
                              "trace.get_values or trace.get_sampler_stats.")
347
            return self.get_values(name)
2✔
348
        if name in self.stat_names:
5✔
349
            return self.get_sampler_stats(name)
2✔
350
        raise AttributeError("'{}' object has no attribute '{}'".format(
3✔
351
            type(self).__name__, name))
352

353
    def __len__(self):
9✔
354
        chain = self.chains[-1]
9✔
355
        return len(self._straces[chain])
9✔
356

357
    @property
9✔
358
    def varnames(self):
359
        chain = self.chains[-1]
9✔
360
        return self._straces[chain].varnames
9✔
361

362
    @property
9✔
363
    def stat_names(self):
364
        if not self._straces:
9✔
365
            return set()
×
366
        sampler_vars = [s.sampler_vars for s in self._straces.values()]
9✔
367
        if not all(svars == sampler_vars[0] for svars in sampler_vars):
9✔
368
            raise ValueError("Inividual chains contain different sampler stats")
×
369
        names = set()
9✔
370
        for trace in self._straces.values():
9✔
371
            if trace.sampler_vars is None:
9✔
372
                continue
8✔
373
            for vars in trace.sampler_vars:
9✔
374
                names.update(vars.keys())
9✔
375
        return names
9✔
376

377
    def add_values(self, vals, overwrite=False) -> None:
9✔
378
        """Add variables to traces.
379

380
        Parameters
381
        ----------
382
        vals : dict (str: array-like)
383
             The keys should be the names of the new variables. The values are expected to be
384
             array-like objects. For traces with more than one chain the length of each value
385
             should match the number of total samples already in the trace `(chains * iterations)`,
386
             otherwise a warning is raised.
387
        overwrite : bool
388
            If `False` (default) a ValueError is raised if the variable already exists.
389
            Change to `True` to overwrite the values of variables
390

391
        Returns
392
        -------
393
            None.
394
        """
395
        for k, v in vals.items():
2✔
396
            new_var = 1
2✔
397
            if k in self.varnames:
2✔
398
                if overwrite:
×
399
                    self.varnames.remove(k)
×
400
                    new_var = 0
×
401
                else:
402
                    raise ValueError("Variable name {} already exists.".format(k))
×
403

404
            self.varnames.append(k)
2✔
405

406
            chains = self._straces
2✔
407
            l_samples = len(self) * len(self.chains)
2✔
408
            l_v = len(v)
2✔
409
            if l_v != l_samples:
2✔
410
                warnings.warn("The length of the values you are trying to "
×
411
                              "add ({}) does not match the number ({}) of "
412
                              "total samples in the trace "
413
                              "(chains * iterations)".format(l_v, l_samples))
414

415
            v = np.squeeze(v.reshape(len(chains), len(self), -1))
2✔
416

417
            for idx, chain in enumerate(chains.values()):
2✔
418
                if new_var:
2✔
419
                    dummy = tt.as_tensor_variable([], k)
2✔
420
                    chain.vars.append(dummy)
2✔
421
                chain.samples[k] = v[idx]
2✔
422

423
    def remove_values(self, name):
9✔
424
        """remove variables from traces.
425

426
        Parameters
427
        ----------
428
        name : str
429
            Name of the variable to remove. Raises KeyError if the variable is not present
430
        """
431
        varnames = self.varnames
2✔
432
        if name not in varnames:
2✔
433
            raise KeyError("Unknown variable {}".format(name))
×
434
        self.varnames.remove(name)
2✔
435
        chains = self._straces
2✔
436
        for chain in chains.values():
2✔
437
            for va in chain.vars:
2✔
438
                if va.name == name:
2✔
439
                    chain.vars.remove(va)
2✔
440
                    del chain.samples[name]
2✔
441

442
    def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
9✔
443
                   squeeze=True):
444
        """Get values from traces.
445

446
        Parameters
447
        ----------
448
        varname : str
449
        burn : int
450
        thin : int
451
        combine : bool
452
            If True, results from `chains` will be concatenated.
453
        chains : int or list of ints
454
            Chains to retrieve. If None, all chains are used. A single
455
            chain value can also be given.
456
        squeeze : bool
457
            Return a single array element if the resulting list of
458
            values only has one element. If False, the result will
459
            always be a list of arrays, even if `combine` is True.
460

461
        Returns
462
        -------
463
        A list of NumPy arrays or a single NumPy array (depending on
464
        `squeeze`).
465
        """
466
        if chains is None:
9✔
467
            chains = self.chains
9✔
468
        varname = str(varname)
9✔
469
        try:
9✔
470
            results = [self._straces[chain].get_values(varname, burn, thin)
9✔
471
                       for chain in chains]
472
        except TypeError:  # Single chain passed.
5✔
473
            results = [self._straces[chains].get_values(varname, burn, thin)]
5✔
474
        return _squeeze_cat(results, combine, squeeze)
9✔
475

476
    def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
9✔
477
                          chains=None, squeeze=True):
478
        """Get sampler statistics from the trace.
479

480
        Parameters
481
        ----------
482
        stat_name : str
483
        sampler_idx : int or None
484
        burn : int
485
        thin : int
486

487
        Returns
488
        -------
489
        If the `sampler_idx` is specified, return the statistic with
490
        the given name in a numpy array. If it is not specified and there
491
        is more than one sampler that provides this statistic, return
492
        a numpy array of shape (m, n), where `m` is the number of
493
        such samplers, and `n` is the number of samples.
494
        """
495
        if stat_name not in self.stat_names:
9✔
496
            raise KeyError("Unknown sampler statistic %s" % stat_name)
×
497

498
        if chains is None:
9✔
499
            chains = self.chains
9✔
500
        try:
9✔
501
            chains = iter(chains)
9✔
502
        except TypeError:
2✔
503
            chains = [chains]
2✔
504

505
        results = [self._straces[chain].get_sampler_stats(stat_name, None, burn, thin)
9✔
506
                   for chain in chains]
507
        return _squeeze_cat(results, combine, squeeze)
9✔
508

509
    def _slice(self, slice):
9✔
510
        """Return a new MultiTrace object sliced according to `slice`."""
511
        new_traces = [trace._slice(slice) for trace in self._straces.values()]
9✔
512
        trace = MultiTrace(new_traces)
9✔
513
        idxs = slice.indices(len(self))
9✔
514
        trace._report = self._report._slice(*idxs)
9✔
515
        return trace
9✔
516

517
    def point(self, idx, chain=None):
9✔
518
        """Return a dictionary of point values at `idx`.
519

520
        Parameters
521
        ----------
522
        idx : int
523
        chain : int
524
            If a chain is not given, the highest chain number is used.
525
        """
526
        if chain is None:
7✔
527
            chain = self.chains[-1]
7✔
528
        return self._straces[chain].point(idx)
7✔
529

530
    def points(self, chains=None):
9✔
531
        """Return an iterator over all or some of the sample points
532

533
        Parameters
534
        ----------
535
        chains : list of int or N
536
            The chains whose points should be inlcuded in the iterator.  If
537
            chains is not given, include points from all chains.
538
        """
539
        if chains is None:
7✔
540
            chains = self.chains
×
541

542
        return itl.chain.from_iterable(self._straces[chain] for chain in chains)
7✔
543

544

545
def merge_traces(mtraces):
9✔
546
    """Merge MultiTrace objects.
547

548
    Parameters
549
    ----------
550
    mtraces : list of MultiTraces
551
        Each instance should have unique chain numbers.
552

553
    Raises
554
    ------
555
    A ValueError is raised if any traces have overlapping chain numbers.
556

557
    Returns
558
    -------
559
    A MultiTrace instance with merged chains
560
    """
561
    base_mtrace = mtraces[0]
2✔
562
    for new_mtrace in mtraces[1:]:
2✔
563
        for new_chain, strace in new_mtrace._straces.items():
2✔
564
            if new_chain in base_mtrace._straces:
2✔
565
                raise ValueError("Chains are not unique.")
2✔
566
            base_mtrace._straces[new_chain] = strace
×
567
    base_mtrace._report = merge_reports([trace.report for trace in mtraces])
×
568
    return base_mtrace
×
569

570

571
def _squeeze_cat(results, combine, squeeze):
9✔
572
    """Squeeze and concatenate the results depending on values of
573
    `combine` and `squeeze`."""
574
    if combine:
9✔
575
        results = np.concatenate(results)
9✔
576
        if not squeeze:
9✔
577
            results = [results]
2✔
578
    else:
579
        if squeeze and len(results) == 1:
9✔
580
            results = results[0]
2✔
581
    return results
9✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc