• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyro-ppl / pyro / 7252572304

18 Dec 2023 06:59PM UTC coverage: 91.931%. Remained the same
7252572304

Pull #3302

github

web-flow
Merge fc5c5c02c into 834ff633c
Pull Request #3302: Add tutorials using normalizing flows

26 of 29 new or added lines in 1 file covered. (89.66%)

17 existing lines in 5 files now uncovered.

22958 of 24973 relevant lines covered (91.93%)

2.29 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.57
/pyro/poutine/runtime.py
1
# Copyright (c) 2017-2019 Uber Technologies, Inc.
2
# SPDX-License-Identifier: Apache-2.0
3

4
from __future__ import annotations
5✔
5

6
import functools
5✔
7
from typing import (
5✔
8
    TYPE_CHECKING,
9
    Callable,
10
    Dict,
11
    List,
12
    Optional,
13
    Set,
14
    Tuple,
15
    TypeVar,
16
    Union,
17
    overload,
18
)
19

20
import torch
5✔
21
from typing_extensions import Literal, ParamSpec, TypedDict
5✔
22

23
from pyro.params.param_store import (  # noqa: F401
5✔
24
    _MODULE_NAMESPACE_DIVIDER,
25
    ParamStoreDict,
26
)
27

28
P = ParamSpec("P")
5✔
29
T = TypeVar("T")
5✔
30

31
if TYPE_CHECKING:
5✔
32
    from pyro.distributions.score_parts import ScoreParts
×
33
    from pyro.distributions.torch_distribution import TorchDistributionMixin
×
34
    from pyro.poutine.indep_messenger import CondIndepStackFrame
×
35
    from pyro.poutine.messenger import Messenger
×
36

37
# the global pyro stack
38
_PYRO_STACK: List[Messenger] = []
5✔
39

40
# the global ParamStore
41
_PYRO_PARAM_STORE = ParamStoreDict()
5✔
42

43

44
class InferDict(TypedDict, total=False):
5✔
45
    """
46
    A dictionary that contains information about inference.
47
    """
48

49
    expand: bool
5✔
50
    is_auxiliary: bool
5✔
51
    is_observed: bool
5✔
52
    num_samples: int
5✔
53
    obs: Optional[torch.Tensor]
5✔
54
    prior: TorchDistributionMixin
5✔
55
    tmc: Literal["diagonal", "mixture"]
5✔
56
    was_observed: bool
5✔
57
    _deterministic: bool
5✔
58
    _dim_to_symbol: Dict[int, str]
5✔
59
    _do_not_trace: bool
5✔
60
    _enumerate_symbol: str
5✔
61
    _markov_scope: Optional[Dict[str, int]]
5✔
62
    _enumerate_dim: int
5✔
63
    _dim_to_id: Dict[int, int]
5✔
64
    _markov_depth: int
5✔
65

66

67
class Message(TypedDict, total=False):
5✔
68
    type: str
5✔
69
    name: Optional[str]
5✔
70
    fn: Union[Callable, TorchDistributionMixin]
5✔
71
    is_observed: bool
5✔
72
    args: Tuple
5✔
73
    kwargs: Dict
5✔
74
    value: Optional[torch.Tensor]
5✔
75
    scale: float
5✔
76
    mask: Union[bool, torch.Tensor, None]
5✔
77
    cond_indep_stack: Tuple[CondIndepStackFrame, ...]
5✔
78
    done: bool
5✔
79
    stop: bool
5✔
80
    continuation: Optional[Callable[[Message], None]]
5✔
81
    infer: Optional[InferDict]
5✔
82
    obs: Optional[torch.Tensor]
5✔
83
    log_prob: torch.Tensor
5✔
84
    log_prob_sum: torch.Tensor
5✔
85
    unscaled_log_prob: torch.Tensor
5✔
86
    score_parts: ScoreParts
5✔
87
    packed: "Message"
5✔
88
    _intervener_id: Optional[str]
5✔
89

90

91
class _DimAllocator:
5✔
92
    """
93
    Dimension allocator for internal use by :class:`plate`.
94
    There is a single global instance.
95

96
    Note that dimensions are indexed from the right, e.g. -1, -2.
97
    """
98

99
    def __init__(self) -> None:
5✔
100
        # in reverse orientation of log_prob.shape
101
        self._stack: List[Optional[str]] = []
5✔
102

103
    def allocate(self, name: str, dim: Optional[int]) -> int:
5✔
104
        """
105
        Allocate a dimension to an :class:`plate` with given name.
106
        Dim should be either None for automatic allocation or a negative
107
        integer for manual allocation.
108
        """
109
        if name in self._stack:
4✔
110
            raise ValueError(f"duplicate plate '{name}'")
111
        if dim is None:
4✔
112
            # Automatically designate the rightmost available dim for allocation.
113
            dim = -1
3✔
114
            while -dim <= len(self._stack) and self._stack[-1 - dim] is not None:
3✔
115
                dim -= 1
3✔
116
        elif dim >= 0:
4✔
117
            raise ValueError(f"Expected dim < 0 to index from the right, actual {dim}")
118

119
        # Allocate the requested dimension.
120
        while dim < -len(self._stack):
4✔
121
            self._stack.append(None)
4✔
122
        if self._stack[-1 - dim] is not None:
4✔
123
            raise ValueError(
124
                "\n".join(
125
                    [
126
                        'at plates "{}" and "{}", collide at dim={}'.format(
127
                            name, self._stack[-1 - dim], dim
128
                        ),
129
                        "\nTry moving the dim of one plate to the left, e.g. dim={}".format(
130
                            dim - 1
131
                        ),
132
                    ]
133
                )
134
            )
135
        self._stack[-1 - dim] = name
4✔
136
        return dim
4✔
137

138
    def free(self, name: str, dim: int) -> None:
5✔
139
        """
140
        Free a dimension.
141
        """
142
        free_idx = -1 - dim  # stack index to free
4✔
143
        assert self._stack[free_idx] == name
4✔
144
        self._stack[free_idx] = None
4✔
145
        while self._stack and self._stack[-1] is None:
4✔
146
            self._stack.pop()
4✔
147

148

149
# Handles placement of plate dimensions
150
_DIM_ALLOCATOR = _DimAllocator()
5✔
151

152

153
class _EnumAllocator:
5✔
154
    """
155
    Dimension allocator for internal use by :func:`~pyro.poutine.markov`.
156
    There is a single global instance.
157

158
    Note that dimensions are indexed from the right, e.g. -1, -2.
159
    Note that ids are simply nonnegative integers here.
160
    """
161

162
    def set_first_available_dim(self, first_available_dim: int) -> None:
5✔
163
        """
164
        Set the first available dim, which should be to the left of all
165
        :class:`plate` dimensions, e.g. ``-1 - max_plate_nesting``. This should
166
        be called once per program. In SVI this should be called only once per
167
        (guide,model) pair.
168
        """
169
        assert first_available_dim < 0, first_available_dim
3✔
170
        self.next_available_dim = first_available_dim
3✔
171
        self.next_available_id = 0
3✔
172
        self.dim_to_id: Dict[int, int] = {}  # only the global ids
3✔
173

174
    def allocate(self, scope_dims: Optional[Set[int]] = None) -> Tuple[int, int]:
5✔
175
        """
176
        Allocate a new recyclable dim and a unique id.
177

178
        If ``scope_dims`` is None, this allocates a global enumeration dim
179
        that will never be recycled. If ``scope_dims`` is specified, this
180
        allocates a local enumeration dim that can be reused by at any other
181
        local site whose scope excludes this site.
182

183
        :param set scope_dims: An optional set of (negative integer)
184
            local enumeration dims to avoid when allocating this dim.
185
        :return: A pair ``(dim, id)``, where ``dim`` is a negative integer
186
            and ``id`` is a nonnegative integer.
187
        :rtype: tuple
188
        """
189
        id_ = self.next_available_id
3✔
190
        self.next_available_id += 1
3✔
191

192
        dim = self.next_available_dim
3✔
193
        if dim == -float("inf"):
3✔
194
            raise ValueError(
195
                "max_plate_nesting must be set to a finite value for parallel enumeration"
196
            )
197
        if scope_dims is None:
3✔
198
            # allocate a new global dimension
199
            self.next_available_dim -= 1
2✔
200
            self.dim_to_id[dim] = id_
2✔
201
        else:
202
            # allocate a new local dimension
203
            while dim in scope_dims:
3✔
204
                dim -= 1
3✔
205

206
        return dim, id_
3✔
207

208

209
# Handles placement of enumeration dimensions
210
_ENUM_ALLOCATOR = _EnumAllocator()
5✔
211

212

213
class NonlocalExit(Exception):
5✔
214
    """
215
    Exception for exiting nonlocally from poutine execution.
216

217
    Used by poutine.EscapeMessenger to return site information.
218
    """
219

220
    def __init__(self, site: Message, *args, **kwargs) -> None:
5✔
221
        """
222
        :param site: message at a pyro site constructor.
223
            Just stores the input site.
224
        """
225
        super().__init__(*args, **kwargs)
2✔
226
        self.site = site
2✔
227

228
    def reset_stack(self) -> None:
5✔
229
        """
230
        Reset the state of the frames remaining in the stack.
231
        Necessary for multiple re-executions in poutine.queue.
232
        """
233
        from pyro.poutine.block_messenger import BlockMessenger
2✔
234

235
        for frame in reversed(_PYRO_STACK):
2✔
236
            frame._reset()
2✔
237
            if isinstance(frame, BlockMessenger) and frame.hide_fn(self.site):
2✔
238
                break
2✔
239

240

241
def default_process_message(msg: Message) -> None:
5✔
242
    """
243
    Default method for processing messages in inference.
244

245
    :param msg: a message to be processed
246
    :returns: None
247
    """
248
    if msg["done"] or msg["is_observed"] or msg["value"] is not None:
4✔
249
        msg["done"] = True
4✔
250
        return
4✔
251

252
    msg["value"] = msg["fn"](*msg["args"], **msg["kwargs"])
4✔
253

254
    # after fn has been called, update msg to prevent it from being called again.
255
    msg["done"] = True
4✔
256

257

258
def apply_stack(initial_msg: Message) -> None:
5✔
259
    """
260
    Execute the effect stack at a single site according to the following scheme:
261

262
        1. For each ``Messenger`` in the stack from bottom to top,
263
           execute ``Messenger._process_message`` with the message;
264
           if the message field "stop" is True, stop;
265
           otherwise, continue
266
        2. Apply default behavior (``default_process_message``) to finish remaining site execution
267
        3. For each ``Messenger`` in the stack from top to bottom,
268
           execute ``_postprocess_message`` to update the message and internal messenger state with the site results
269
        4. If the message field "continuation" is not ``None``, call it with the message
270

271
    :param dict initial_msg: the starting version of the trace site
272
    :returns: ``None``
273
    """
274
    stack = _PYRO_STACK
4✔
275
    # TODO check at runtime if stack is valid
276

277
    # msg is used to pass information up and down the stack
278
    msg = initial_msg
4✔
279

280
    pointer = 0
4✔
281
    # go until time to stop?
282
    for frame in reversed(stack):
4✔
283
        pointer = pointer + 1
4✔
284

285
        frame._process_message(msg)
4✔
286

287
        if msg["stop"]:
4✔
288
            break
4✔
289

290
    default_process_message(msg)
4✔
291

292
    for frame in stack[-pointer:]:
4✔
293
        frame._postprocess_message(msg)
4✔
294

295
    cont = msg["continuation"]
4✔
296
    if cont is not None:
4✔
297
        cont(msg)
2✔
298

299

300
def am_i_wrapped() -> bool:
5✔
301
    """
302
    Checks whether the current computation is wrapped in a poutine.
303
    :returns: bool
304
    """
305
    return len(_PYRO_STACK) > 0
5✔
306

307

308
@overload
5✔
309
def effectful(
5✔
310
    fn: None = ..., type: Optional[str] = ...
311
) -> Callable[[Callable[P, T]], Callable[..., Union[T, torch.Tensor, None]]]:
UNCOV
312
    ...
×
313

314

315
@overload
5✔
316
def effectful(
5✔
317
    fn: Callable[P, T] = ..., type: Optional[str] = ...
318
) -> Callable[..., Union[T, torch.Tensor, None]]:
UNCOV
319
    ...
×
320

321

322
def effectful(
5✔
323
    fn: Optional[Callable[P, T]] = None, type: Optional[str] = None
324
) -> Callable:
325
    """
326
    :param fn: function or callable that performs an effectful computation
327
    :param str type: the type label of the operation, e.g. `"sample"`
328

329
    Wrapper for calling :func:`~pyro.poutine.runtime.apply_stack` to apply any active effects.
330
    """
331
    if fn is None:
5✔
332
        return functools.partial(effectful, type=type)
5✔
333

334
    if getattr(fn, "_is_effectful", None):
5✔
UNCOV
335
        return fn
×
336

337
    assert type is not None, f"must provide a type label for operation {fn}"
5✔
338
    assert type != "message", "cannot use 'message' as keyword"
5✔
339

340
    @functools.wraps(fn)
5✔
341
    def _fn(
5✔
342
        *args: P.args,
343
        name: Optional[str] = None,
344
        infer: Optional[InferDict] = None,
345
        obs: Optional[torch.Tensor] = None,
346
        **kwargs: P.kwargs,
347
    ) -> Union[T, torch.Tensor, None]:
348
        is_observed = obs is not None
5✔
349

350
        if not am_i_wrapped():
5✔
351
            return fn(*args, **kwargs)
5✔
352
        else:
353
            msg = Message(
4✔
354
                type=type,
355
                name=name,
356
                fn=fn,
357
                is_observed=is_observed,
358
                args=args,
359
                kwargs=kwargs,
360
                value=obs,
361
                scale=1.0,
362
                mask=None,
363
                cond_indep_stack=(),
364
                done=False,
365
                stop=False,
366
                continuation=None,
367
                infer=infer if infer is not None else {},
368
            )
369
            # apply the stack and return its return value
370
            apply_stack(msg)
4✔
371
            return msg["value"]
4✔
372

373
    _fn._is_effectful = True  # type: ignore[attr-defined]
5✔
374
    return _fn
5✔
375

376

377
def _inspect() -> Message:
5✔
378
    """
379
    EXPERIMENTAL Inspect the Pyro stack.
380

381
    .. warning:: The format of the returned message may change at any time and
382
        does not guarantee backwards compatibility.
383

384
    :returns: A message with all effects applied.
385
    :rtype: dict
386
    """
387
    msg = Message(
3✔
388
        type="inspect",
389
        name="_pyro_inspect",
390
        fn=lambda: True,
391
        is_observed=False,
392
        args=(),
393
        kwargs={},
394
        value=None,
395
        infer={"_do_not_trace": True},
396
        scale=1.0,
397
        mask=None,
398
        cond_indep_stack=(),
399
        done=False,
400
        stop=False,
401
        continuation=None,
402
    )
403
    apply_stack(msg)
3✔
404
    return msg
3✔
405

406

407
def get_mask() -> Union[bool, torch.Tensor, None]:
5✔
408
    """
409
    Records the effects of enclosing ``poutine.mask`` handlers.
410

411
    This is useful for avoiding expensive ``pyro.factor()`` computations during
412
    prediction, when the log density need not be computed, e.g.::
413

414
        def model():
415
            # ...
416
            if poutine.get_mask() is not False:
417
                log_density = my_expensive_computation()
418
                pyro.factor("foo", log_density)
419
            # ...
420

421
    :returns: The mask.
422
    :rtype: None, bool, or torch.Tensor
423
    """
424
    return _inspect()["mask"]
3✔
425

426

427
def get_plates() -> Tuple[CondIndepStackFrame, ...]:
5✔
428
    """
429
    Records the effects of enclosing ``pyro.plate`` contexts.
430

431
    :returns: A tuple of
432
        :class:`pyro.poutine.indep_messenger.CondIndepStackFrame` objects.
433
    :rtype: tuple
434
    """
435
    return _inspect()["cond_indep_stack"]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc