• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9386

pending completion
9386

Pull #3634

travis-ci

web-flow
WIP: refactoring the DifferentialEquation Op
+ full support for test_values
+ explicit input/output types
+ 2D return shape
+ optional return of sensitivities
+ gradient without helper Op
Pull Request #3634: [WIP] DifferentialEquation Op improvements

99 of 99 new or added lines in 3 files covered. (100.0%)

42609 of 80628 relevant lines covered (52.85%)

1.86 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.58
/pymc3/model.py
1
import collections
5✔
2
import functools
5✔
3
import itertools
5✔
4
import threading
5✔
5
import warnings
5✔
6
from typing import Optional
5✔
7

8
import numpy as np
5✔
9
from pandas import Series
5✔
10
import scipy.sparse as sps
5✔
11
import theano.sparse as sparse
5✔
12
from theano import theano, tensor as tt
5✔
13
from theano.tensor.var import TensorVariable
5✔
14
from theano.compile import SharedVariable
5✔
15

16
from pymc3.theanof import set_theano_conf, floatX
5✔
17
import pymc3 as pm
5✔
18
from pymc3.math import flatten_list
5✔
19
from .memoize import memoize, WithMemoization
5✔
20
from .theanof import gradient, hessian, inputvars, generator
5✔
21
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
5✔
22
from .blocking import DictToArrayBijection, ArrayOrdering
5✔
23
from .util import get_transformed_name
5✔
24

25
__all__ = [
5✔
26
    'Model', 'Factor', 'compilef', 'fn', 'fastfn', 'modelcontext',
27
    'Point', 'Deterministic', 'Potential', 'set_data'
28
]
29

30
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')
5✔
31

32

33
class PyMC3Variable(TensorVariable):
5✔
34
    """Class to wrap Theano TensorVariable for custom behavior."""
35

36
    # Implement matrix multiplication infix operator: X @ w
37
    __matmul__ = tt.dot
5✔
38

39
    def __rmatmul__(self, other):
5✔
40
        return tt.dot(other, self)
×
41

42

43
class InstanceMethod:
5✔
44
    """Class for hiding references to instance methods so they can be pickled.
45

46
    >>> self.method = InstanceMethod(some_object, 'method_name')
47
    """
48

49
    def __init__(self, obj, method_name):
5✔
50
        self.obj = obj
5✔
51
        self.method_name = method_name
5✔
52

53
    def __call__(self, *args, **kwargs):
5✔
54
        return getattr(self.obj, self.method_name)(*args, **kwargs)
4✔
55

56

57
def incorporate_methods(source, destination, methods, default=None,
5✔
58
                        wrapper=None, override=False):
59
    """
60
    Add attributes to a destination object which points to
61
    methods from from a source object.
62

63
    Parameters
64
    ----------
65
    source : object
66
        The source object containing the methods.
67
    destination : object
68
        The destination object for the methods.
69
    methods : list of str
70
        Names of methods to incorporate.
71
    default : object
72
        The value used if the source does not have one of the listed methods.
73
    wrapper : function
74
        An optional function to allow the source method to be
75
        wrapped. Should take the form my_wrapper(source, method_name)
76
        and return a single value.
77
    override : bool
78
        If the destination object already has a method/attribute
79
        an AttributeError will be raised if override is False (the default).
80
    """
81
    for method in methods:
5✔
82
        if hasattr(destination, method) and not override:
5✔
83
            raise AttributeError("Cannot add method {!r}".format(method) +
×
84
                                 "to destination object as it already exists. "
85
                                 "To prevent this error set 'override=True'.")
86
        if hasattr(source, method):
5✔
87
            if wrapper is None:
5✔
88
                setattr(destination, method, getattr(source, method))
×
89
            else:
90
                setattr(destination, method, wrapper(source, method))
5✔
91
        else:
92
            setattr(destination, method, None)
5✔
93

94
def get_named_nodes_and_relations(graph):
5✔
95
    """Get the named nodes in a theano graph (i.e., nodes whose name
96
    attribute is not None) along with their relationships (i.e., the
97
    node's named parents, and named children, while skipping unnamed
98
    intermediate nodes)
99

100
    Parameters
101
    ----------
102
    graph - a theano node
103

104
    Returns:
105
    leaf_nodes: A dictionary of name:node pairs, of the named nodes that
106
        are also leafs of the graph
107
    node_parents: A dictionary of node:set([parents]) pairs. Each key is
108
        a theano named node, and the corresponding value is the set of
109
        theano named nodes that are parents of the node. These parental
110
        relations skip unnamed intermediate nodes.
111
    node_children: A dictionary of node:set([children]) pairs. Each key
112
        is a theano named node, and the corresponding value is the set
113
        of theano named nodes that are children of the node. These child
114
        relations skip unnamed intermediate nodes.
115

116
    """
117
    if graph.name is not None:
5✔
118
        node_parents = {graph: set()}
4✔
119
        node_children = {graph: set()}
4✔
120
    else:
121
        node_parents = {}
5✔
122
        node_children = {}
5✔
123
    return _get_named_nodes_and_relations(graph, None, {}, node_parents, node_children)
5✔
124

125
def _get_named_nodes_and_relations(graph, parent, leaf_nodes,
5✔
126
                                        node_parents, node_children):
127
    if getattr(graph, 'owner', None) is None:  # Leaf node
5✔
128
        if graph.name is not None:  # Named leaf node
5✔
129
            leaf_nodes.update({graph.name: graph})
4✔
130
            if parent is not None:  # Is None for the root node
4✔
131
                try:
4✔
132
                    node_parents[graph].add(parent)
4✔
133
                except KeyError:
4✔
134
                    node_parents[graph] = {parent}
4✔
135
                node_children[parent].add(graph)
4✔
136
            else:
137
                node_parents[graph] = set()
4✔
138
            # Flag that the leaf node has no children
139
            node_children[graph] = set()
4✔
140
    else:  # Intermediate node
141
        if graph.name is not None:  # Intermediate named node
5✔
142
            if parent is not None:  # Is only None for the root node
4✔
143
                try:
4✔
144
                    node_parents[graph].add(parent)
4✔
145
                except KeyError:
4✔
146
                    node_parents[graph] = {parent}
4✔
147
                node_children[parent].add(graph)
4✔
148
            else:
149
                node_parents[graph] = set()
4✔
150
            # The current node will be set as the parent of the next
151
            # nodes only if it is a named node
152
            parent = graph
4✔
153
            # Init the nodes children to an empty set
154
            node_children[graph] = set()
4✔
155
        for i in graph.owner.inputs:
5✔
156
            temp_nodes, temp_inter, temp_tree = \
5✔
157
                _get_named_nodes_and_relations(i, parent, leaf_nodes,
158
                                               node_parents, node_children)
159
            leaf_nodes.update(temp_nodes)
5✔
160
            node_parents.update(temp_inter)
5✔
161
            node_children.update(temp_tree)
5✔
162
    return leaf_nodes, node_parents, node_children
5✔
163

164

165
class Context:
5✔
166
    """Functionality for objects that put themselves in a context using
167
    the `with` statement.
168
    """
169
    contexts = threading.local()
5✔
170

171
    def __enter__(self):
5✔
172
        type(self).get_contexts().append(self)
5✔
173
        # self._theano_config is set in Model.__new__
174
        if hasattr(self, '_theano_config'):
5✔
175
            self._old_theano_config = set_theano_conf(self._theano_config)
5✔
176
        return self
5✔
177

178
    def __exit__(self, typ, value, traceback):
5✔
179
        type(self).get_contexts().pop()
5✔
180
        # self._theano_config is set in Model.__new__
181
        if hasattr(self, '_old_theano_config'):
5✔
182
            set_theano_conf(self._old_theano_config)
5✔
183

184
    @classmethod
5✔
185
    def get_contexts(cls):
186
        # no race-condition here, cls.contexts is a thread-local object
187
        # be sure not to override contexts in a subclass however!
188
        if not hasattr(cls.contexts, 'stack'):
5✔
189
            cls.contexts.stack = []
5✔
190
        return cls.contexts.stack
5✔
191

192
    @classmethod
5✔
193
    def get_context(cls):
194
        """Return the deepest context on the stack."""
195
        try:
5✔
196
            return cls.get_contexts()[-1]
5✔
197
        except IndexError:
×
198
            raise TypeError("No context on context stack")
×
199

200

201
def modelcontext(model: Optional['Model']) -> 'Model':
5✔
202
    """return the given model or try to find it in the context if there was
203
    none supplied.
204
    """
205
    if model is None:
5✔
206
        return Model.get_context()
5✔
207
    return model
5✔
208

209

210
class Factor:
5✔
211
    """Common functionality for objects with a log probability density
212
    associated with them.
213
    """
214
    def __init__(self, *args, **kwargs):
5✔
215
        super().__init__(*args, **kwargs)
5✔
216

217
    @property
5✔
218
    def logp(self):
219
        """Compiled log probability density function"""
220
        return self.model.fn(self.logpt)
3✔
221

222
    @property
5✔
223
    def logp_elemwise(self):
224
        return self.model.fn(self.logp_elemwiset)
3✔
225

226
    def dlogp(self, vars=None):
5✔
227
        """Compiled log probability density gradient function"""
228
        return self.model.fn(gradient(self.logpt, vars))
×
229

230
    def d2logp(self, vars=None):
5✔
231
        """Compiled log probability density hessian function"""
232
        return self.model.fn(hessian(self.logpt, vars))
×
233

234
    @property
5✔
235
    def logp_nojac(self):
236
        return self.model.fn(self.logp_nojact)
×
237

238
    def dlogp_nojac(self, vars=None):
5✔
239
        """Compiled log density gradient function, without jacobian terms."""
240
        return self.model.fn(gradient(self.logp_nojact, vars))
×
241

242
    def d2logp_nojac(self, vars=None):
5✔
243
        """Compiled log density hessian function, without jacobian terms."""
244
        return self.model.fn(hessian(self.logp_nojact, vars))
×
245

246
    @property
5✔
247
    def fastlogp(self):
248
        """Compiled log probability density function"""
249
        return self.model.fastfn(self.logpt)
1✔
250

251
    def fastdlogp(self, vars=None):
5✔
252
        """Compiled log probability density gradient function"""
253
        return self.model.fastfn(gradient(self.logpt, vars))
×
254

255
    def fastd2logp(self, vars=None):
5✔
256
        """Compiled log probability density hessian function"""
257
        return self.model.fastfn(hessian(self.logpt, vars))
2✔
258

259
    @property
5✔
260
    def fastlogp_nojac(self):
261
        return self.model.fastfn(self.logp_nojact)
3✔
262

263
    def fastdlogp_nojac(self, vars=None):
5✔
264
        """Compiled log density gradient function, without jacobian terms."""
265
        return self.model.fastfn(gradient(self.logp_nojact, vars))
3✔
266

267
    def fastd2logp_nojac(self, vars=None):
5✔
268
        """Compiled log density hessian function, without jacobian terms."""
269
        return self.model.fastfn(hessian(self.logp_nojact, vars))
×
270

271
    @property
5✔
272
    def logpt(self):
273
        """Theano scalar of log-probability of the model"""
274
        if getattr(self, 'total_size', None) is not None:
5✔
275
            logp = self.logp_sum_unscaledt * self.scaling
2✔
276
        else:
277
            logp = self.logp_sum_unscaledt
5✔
278
        if self.name is not None:
5✔
279
            logp.name = '__logp_%s' % self.name
5✔
280
        return logp
5✔
281

282
    @property
5✔
283
    def logp_nojact(self):
284
        """Theano scalar of log-probability, excluding jacobian terms."""
285
        if getattr(self, 'total_size', None) is not None:
3✔
286
            logp = tt.sum(self.logp_nojac_unscaledt) * self.scaling
×
287
        else:
288
            logp = tt.sum(self.logp_nojac_unscaledt)
3✔
289
        if self.name is not None:
3✔
290
            logp.name = '__logp_%s' % self.name
3✔
291
        return logp
3✔
292

293

294
class InitContextMeta(type):
5✔
295
    """Metaclass that executes `__init__` of instance in it's context"""
296
    def __call__(cls, *args, **kwargs):
5✔
297
        instance = cls.__new__(cls, *args, **kwargs)
5✔
298
        with instance:  # appends context
5✔
299
            instance.__init__(*args, **kwargs)
5✔
300
        return instance
5✔
301

302

303
def withparent(meth):
5✔
304
    """Helper wrapper that passes calls to parent's instance"""
305
    def wrapped(self, *args, **kwargs):
5✔
306
        res = meth(self, *args, **kwargs)
5✔
307
        if getattr(self, 'parent', None) is not None:
5✔
308
            getattr(self.parent, meth.__name__)(*args, **kwargs)
3✔
309
        return res
5✔
310
    # Unfortunately functools wrapper fails
311
    # when decorating built-in methods so we
312
    # need to fix that improper behaviour
313
    wrapped.__name__ = meth.__name__
5✔
314
    return wrapped
5✔
315

316

317
class treelist(list):
5✔
318
    """A list that passes mutable extending operations used in Model
319
    to parent list instance.
320
    Extending treelist you will also extend its parent
321
    """
322
    def __init__(self, iterable=(), parent=None):
5✔
323
        super().__init__(iterable)
5✔
324
        assert isinstance(parent, list) or parent is None
5✔
325
        self.parent = parent
5✔
326
        if self.parent is not None:
5✔
327
            self.parent.extend(self)
3✔
328
    # typechecking here works bad
329
    append = withparent(list.append)
5✔
330
    __iadd__ = withparent(list.__iadd__)
5✔
331
    extend = withparent(list.extend)
5✔
332

333
    def tree_contains(self, item):
5✔
334
        if isinstance(self.parent, treedict):
×
335
            return (list.__contains__(self, item) or
×
336
                    self.parent.tree_contains(item))
337
        elif isinstance(self.parent, list):
×
338
            return (list.__contains__(self, item) or
×
339
                    self.parent.__contains__(item))
340
        else:
341
            return list.__contains__(self, item)
×
342

343
    def __setitem__(self, key, value):
5✔
344
        raise NotImplementedError('Method is removed as we are not'
×
345
                                  ' able to determine '
346
                                  'appropriate logic for it')
347

348
    def __imul__(self, other):
5✔
349
        t0 = len(self)
×
350
        list.__imul__(self, other)
×
351
        if self.parent is not None:
×
352
            self.parent.extend(self[t0:])
×
353

354

355
class treedict(dict):
5✔
356
    """A dict that passes mutable extending operations used in Model
357
    to parent dict instance.
358
    Extending treedict you will also extend its parent
359
    """
360
    def __init__(self, iterable=(), parent=None, **kwargs):
5✔
361
        super().__init__(iterable, **kwargs)
5✔
362
        assert isinstance(parent, dict) or parent is None
5✔
363
        self.parent = parent
5✔
364
        if self.parent is not None:
5✔
365
            self.parent.update(self)
3✔
366
    # typechecking here works bad
367
    __setitem__ = withparent(dict.__setitem__)
5✔
368
    update = withparent(dict.update)
5✔
369

370
    def tree_contains(self, item):
5✔
371
        # needed for `add_random_variable` method
372
        if isinstance(self.parent, treedict):
5✔
373
            return (dict.__contains__(self, item) or
3✔
374
                    self.parent.tree_contains(item))
375
        elif isinstance(self.parent, dict):
5✔
376
            return (dict.__contains__(self, item) or
×
377
                    self.parent.__contains__(item))
378
        else:
379
            return dict.__contains__(self, item)
5✔
380

381

382
class ValueGradFunction:
5✔
383
    """Create a theano function that computes a value and its gradient.
384

385
    Parameters
386
    ----------
387
    cost : theano variable
388
        The value that we compute with its gradient.
389
    grad_vars : list of named theano variables or None
390
        The arguments with respect to which the gradient is computed.
391
    extra_vars : list of named theano variables or None
392
        Other arguments of the function that are assumed constant. They
393
        are stored in shared variables and can be set using
394
        `set_extra_values`.
395
    dtype : str, default=theano.config.floatX
396
        The dtype of the arrays.
397
    casting : {'no', 'equiv', 'save', 'same_kind', 'unsafe'}, default='no'
398
        Casting rule for casting `grad_args` to the array dtype.
399
        See `numpy.can_cast` for a description of the options.
400
        Keep in mind that we cast the variables to the array *and*
401
        back from the array dtype to the variable dtype.
402
    kwargs
403
        Extra arguments are passed on to `theano.function`.
404

405
    Attributes
406
    ----------
407
    size : int
408
        The number of elements in the parameter array.
409
    profile : theano profiling object or None
410
        The profiling object of the theano function that computes value and
411
        gradient. This is None unless `profile=True` was set in the
412
        kwargs.
413
    """
414
    def __init__(self, cost, grad_vars, extra_vars=None, dtype=None,
5✔
415
                 casting='no', **kwargs):
416
        from .distributions import TensorType
3✔
417

418
        if extra_vars is None:
3✔
419
            extra_vars = []
×
420

421
        names = [arg.name for arg in grad_vars + extra_vars]
3✔
422
        if any(name is None for name in names):
3✔
423
            raise ValueError('Arguments must be named.')
×
424
        if len(set(names)) != len(names):
3✔
425
            raise ValueError('Names of the arguments are not unique.')
×
426

427
        if cost.ndim > 0:
3✔
428
            raise ValueError('Cost must be a scalar.')
×
429

430
        self._grad_vars = grad_vars
3✔
431
        self._extra_vars = extra_vars
3✔
432
        self._extra_var_names = set(var.name for var in extra_vars)
3✔
433
        self._cost = cost
3✔
434
        self._ordering = ArrayOrdering(grad_vars)
3✔
435
        self.size = self._ordering.size
3✔
436
        self._extra_are_set = False
3✔
437
        if dtype is None:
3✔
438
            dtype = theano.config.floatX
3✔
439
        self.dtype = dtype
3✔
440
        for var in self._grad_vars:
3✔
441
            if not np.can_cast(var.dtype, self.dtype, casting):
3✔
442
                raise TypeError('Invalid dtype for variable %s. Can not '
×
443
                                'cast to %s with casting rule %s.'
444
                                % (var.name, self.dtype, casting))
445
            if not np.issubdtype(var.dtype, np.floating):
3✔
446
                raise TypeError('Invalid dtype for variable %s. Must be '
×
447
                                'floating point but is %s.'
448
                                % (var.name, var.dtype))
449

450
        givens = []
3✔
451
        self._extra_vars_shared = {}
3✔
452
        for var in extra_vars:
3✔
453
            shared = theano.shared(var.tag.test_value, var.name + '_shared__')
3✔
454
            # test TensorType compatibility
455
            if hasattr(var.tag.test_value, 'shape'):
3✔
456
                testtype = TensorType(var.dtype, var.tag.test_value.shape)
3✔
457

458
                if testtype != shared.type:
3✔
459
                    shared.type = testtype
×
460
            self._extra_vars_shared[var.name] = shared
3✔
461
            givens.append((var, shared))
3✔
462

463
        self._vars_joined, self._cost_joined = self._build_joined(
3✔
464
            self._cost, grad_vars, self._ordering.vmap)
465

466
        grad = tt.grad(self._cost_joined, self._vars_joined)
3✔
467
        grad.name = '__grad'
3✔
468

469
        inputs = [self._vars_joined]
3✔
470

471
        self._theano_function = theano.function(
3✔
472
            inputs, [self._cost_joined, grad], givens=givens, **kwargs)
473

474
    def set_extra_values(self, extra_vars):
5✔
475
        self._extra_are_set = True
3✔
476
        for var in self._extra_vars:
3✔
477
            self._extra_vars_shared[var.name].set_value(extra_vars[var.name])
3✔
478

479
    def get_extra_values(self):
5✔
480
        if not self._extra_are_set:
×
481
            raise ValueError('Extra values are not set.')
×
482

483
        return {var.name: self._extra_vars_shared[var.name].get_value()
×
484
                for var in self._extra_vars}
485

486
    def __call__(self, array, grad_out=None, extra_vars=None):
5✔
487
        if extra_vars is not None:
3✔
488
            self.set_extra_values(extra_vars)
×
489

490
        if not self._extra_are_set:
3✔
491
            raise ValueError('Extra values are not set.')
×
492

493
        if array.shape != (self.size,):
3✔
494
            raise ValueError('Invalid shape for array. Must be %s but is %s.'
×
495
                             % ((self.size,), array.shape))
496

497
        if grad_out is None:
3✔
498
            out = np.empty_like(array)
3✔
499
        else:
500
            out = grad_out
3✔
501

502
        logp, dlogp = self._theano_function(array)
3✔
503
        if grad_out is None:
3✔
504
            return logp, dlogp
3✔
505
        else:
506
            np.copyto(out, dlogp)
3✔
507
            return logp
3✔
508

509
    @property
5✔
510
    def profile(self):
511
        """Profiling information of the underlying theano function."""
512
        return self._theano_function.profile
×
513

514
    def dict_to_array(self, point):
5✔
515
        """Convert a dictionary with values for grad_vars to an array."""
516
        array = np.empty(self.size, dtype=self.dtype)
3✔
517
        for varmap in self._ordering.vmap:
3✔
518
            array[varmap.slc] = point[varmap.var].ravel().astype(self.dtype)
3✔
519
        return array
3✔
520

521
    def array_to_dict(self, array):
5✔
522
        """Convert an array to a dictionary containing the grad_vars."""
523
        if array.shape != (self.size,):
3✔
524
            raise ValueError('Array should have shape (%s,) but has %s'
×
525
                             % (self.size, array.shape))
526
        if array.dtype != self.dtype:
3✔
527
            raise ValueError('Array has invalid dtype. Should be %s but is %s'
×
528
                             % (self._dtype, self.dtype))
529
        point = {}
3✔
530
        for varmap in self._ordering.vmap:
3✔
531
            data = array[varmap.slc].reshape(varmap.shp)
3✔
532
            point[varmap.var] = data.astype(varmap.dtyp)
3✔
533

534
        return point
3✔
535

536
    def array_to_full_dict(self, array):
5✔
537
        """Convert an array to a dictionary with grad_vars and extra_vars."""
538
        point = self.array_to_dict(array)
3✔
539
        for name, var in self._extra_vars_shared.items():
3✔
540
            point[name] = var.get_value()
2✔
541
        return point
3✔
542

543
    def _build_joined(self, cost, args, vmap):
5✔
544
        args_joined = tt.vector('__args_joined')
3✔
545
        args_joined.tag.test_value = np.zeros(self.size, dtype=self.dtype)
3✔
546

547
        joined_slices = {}
3✔
548
        for vmap in vmap:
3✔
549
            sliced = args_joined[vmap.slc].reshape(vmap.shp)
3✔
550
            sliced.name = vmap.var
3✔
551
            joined_slices[vmap.var] = sliced
3✔
552

553
        replace = {var: joined_slices[var.name] for var in args}
3✔
554
        return args_joined, theano.clone(cost, replace=replace)
3✔
555

556

557
class Model(Context, Factor, WithMemoization, metaclass=InitContextMeta):
5✔
558
    """Encapsulates the variables and likelihood factors of a model.
559

560
    Model class can be used for creating class based models. To create
561
    a class based model you should inherit from :class:`~.Model` and
562
    override :meth:`~.__init__` with arbitrary definitions (do not
563
    forget to call base class :meth:`__init__` first).
564

565
    Parameters
566
    ----------
567
    name : str
568
        name that will be used as prefix for names of all random
569
        variables defined within model
570
    model : Model
571
        instance of Model that is supposed to be a parent for the new
572
        instance. If ``None``, context will be used. All variables
573
        defined within instance will be passed to the parent instance.
574
        So that 'nested' model contributes to the variables and
575
        likelihood factors of parent model.
576
    theano_config : dict
577
        A dictionary of theano config values that should be set
578
        temporarily in the model context. See the documentation
579
        of theano for a complete list. Set config key
580
        ``compute_test_value`` to `raise` if it is None.
581

582
    Examples
583
    --------
584

585
    How to define a custom model
586

587
    .. code-block:: python
588

589
        class CustomModel(Model):
590
            # 1) override init
591
            def __init__(self, mean=0, sigma=1, name='', model=None):
592
                # 2) call super's init first, passing model and name
593
                # to it name will be prefix for all variables here if
594
                # no name specified for model there will be no prefix
595
                super().__init__(name, model)
596
                # now you are in the context of instance,
597
                # `modelcontext` will return self you can define
598
                # variables in several ways note, that all variables
599
                # will get model's name prefix
600

601
                # 3) you can create variables with Var method
602
                self.Var('v1', Normal.dist(mu=mean, sigma=sd))
603
                # this will create variable named like '{prefix_}v1'
604
                # and assign attribute 'v1' to instance created
605
                # variable can be accessed with self.v1 or self['v1']
606

607
                # 4) this syntax will also work as we are in the
608
                # context of instance itself, names are given as usual
609
                Normal('v2', mu=mean, sigma=sd)
610

611
                # something more complex is allowed, too
612
                half_cauchy = HalfCauchy('sd', beta=10, testval=1.)
613
                Normal('v3', mu=mean, sigma=half_cauchy)
614

615
                # Deterministic variables can be used in usual way
616
                Deterministic('v3_sq', self.v3 ** 2)
617

618
                # Potentials too
619
                Potential('p1', tt.constant(1))
620

621
        # After defining a class CustomModel you can use it in several
622
        # ways
623

624
        # I:
625
        #   state the model within a context
626
        with Model() as model:
627
            CustomModel()
628
            # arbitrary actions
629

630
        # II:
631
        #   use new class as entering point in context
632
        with CustomModel() as model:
633
            Normal('new_normal_var', mu=1, sigma=0)
634

635
        # III:
636
        #   just get model instance with all that was defined in it
637
        model = CustomModel()
638

639
        # IV:
640
        #   use many custom models within one context
641
        with Model() as model:
642
            CustomModel(mean=1, name='first')
643
            CustomModel(mean=2, name='second')
644
    """
645
    def __new__(cls, *args, **kwargs):
5✔
646
        # resolves the parent instance
647
        instance = super().__new__(cls)
5✔
648
        if kwargs.get('model') is not None:
5✔
649
            instance._parent = kwargs.get('model')
×
650
        elif cls.get_contexts():
5✔
651
            instance._parent = cls.get_contexts()[-1]
3✔
652
        else:
653
            instance._parent = None
5✔
654
        theano_config = kwargs.get('theano_config', None)
5✔
655
        if theano_config is None or 'compute_test_value' not in theano_config:
5✔
656
            theano_config = {'compute_test_value': 'raise'}
5✔
657
        instance._theano_config = theano_config
5✔
658
        return instance
5✔
659

660
    def __init__(self, name='', model=None, theano_config=None):
5✔
661
        self.name = name
5✔
662
        if self.parent is not None:
5✔
663
            self.named_vars = treedict(parent=self.parent.named_vars)
3✔
664
            self.free_RVs = treelist(parent=self.parent.free_RVs)
3✔
665
            self.observed_RVs = treelist(parent=self.parent.observed_RVs)
3✔
666
            self.deterministics = treelist(parent=self.parent.deterministics)
3✔
667
            self.potentials = treelist(parent=self.parent.potentials)
3✔
668
            self.missing_values = treelist(parent=self.parent.missing_values)
3✔
669
        else:
670
            self.named_vars = treedict()
5✔
671
            self.free_RVs = treelist()
5✔
672
            self.observed_RVs = treelist()
5✔
673
            self.deterministics = treelist()
5✔
674
            self.potentials = treelist()
5✔
675
            self.missing_values = treelist()
5✔
676

677
    @property
5✔
678
    def model(self):
679
        return self
3✔
680

681
    @property
5✔
682
    def parent(self):
683
        return self._parent
5✔
684

685
    @property
5✔
686
    def root(self):
687
        model = self
×
688
        while not model.isroot:
×
689
            model = model.parent
×
690
        return model
×
691

692
    @property
5✔
693
    def isroot(self):
694
        return self.parent is None
×
695

696
    @property
5✔
697
    @memoize(bound=True)
5✔
698
    def bijection(self):
699
        vars = inputvars(self.vars)
3✔
700

701
        bij = DictToArrayBijection(ArrayOrdering(vars),
3✔
702
                                   self.test_point)
703

704
        return bij
3✔
705

706
    @property
5✔
707
    def dict_to_array(self):
708
        return self.bijection.map
3✔
709

710
    @property
5✔
711
    def ndim(self):
712
        return sum(var.dsize for var in self.free_RVs)
5✔
713

714
    @property
5✔
715
    def logp_array(self):
716
        return self.bijection.mapf(self.fastlogp)
×
717

718
    @property
5✔
719
    def dlogp_array(self):
720
        vars = inputvars(self.cont_vars)
×
721
        return self.bijection.mapf(self.fastdlogp(vars))
×
722

723
    def logp_dlogp_function(self, grad_vars=None, **kwargs):
5✔
724
        if grad_vars is None:
3✔
725
            grad_vars = list(typefilter(self.free_RVs, continuous_types))
×
726
        else:
727
            for var in grad_vars:
3✔
728
                if var.dtype not in continuous_types:
3✔
729
                    raise ValueError("Can only compute the gradient of "
×
730
                                     "continuous types: %s" % var)
731
        varnames = [var.name for var in grad_vars]
3✔
732
        extra_vars = [var for var in self.free_RVs if var.name not in varnames]
3✔
733
        return ValueGradFunction(self.logpt, grad_vars, extra_vars, **kwargs)
3✔
734

735
    @property
5✔
736
    def logpt(self):
737
        """Theano scalar of log-probability of the model"""
738
        with self:
5✔
739
            factors = [var.logpt for var in self.basic_RVs] + self.potentials
5✔
740
            logp = tt.sum([tt.sum(factor) for factor in factors])
5✔
741
            if self.name:
5✔
742
                logp.name = '__logp_%s' % self.name
×
743
            else:
744
                logp.name = '__logp'
5✔
745
            return logp
5✔
746

747
    @property
5✔
748
    def logp_nojact(self):
749
        """Theano scalar of log-probability of the model but without the jacobian
750
        if transformed Random Variable is presented.
751
        Note that If there is no transformed variable in the model, logp_nojact
752
        will be the same as logpt as there is no need for Jacobian correction.
753
        """
754
        with self:
3✔
755
            factors = [var.logp_nojact for var in self.basic_RVs] + self.potentials
3✔
756
            logp = tt.sum([tt.sum(factor) for factor in factors])
3✔
757
            if self.name:
3✔
758
                logp.name = '__logp_nojac_%s' % self.name
×
759
            else:
760
                logp.name = '__logp_nojac'
3✔
761
            return logp
3✔
762

763
    @property
5✔
764
    def varlogpt(self):
765
        """Theano scalar of log-probability of the unobserved random variables
766
           (excluding deterministic)."""
767
        with self:
4✔
768
            factors = [var.logpt for var in self.free_RVs]
4✔
769
            return tt.sum(factors)
4✔
770

771
    @property
5✔
772
    def datalogpt(self):
773
        with self:
4✔
774
            factors = [var.logpt for var in self.observed_RVs]
4✔
775
            factors += [tt.sum(factor) for factor in self.potentials]
4✔
776
            return tt.sum(factors)
4✔
777

778
    @property
5✔
779
    def vars(self):
780
        """List of unobserved random variables used as inputs to the model
781
        (which excludes deterministics).
782
        """
783
        return self.free_RVs
5✔
784

785
    @property
5✔
786
    def basic_RVs(self):
787
        """List of random variables the model is defined in terms of
788
        (which excludes deterministics).
789
        """
790
        return self.free_RVs + self.observed_RVs
5✔
791

792
    @property
5✔
793
    def unobserved_RVs(self):
794
        """List of all random variable, including deterministic ones."""
795
        return self.vars + self.deterministics
5✔
796

797
    @property
5✔
798
    def test_point(self):
799
        """Test point used to check that the model doesn't generate errors"""
800
        return Point(((var, var.tag.test_value) for var in self.vars),
5✔
801
                     model=self)
802

803
    @property
5✔
804
    def disc_vars(self):
805
        """All the discrete variables in the model"""
806
        return list(typefilter(self.vars, discrete_types))
×
807

808
    @property
5✔
809
    def cont_vars(self):
810
        """All the continuous variables in the model"""
811
        return list(typefilter(self.vars, continuous_types))
3✔
812

813
    def Var(self, name, dist, data=None, total_size=None):
5✔
814
        """Create and add (un)observed random variable to the model with an
815
        appropriate prior distribution.
816

817
        Parameters
818
        ----------
819
        name : str
820
        dist : distribution for the random variable
821
        data : array_like (optional)
822
           If data is provided, the variable is observed. If None,
823
           the variable is unobserved.
824
        total_size : scalar
825
            upscales logp of variable with ``coef = total_size/var.shape[0]``
826

827
        Returns
828
        -------
829
        FreeRV or ObservedRV
830
        """
831
        name = self.name_for(name)
5✔
832
        if data is None:
5✔
833
            if getattr(dist, "transform", None) is None:
5✔
834
                with self:
5✔
835
                    var = FreeRV(name=name, distribution=dist,
5✔
836
                                 total_size=total_size, model=self)
837
                self.free_RVs.append(var)
5✔
838
            else:
839
                with self:
5✔
840
                    var = TransformedRV(name=name, distribution=dist,
5✔
841
                                        transform=dist.transform,
842
                                        total_size=total_size,
843
                                        model=self)
844
                pm._log.debug('Applied {transform}-transform to {name}'
5✔
845
                              ' and added transformed {orig_name} to model.'.format(
846
                                transform=dist.transform.name,
847
                                name=name,
848
                                orig_name=get_transformed_name(name, dist.transform)))
849
                self.deterministics.append(var)
5✔
850
                self.add_random_variable(var)
5✔
851
                return var
5✔
852
        elif isinstance(data, dict):
5✔
853
            with self:
×
854
                var = MultiObservedRV(name=name, data=data, distribution=dist,
×
855
                                      total_size=total_size, model=self)
856
            self.observed_RVs.append(var)
×
857
            if var.missing_values:
×
858
                self.free_RVs += var.missing_values
×
859
                self.missing_values += var.missing_values
×
860
                for v in var.missing_values:
×
861
                    self.named_vars[v.name] = v
×
862
        else:
863
            with self:
5✔
864
                var = ObservedRV(name=name, data=data,
5✔
865
                                 distribution=dist,
866
                                 total_size=total_size, model=self)
867
            self.observed_RVs.append(var)
5✔
868
            if var.missing_values:
5✔
869
                self.free_RVs.append(var.missing_values)
1✔
870
                self.missing_values.append(var.missing_values)
1✔
871
                self.named_vars[var.missing_values.name] = var.missing_values
1✔
872

873
        self.add_random_variable(var)
5✔
874
        return var
5✔
875

876
    def add_random_variable(self, var):
5✔
877
        """Add a random variable to the named variables of the model."""
878
        if self.named_vars.tree_contains(var.name):
5✔
879
            raise ValueError(
×
880
                "Variable name {} already exists.".format(var.name))
881
        self.named_vars[var.name] = var
5✔
882
        if not hasattr(self, self.name_of(var.name)):
5✔
883
            setattr(self, self.name_of(var.name), var)
5✔
884

885
    @property
5✔
886
    def prefix(self):
887
        return '%s_' % self.name if self.name else ''
5✔
888

889
    def name_for(self, name):
5✔
890
        """Checks if name has prefix and adds if needed
891
        """
892
        if self.prefix:
5✔
893
            if not name.startswith(self.prefix):
×
894
                return '{}{}'.format(self.prefix, name)
×
895
            else:
896
                return name
×
897
        else:
898
            return name
5✔
899

900
    def name_of(self, name):
5✔
901
        """Checks if name has prefix and deletes if needed
902
        """
903
        if not self.prefix or not name:
5✔
904
            return name
5✔
905
        elif name.startswith(self.prefix):
×
906
            return name[len(self.prefix):]
×
907
        else:
908
            return name
×
909

910
    def __getitem__(self, key):
5✔
911
        try:
4✔
912
            return self.named_vars[key]
4✔
913
        except KeyError as e:
×
914
            try:
×
915
                return self.named_vars[self.name_for(key)]
×
916
            except KeyError:
×
917
                raise e
×
918

919
    def makefn(self, outs, mode=None, *args, **kwargs):
5✔
920
        """Compiles a Theano function which returns ``outs`` and takes the variable
921
        ancestors of ``outs`` as inputs.
922

923
        Parameters
924
        ----------
925
        outs : Theano variable or iterable of Theano variables
926
        mode : Theano compilation mode
927

928
        Returns
929
        -------
930
        Compiled Theano function
931
        """
932
        with self:
5✔
933
            return theano.function(self.vars, outs,
5✔
934
                                   allow_input_downcast=True,
935
                                   on_unused_input='ignore',
936
                                   accept_inplace=True,
937
                                   mode=mode, *args, **kwargs)
938

939
    def fn(self, outs, mode=None, *args, **kwargs):
5✔
940
        """Compiles a Theano function which returns the values of ``outs``
941
        and takes values of model vars as arguments.
942

943
        Parameters
944
        ----------
945
        outs : Theano variable or iterable of Theano variables
946
        mode : Theano compilation mode
947

948
        Returns
949
        -------
950
        Compiled Theano function
951
        """
952
        return LoosePointFunc(self.makefn(outs, mode, *args, **kwargs), self)
3✔
953

954
    def fastfn(self, outs, mode=None, *args, **kwargs):
5✔
955
        """Compiles a Theano function which returns ``outs`` and takes values
956
        of model vars as a dict as an argument.
957

958
        Parameters
959
        ----------
960
        outs : Theano variable or iterable of Theano variables
961
        mode : Theano compilation mode
962

963
        Returns
964
        -------
965
        Compiled Theano function as point function.
966
        """
967
        f = self.makefn(outs, mode, *args, **kwargs)
5✔
968
        return FastPointFunc(f)
5✔
969

970
    def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs):
5✔
971
        """Compiles and profiles a Theano function which returns ``outs`` and
972
        takes values of model vars as a dict as an argument.
973

974
        Parameters
975
        ----------
976
        outs : Theano variable or iterable of Theano variables
977
        n : int, default 1000
978
            Number of iterations to run
979
        point : point
980
            Point to pass to the function
981
        profile : True or ProfileStats
982
        args, kwargs
983
            Compilation args
984

985
        Returns
986
        -------
987
        ProfileStats
988
            Use .summary() to print stats.
989
        """
990
        f = self.makefn(outs, profile=profile, *args, **kwargs)
×
991
        if point is None:
×
992
            point = self.test_point
×
993

994
        for _ in range(n):
×
995
            f(**point)
×
996

997
        return f.profile
×
998

999
    def flatten(self, vars=None, order=None, inputvar=None):
5✔
1000
        """Flattens model's input and returns:
1001

1002
        FlatView with
1003
            * input vector variable
1004
            * replacements ``input_var -> vars``
1005
            * view `{variable: VarMap}`
1006

1007
        Parameters
1008
        ----------
1009
        vars : list of variables or None
1010
            if None, then all model.free_RVs are used for flattening input
1011
        order : ArrayOrdering
1012
            Optional, use predefined ordering
1013
        inputvar : tt.vector
1014
            Optional, use predefined inputvar
1015

1016
        Returns
1017
        -------
1018
        flat_view
1019
        """
1020
        if vars is None:
×
1021
            vars = self.free_RVs
×
1022
        if order is None:
×
1023
            order = ArrayOrdering(vars)
×
1024
        if inputvar is None:
×
1025
            inputvar = tt.vector('flat_view', dtype=theano.config.floatX)
×
1026
            if theano.config.compute_test_value != 'off':
×
1027
                if vars:
×
1028
                    inputvar.tag.test_value = flatten_list(vars).tag.test_value
×
1029
                else:
1030
                    inputvar.tag.test_value = np.asarray([], inputvar.dtype)
×
1031
        replacements = {self.named_vars[name]: inputvar[slc].reshape(shape).astype(dtype)
×
1032
                        for name, slc, shape, dtype in order.vmap}
1033
        view = {vm.var: vm for vm in order.vmap}
×
1034
        flat_view = FlatView(inputvar, replacements, view)
×
1035
        return flat_view
×
1036

1037
    def check_test_point(self, test_point=None, round_vals=2):
5✔
1038
        """Checks log probability of test_point for all random variables in the model.
1039

1040
        Parameters
1041
        ----------
1042
        test_point : Point
1043
            Point to be evaluated.
1044
            if None, then all model.test_point is used
1045
        round_vals : int
1046
            Number of decimals to round log-probabilities
1047

1048
        Returns
1049
        -------
1050
        Pandas Series
1051
        """
1052
        if test_point is None:
1✔
1053
            test_point = self.test_point
1✔
1054

1055
        return Series({RV.name:np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
1✔
1056
            name='Log-probability of test_point')
1057

1058
    def _repr_latex_(self, name=None, dist=None):
5✔
1059
        tex_vars = []
×
1060
        for rv in itertools.chain(self.unobserved_RVs, self.observed_RVs):
×
1061
            rv_tex = rv.__latex__()
×
1062
            if rv_tex is not None:
×
1063
                array_rv = rv_tex.replace(r'\sim', r'&\sim &').strip('$')
×
1064
                tex_vars.append(array_rv)
×
1065
        return r'''$$
×
1066
            \begin{{array}}{{rcl}}
1067
            {}
1068
            \end{{array}}
1069
            $$'''.format('\\\\'.join(tex_vars))
1070

1071
    __latex__ = _repr_latex_
5✔
1072

1073

1074
def set_data(new_data, model=None):
5✔
1075
    """Sets the value of one or more data container variables.
1076

1077
    Parameters
1078
    ----------
1079
    new_data : dict
1080
        New values for the data containers. The keys of the dictionary are
1081
        the  variables names in the model and the values are the objects
1082
        with which to update.
1083
    model : Model (optional if in `with` context)
1084

1085
    Examples
1086
    --------
1087

1088
    .. code:: ipython
1089

1090
        >>> import pymc3 as pm
1091
        >>> with pm.Model() as model:
1092
        ...     x = pm.Data('x', [1., 2., 3.])
1093
        ...     y = pm.Data('y', [1., 2., 3.])
1094
        ...     beta = pm.Normal('beta', 0, 1)
1095
        ...     obs = pm.Normal('obs', x * beta, 1, observed=y)
1096
        ...     trace = pm.sample(1000, tune=1000)
1097

1098
    Set the value of `x` to predict on new data.
1099

1100
    .. code:: ipython
1101

1102
        >>> with model:
1103
        ...     pm.set_data({'x': [5,6,9]})
1104
        ...     y_test = pm.sample_posterior_predictive(trace)
1105
        >>> y_test['obs'].mean(axis=0)
1106
        array([4.6088569 , 5.54128318, 8.32953844])
1107
    """
1108
    model = modelcontext(model)
×
1109

1110
    for variable_name, new_value in new_data.items():
×
1111
        if isinstance(model[variable_name], SharedVariable):
×
1112
            model[variable_name].set_value(pandas_to_array(new_value))
×
1113
        else:
1114
            message = 'The variable `{}` must be defined as `pymc3.' \
×
1115
                      'Data` inside the model to allow updating. The ' \
1116
                      'current type is: ' \
1117
                      '{}.'.format(variable_name,
1118
                                   type(model[variable_name]))
1119
            raise TypeError(message)
×
1120

1121

1122
def fn(outs, mode=None, model=None, *args, **kwargs):
5✔
1123
    """Compiles a Theano function which returns the values of ``outs`` and
1124
    takes values of model vars as arguments.
1125

1126
    Parameters
1127
    ----------
1128
    outs : Theano variable or iterable of Theano variables
1129
    mode : Theano compilation mode
1130

1131
    Returns
1132
    -------
1133
    Compiled Theano function
1134
    """
1135
    model = modelcontext(model)
×
1136
    return model.fn(outs, mode, *args, **kwargs)
×
1137

1138

1139
def fastfn(outs, mode=None, model=None):
5✔
1140
    """Compiles a Theano function which returns ``outs`` and takes values of model
1141
    vars as a dict as an argument.
1142

1143
    Parameters
1144
    ----------
1145
    outs : Theano variable or iterable of Theano variables
1146
    mode : Theano compilation mode
1147

1148
    Returns
1149
    -------
1150
    Compiled Theano function as point function.
1151
    """
1152
    model = modelcontext(model)
×
1153
    return model.fastfn(outs, mode)
×
1154

1155

1156
def Point(*args, **kwargs):
5✔
1157
    """Build a point. Uses same args as dict() does.
1158
    Filters out variables not in the model. All keys are strings.
1159

1160
    Parameters
1161
    ----------
1162
    args, kwargs
1163
        arguments to build a dict
1164
    """
1165
    model = modelcontext(kwargs.pop('model', None))
5✔
1166
    args = list(args)
5✔
1167
    try:
5✔
1168
        d = dict(*args, **kwargs)
5✔
1169
    except Exception as e:
×
1170
        raise TypeError(
×
1171
            "can't turn {} and {} into a dict. {}".format(args, kwargs, e))
1172
    return dict((str(k), np.array(v)) for k, v in d.items()
5✔
1173
                if str(k) in map(str, model.vars))
1174

1175

1176
class FastPointFunc:
5✔
1177
    """Wraps so a function so it takes a dict of arguments instead of arguments."""
1178

1179
    def __init__(self, f):
5✔
1180
        self.f = f
5✔
1181

1182
    def __call__(self, state):
5✔
1183
        return self.f(**state)
5✔
1184

1185

1186
class LoosePointFunc:
5✔
1187
    """Wraps so a function so it takes a dict of arguments instead of arguments
1188
    but can still take arguments."""
1189

1190
    def __init__(self, f, model):
5✔
1191
        self.f = f
3✔
1192
        self.model = model
3✔
1193

1194
    def __call__(self, *args, **kwargs):
5✔
1195
        point = Point(model=self.model, *args, **kwargs)
3✔
1196
        return self.f(**point)
3✔
1197

1198
compilef = fastfn
5✔
1199

1200

1201
def _get_scaling(total_size, shape, ndim):
5✔
1202
    """
1203
    Gets scaling constant for logp
1204

1205
    Parameters
1206
    ----------
1207
    total_size : int or list[int]
1208
    shape : shape
1209
        shape to scale
1210
    ndim : int
1211
        ndim hint
1212

1213
    Returns
1214
    -------
1215
    scalar
1216
    """
1217
    if total_size is None:
5✔
1218
        coef = floatX(1)
5✔
1219
    elif isinstance(total_size, int):
2✔
1220
        if ndim >= 1:
2✔
1221
            denom = shape[0]
2✔
1222
        else:
1223
            denom = 1
×
1224
        coef = floatX(total_size) / floatX(denom)
2✔
1225
    elif isinstance(total_size, (list, tuple)):
×
1226
        if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
×
1227
            raise TypeError('Unrecognized `total_size` type, expected '
×
1228
                            'int or list of ints, got %r' % total_size)
1229
        if Ellipsis in total_size:
×
1230
            sep = total_size.index(Ellipsis)
×
1231
            begin = total_size[:sep]
×
1232
            end = total_size[sep+1:]
×
1233
            if Ellipsis in end:
×
1234
                raise ValueError('Double Ellipsis in `total_size` is restricted, got %r' % total_size)
×
1235
        else:
1236
            begin = total_size
×
1237
            end = []
×
1238
        if (len(begin) + len(end)) > ndim:
×
1239
            raise ValueError('Length of `total_size` is too big, '
×
1240
                             'number of scalings is bigger that ndim, got %r' % total_size)
1241
        elif (len(begin) + len(end)) == 0:
×
1242
            return floatX(1)
×
1243
        if len(end) > 0:
×
1244
            shp_end = shape[-len(end):]
×
1245
        else:
1246
            shp_end = np.asarray([])
×
1247
        shp_begin = shape[:len(begin)]
×
1248
        begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
×
1249
        end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
×
1250
        coefs = begin_coef + end_coef
×
1251
        coef = tt.prod(coefs)
×
1252
    else:
1253
        raise TypeError('Unrecognized `total_size` type, expected '
×
1254
                        'int or list of ints, got %r' % total_size)
1255
    return tt.as_tensor(floatX(coef))
5✔
1256

1257

1258
class FreeRV(Factor, PyMC3Variable):
5✔
1259
    """Unobserved random variable that a model is specified in terms of."""
1260

1261
    def __init__(self, type=None, owner=None, index=None, name=None,
5✔
1262
                 distribution=None, total_size=None, model=None):
1263
        """
1264
        Parameters
1265
        ----------
1266
        type : theano type (optional)
1267
        owner : theano owner (optional)
1268
        name : str
1269
        distribution : Distribution
1270
        model : Model
1271
        total_size : scalar Tensor (optional)
1272
            needed for upscaling logp
1273
        """
1274
        if type is None:
5✔
1275
            type = distribution.type
5✔
1276
        super().__init__(type, owner, index, name)
5✔
1277

1278
        if distribution is not None:
5✔
1279
            self.dshape = tuple(distribution.shape)
5✔
1280
            self.dsize = int(np.prod(distribution.shape))
5✔
1281
            self.distribution = distribution
5✔
1282
            self.tag.test_value = np.ones(
5✔
1283
                distribution.shape, distribution.dtype) * distribution.default()
1284
            self.logp_elemwiset = distribution.logp(self)
5✔
1285
            # The logp might need scaling in minibatches.
1286
            # This is done in `Factor`.
1287
            self.logp_sum_unscaledt = distribution.logp_sum(self)
5✔
1288
            self.logp_nojac_unscaledt = distribution.logp_nojac(self)
5✔
1289
            self.total_size = total_size
5✔
1290
            self.model = model
5✔
1291
            self.scaling = _get_scaling(total_size, self.shape, self.ndim)
5✔
1292

1293
            incorporate_methods(source=distribution, destination=self,
5✔
1294
                                methods=['random'],
1295
                                wrapper=InstanceMethod)
1296

1297
    def _repr_latex_(self, name=None, dist=None):
5✔
1298
        if self.distribution is None:
×
1299
            return None
×
1300
        if name is None:
×
1301
            name = self.name
×
1302
        if dist is None:
×
1303
            dist = self.distribution
×
1304
        return self.distribution._repr_latex_(name=name, dist=dist)
×
1305

1306
    __latex__ = _repr_latex_
5✔
1307

1308
    @property
5✔
1309
    def init_value(self):
1310
        """Convenience attribute to return tag.test_value"""
1311
        return self.tag.test_value
2✔
1312

1313

1314
def pandas_to_array(data):
5✔
1315
    if hasattr(data, 'values'):  # pandas
5✔
1316
        if data.isnull().any().any():  # missing values
×
1317
            ret = np.ma.MaskedArray(data.values, data.isnull().values)
×
1318
        else:
1319
            ret = data.values
×
1320
    elif hasattr(data, 'mask'):
5✔
1321
        if data.mask.any():
1✔
1322
            ret = data
1✔
1323
        else:  # empty mask
1324
            ret = data.filled()
×
1325
    elif isinstance(data, theano.gof.graph.Variable):
5✔
1326
        ret = data
4✔
1327
    elif sps.issparse(data):
5✔
1328
        ret = data
×
1329
    elif isgenerator(data):
5✔
1330
        ret = generator(data)
×
1331
    else:
1332
        ret = np.asarray(data)
5✔
1333
    return pm.floatX(ret)
5✔
1334

1335

1336
def as_tensor(data, name, model, distribution):
5✔
1337
    dtype = distribution.dtype
5✔
1338
    data = pandas_to_array(data).astype(dtype)
5✔
1339

1340
    if hasattr(data, 'mask'):
5✔
1341
        impute_message = ('Data in {name} contains missing values and'
1✔
1342
                          ' will be automatically imputed from the'
1343
                          ' sampling distribution.'.format(name=name))
1344
        warnings.warn(impute_message, UserWarning)
1✔
1345
        from .distributions import NoDistribution
1✔
1346
        testval = np.broadcast_to(distribution.default(), data.shape)[data.mask]
1✔
1347
        fakedist = NoDistribution.dist(shape=data.mask.sum(), dtype=dtype,
1✔
1348
                                       testval=testval, parent_dist=distribution)
1349
        missing_values = FreeRV(name=name + '_missing', distribution=fakedist,
1✔
1350
                                model=model)
1351
        constant = tt.as_tensor_variable(data.filled())
1✔
1352

1353
        dataTensor = tt.set_subtensor(
1✔
1354
            constant[data.mask.nonzero()], missing_values)
1355
        dataTensor.missing_values = missing_values
1✔
1356
        return dataTensor
1✔
1357
    elif sps.issparse(data):
5✔
1358
        data = sparse.basic.as_sparse(data, name=name)
×
1359
        data.missing_values = None
×
1360
        return data
×
1361
    else:
1362
        data = tt.as_tensor_variable(data, name=name)
5✔
1363
        data.missing_values = None
5✔
1364
        return data
5✔
1365

1366

1367
class ObservedRV(Factor, PyMC3Variable):
5✔
1368
    """Observed random variable that a model is specified in terms of.
1369
    Potentially partially observed.
1370
    """
1371

1372
    def __init__(self, type=None, owner=None, index=None, name=None, data=None,
5✔
1373
                 distribution=None, total_size=None, model=None):
1374
        """
1375
        Parameters
1376
        ----------
1377
        type : theano type (optional)
1378
        owner : theano owner (optional)
1379
        name : str
1380
        distribution : Distribution
1381
        model : Model
1382
        total_size : scalar Tensor (optional)
1383
            needed for upscaling logp
1384
        """
1385
        from .distributions import TensorType
5✔
1386

1387
        if hasattr(data, 'type') and isinstance(data.type, tt.TensorType):
5✔
1388
            type = data.type
4✔
1389

1390
        if type is None:
5✔
1391
            data = pandas_to_array(data)
5✔
1392
            type = TensorType(distribution.dtype, data.shape)
5✔
1393

1394
        self.observations = data
5✔
1395

1396
        super().__init__(type, owner, index, name)
5✔
1397

1398
        if distribution is not None:
5✔
1399
            data = as_tensor(data, name, model, distribution)
5✔
1400

1401
            self.missing_values = data.missing_values
5✔
1402
            self.logp_elemwiset = distribution.logp(data)
5✔
1403
            # The logp might need scaling in minibatches.
1404
            # This is done in `Factor`.
1405
            self.logp_sum_unscaledt = distribution.logp_sum(data)
5✔
1406
            self.logp_nojac_unscaledt = distribution.logp_nojac(data)
5✔
1407
            self.total_size = total_size
5✔
1408
            self.model = model
5✔
1409
            self.distribution = distribution
5✔
1410

1411
            # make this RV a view on the combined missing/nonmissing array
1412
            theano.gof.Apply(theano.compile.view_op,
5✔
1413
                             inputs=[data], outputs=[self])
1414
            self.tag.test_value = theano.compile.view_op(data).tag.test_value
5✔
1415
            self.scaling = _get_scaling(total_size, data.shape, data.ndim)
5✔
1416

1417
    def _repr_latex_(self, name=None, dist=None):
5✔
1418
        if self.distribution is None:
×
1419
            return None
×
1420
        if name is None:
×
1421
            name = self.name
×
1422
        if dist is None:
×
1423
            dist = self.distribution
×
1424
        return self.distribution._repr_latex_(name=name, dist=dist)
×
1425

1426
    __latex__ = _repr_latex_
5✔
1427

1428
    @property
5✔
1429
    def init_value(self):
1430
        """Convenience attribute to return tag.test_value"""
1431
        return self.tag.test_value
×
1432

1433

1434
class MultiObservedRV(Factor):
5✔
1435
    """Observed random variable that a model is specified in terms of.
1436
    Potentially partially observed.
1437
    """
1438

1439
    def __init__(self, name, data, distribution, total_size=None, model=None):
5✔
1440
        """
1441
        Parameters
1442
        ----------
1443
        type : theano type (optional)
1444
        owner : theano owner (optional)
1445
        name : str
1446
        distribution : Distribution
1447
        model : Model
1448
        total_size : scalar Tensor (optional)
1449
            needed for upscaling logp
1450
        """
1451
        self.name = name
×
1452
        self.data = {name: as_tensor(data, name, model, distribution)
×
1453
                     for name, data in data.items()}
1454

1455
        self.missing_values = [datum.missing_values for datum in self.data.values()
×
1456
                               if datum.missing_values is not None]
1457
        self.logp_elemwiset = distribution.logp(**self.data)
×
1458
        # The logp might need scaling in minibatches.
1459
        # This is done in `Factor`.
1460
        self.logp_sum_unscaledt = distribution.logp_sum(**self.data)
×
1461
        self.logp_nojac_unscaledt = distribution.logp_nojac(**self.data)
×
1462
        self.total_size = total_size
×
1463
        self.model = model
×
1464
        self.distribution = distribution
×
1465
        self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
×
1466

1467
    # Make hashable by id for draw_values
1468
    def __hash__(self):
5✔
1469
        return id(self)
×
1470

1471
    def __eq__(self, other):
5✔
1472
        return self.id == other.id
×
1473

1474
    def __ne__(self, other):
5✔
1475
        return not self == other
×
1476

1477

1478
def _walk_up_rv(rv):
5✔
1479
    """Walk up theano graph to get inputs for deterministic RV."""
1480
    all_rvs = []
×
1481
    parents = list(itertools.chain(*[j.inputs for j in rv.get_parents()]))
×
1482
    if parents:
×
1483
        for parent in parents:
×
1484
            all_rvs.extend(_walk_up_rv(parent))
×
1485
    else:
1486
        if rv.name:
×
1487
            all_rvs.append(r'\text{%s}' % rv.name)
×
1488
        else:
1489
            all_rvs.append(r'\text{Constant}')
×
1490
    return all_rvs
×
1491

1492

1493
def _latex_repr_rv(rv):
5✔
1494
    """Make latex string for a Deterministic variable"""
1495
    return r'$\text{%s} \sim \text{Deterministic}(%s)$' % (rv.name, r',~'.join(_walk_up_rv(rv)))
×
1496

1497

1498
def Deterministic(name, var, model=None):
5✔
1499
    """Create a named deterministic variable
1500

1501
    Parameters
1502
    ----------
1503
    name : str
1504
    var : theano variables
1505

1506
    Returns
1507
    -------
1508
    var : var, with name attribute
1509
    """
1510
    model = modelcontext(model)
5✔
1511
    var = var.copy(model.name_for(name))
5✔
1512
    model.deterministics.append(var)
5✔
1513
    model.add_random_variable(var)
5✔
1514
    var._repr_latex_ = functools.partial(_latex_repr_rv, var)
5✔
1515
    var.__latex__ = var._repr_latex_
5✔
1516
    return var
5✔
1517

1518

1519
def Potential(name, var, model=None):
5✔
1520
    """Add an arbitrary factor potential to the model likelihood
1521

1522
    Parameters
1523
    ----------
1524
    name : str
1525
    var : theano variables
1526

1527
    Returns
1528
    -------
1529
    var : var, with name attribute
1530
    """
1531
    model = modelcontext(model)
4✔
1532
    var.name = model.name_for(name)
4✔
1533
    model.potentials.append(var)
4✔
1534
    model.add_random_variable(var)
4✔
1535
    return var
4✔
1536

1537

1538
class TransformedRV(PyMC3Variable):
5✔
1539
    """
1540
    Parameters
1541
    ----------
1542

1543
    type : theano type (optional)
1544
    owner : theano owner (optional)
1545
    name : str
1546
    distribution : Distribution
1547
    model : Model
1548
    total_size : scalar Tensor (optional)
1549
        needed for upscaling logp
1550
    """
1551

1552
    def __init__(self, type=None, owner=None, index=None, name=None,
5✔
1553
                 distribution=None, model=None, transform=None,
1554
                 total_size=None):
1555
        if type is None:
5✔
1556
            type = distribution.type
5✔
1557
        super().__init__(type, owner, index, name)
5✔
1558

1559
        self.transformation = transform
5✔
1560

1561
        if distribution is not None:
5✔
1562
            self.model = model
5✔
1563
            self.distribution = distribution
5✔
1564
            self.dshape = tuple(distribution.shape)
5✔
1565
            self.dsize = int(np.prod(distribution.shape))
5✔
1566

1567
            transformed_name = get_transformed_name(name, transform)
5✔
1568

1569
            self.transformed = model.Var(
5✔
1570
                transformed_name, transform.apply(distribution), total_size=total_size)
1571

1572
            normalRV = transform.backward(self.transformed)
5✔
1573

1574
            theano.Apply(theano.compile.view_op, inputs=[
5✔
1575
                         normalRV], outputs=[self])
1576
            self.tag.test_value = normalRV.tag.test_value
5✔
1577
            self.scaling = _get_scaling(total_size, self.shape, self.ndim)
5✔
1578
            incorporate_methods(source=distribution, destination=self,
5✔
1579
                                methods=['random'],
1580
                                wrapper=InstanceMethod)
1581

1582
    def _repr_latex_(self, name=None, dist=None):
5✔
1583
        if self.distribution is None:
×
1584
            return None
×
1585
        if name is None:
×
1586
            name = self.name
×
1587
        if dist is None:
×
1588
            dist = self.distribution
×
1589
        return self.distribution._repr_latex_(name=name, dist=dist)
×
1590

1591
    __latex__ = _repr_latex_
5✔
1592

1593
    @property
5✔
1594
    def init_value(self):
1595
        """Convenience attribute to return tag.test_value"""
1596
        return self.tag.test_value
2✔
1597

1598

1599
def as_iterargs(data):
5✔
1600
    if isinstance(data, tuple):
×
1601
        return data
×
1602
    else:
1603
        return [data]
×
1604

1605

1606
def all_continuous(vars):
5✔
1607
    """Check that vars not include discrete variables, excepting
1608
    ObservedRVs.  """
1609
    vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)]
3✔
1610
    if any([var.dtype in pm.discrete_types for var in vars_]):
3✔
1611
        return False
×
1612
    else:
1613
        return True
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc