• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyro-ppl / pyro / 5702491366

pending completion
5702491366

push

github

fritzo
Merge branch 'dev'

50 of 51 new or added lines in 10 files covered. (98.04%)

2 existing lines in 2 files now uncovered.

22701 of 24722 relevant lines covered (91.83%)

2.27 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.17
/pyro/poutine/trace_struct.py
1
# Copyright (c) 2017-2019 Uber Technologies, Inc.
2
# SPDX-License-Identifier: Apache-2.0
3

4
import sys
5✔
5
from collections import OrderedDict
5✔
6

7
import opt_einsum
5✔
8

9
from pyro.distributions.score_parts import ScoreParts
5✔
10
from pyro.distributions.util import scale_and_mask
5✔
11
from pyro.ops.packed import pack
5✔
12
from pyro.poutine.util import is_validation_enabled
5✔
13
from pyro.util import warn_if_inf, warn_if_nan
14

15

16
class Trace:
5✔
17
    """
18
    Graph data structure denoting the relationships amongst different pyro primitives
19
    in the execution trace.
20

21
    An execution trace of a Pyro program is a record of every call
22
    to ``pyro.sample()`` and ``pyro.param()`` in a single execution of that program.
23
    Traces are directed graphs whose nodes represent primitive calls or input/output,
24
    and whose edges represent conditional dependence relationships
25
    between those primitive calls. They are created and populated by ``poutine.trace``.
26

27
    Each node (or site) in a trace contains the name, input and output value of the site,
28
    as well as additional metadata added by inference algorithms or user annotation.
29
    In the case of ``pyro.sample``, the trace also includes the stochastic function
30
    at the site, and any observed data added by users.
31

32
    Consider the following Pyro program:
33

34
        >>> def model(x):
35
        ...     s = pyro.param("s", torch.tensor(0.5))
36
        ...     z = pyro.sample("z", dist.Normal(x, s))
37
        ...     return z ** 2
38

39
    We can record its execution using ``pyro.poutine.trace``
40
    and use the resulting data structure to compute the log-joint probability
41
    of all of the sample sites in the execution or extract all parameters.
42

43
        >>> trace = pyro.poutine.trace(model).get_trace(0.0)
44
        >>> logp = trace.log_prob_sum()
45
        >>> params = [trace.nodes[name]["value"].unconstrained() for name in trace.param_nodes]
46

47
    We can also inspect or manipulate individual nodes in the trace.
48
    ``trace.nodes`` contains a ``collections.OrderedDict``
49
    of site names and metadata corresponding to ``x``, ``s``, ``z``, and the return value:
50

51
        >>> list(name for name in trace.nodes.keys())  # doctest: +SKIP
52
        ["_INPUT", "s", "z", "_RETURN"]
53

54
    Values of ``trace.nodes`` are dictionaries of node metadata:
55

56
        >>> trace.nodes["z"]  # doctest: +SKIP
57
        {'type': 'sample', 'name': 'z', 'is_observed': False,
58
         'fn': Normal(), 'value': tensor(0.6480), 'args': (), 'kwargs': {},
59
         'infer': {}, 'scale': 1.0, 'cond_indep_stack': (),
60
         'done': True, 'stop': False, 'continuation': None}
61

62
    ``'infer'`` is a dictionary of user- or algorithm-specified metadata.
63
    ``'args'`` and ``'kwargs'`` are the arguments passed via ``pyro.sample``
64
    to ``fn.__call__`` or ``fn.log_prob``.
65
    ``'scale'`` is used to scale the log-probability of the site when computing the log-joint.
66
    ``'cond_indep_stack'`` contains data structures corresponding to ``pyro.plate`` contexts
67
    appearing in the execution.
68
    ``'done'``, ``'stop'``, and ``'continuation'`` are only used by Pyro's internals.
69

70
    :param string graph_type: string specifying the kind of trace graph to construct
71
    """
72

73
    def __init__(self, graph_type="flat"):
5✔
74
        assert graph_type in ("flat", "dense"), "{} not a valid graph type".format(
5✔
75
            graph_type
76
        )
77
        self.graph_type = graph_type
5✔
78
        self.nodes = OrderedDict()
5✔
79
        self._succ = OrderedDict()
5✔
80
        self._pred = OrderedDict()
5✔
81

82
    def __contains__(self, name):
5✔
83
        return name in self.nodes
5✔
84

85
    def __iter__(self):
5✔
86
        return iter(self.nodes.keys())
2✔
87

88
    def __len__(self):
5✔
89
        return len(self.nodes)
1✔
90

91
    @property
5✔
92
    def edges(self):
5✔
93
        for site, adj_nodes in self._succ.items():
3✔
94
            for adj_node in adj_nodes:
3✔
95
                yield site, adj_node
3✔
96

97
    def add_node(self, site_name, **kwargs):
5✔
98
        """
99
        :param string site_name: the name of the site to be added
100

101
        Adds a site to the trace.
102

103
        Raises an error when attempting to add a duplicate node
104
        instead of silently overwriting.
105
        """
106
        if site_name in self:
5✔
107
            site = self.nodes[site_name]
4✔
108
            if site["type"] != kwargs["type"]:
4✔
109
                # Cannot sample or observe after a param statement.
110
                raise RuntimeError(
1✔
111
                    "{} is already in the trace as a {}".format(site_name, site["type"])
112
                )
113
            elif kwargs["type"] != "param":
4✔
114
                # Cannot sample after a previous sample statement.
115
                raise RuntimeError(
1✔
116
                    "Multiple {} sites named '{}'".format(kwargs["type"], site_name)
117
                )
118

119
        # XXX should copy in case site gets mutated, or dont bother?
120
        self.nodes[site_name] = kwargs
5✔
121
        self._pred[site_name] = set()
5✔
122
        self._succ[site_name] = set()
5✔
123

124
    def add_edge(self, site1, site2):
5✔
125
        for site in (site1, site2):
3✔
126
            if site not in self.nodes:
3✔
127
                self.add_node(site)
3✔
128
        self._succ[site1].add(site2)
3✔
129
        self._pred[site2].add(site1)
3✔
130

131
    def remove_node(self, site_name):
5✔
132
        self.nodes.pop(site_name)
4✔
133
        for p in self._pred[site_name]:
4✔
134
            self._succ[p].remove(site_name)
1✔
135
        for s in self._succ[site_name]:
4✔
UNCOV
136
            self._pred[s].remove(site_name)
×
137
        self._pred.pop(site_name)
4✔
138
        self._succ.pop(site_name)
4✔
139

140
    def predecessors(self, site_name):
5✔
141
        return self._pred[site_name]
2✔
142

143
    def successors(self, site_name):
5✔
144
        return self._succ[site_name]
1✔
145

146
    def copy(self):
5✔
147
        """
148
        Makes a shallow copy of self with nodes and edges preserved.
149
        """
150
        new_tr = Trace(graph_type=self.graph_type)
4✔
151
        new_tr.nodes.update(self.nodes)
4✔
152
        new_tr._succ.update(self._succ)
4✔
153
        new_tr._pred.update(self._pred)
4✔
154
        return new_tr
4✔
155

156
    def _dfs(self, site, visited):
5✔
157
        if site in visited:
3✔
158
            return
3✔
159
        for s in self._succ[site]:
3✔
160
            for node in self._dfs(s, visited):
3✔
161
                yield node
3✔
162
        visited.add(site)
3✔
163
        yield site
3✔
164

165
    def topological_sort(self, reverse=False):
5✔
166
        """
167
        Return a list of nodes (site names) in topologically sorted order.
168

169
        :param bool reverse: Return the list in reverse order.
170
        :return: list of topologically sorted nodes (site names).
171
        """
172
        visited = set()
3✔
173
        top_sorted = []
3✔
174
        for s in self._succ:
3✔
175
            for node in self._dfs(s, visited):
3✔
176
                top_sorted.append(node)
3✔
177
        return top_sorted if reverse else list(reversed(top_sorted))
3✔
178

179
    def log_prob_sum(self, site_filter=lambda name, site: True):
5✔
180
        """
181
        Compute the site-wise log probabilities of the trace.
182
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
183
        Each ``log_prob_sum`` is a scalar.
184
        The computation of ``log_prob_sum`` is memoized.
185

186
        :returns: total log probability.
187
        :rtype: torch.Tensor
188
        """
189
        result = 0.0
3✔
190
        for name, site in self.nodes.items():
3✔
191
            if site["type"] == "sample" and site_filter(name, site):
3✔
192
                if "log_prob_sum" in site:
3✔
193
                    log_p = site["log_prob_sum"]
1✔
194
                else:
195
                    try:
3✔
196
                        log_p = site["fn"].log_prob(
3✔
197
                            site["value"], *site["args"], **site["kwargs"]
198
                        )
199
                    except ValueError as e:
1✔
200
                        _, exc_value, traceback = sys.exc_info()
1✔
201
                        shapes = self.format_shapes(last_site=site["name"])
1✔
202
                        raise ValueError(
203
                            "Error while computing log_prob_sum at site '{}':\n{}\n{}\n".format(
204
                                name, exc_value, shapes
205
                            )
206
                        ).with_traceback(traceback) from e
207
                    log_p = scale_and_mask(log_p, site["scale"], site["mask"]).sum()
3✔
208
                    site["log_prob_sum"] = log_p
3✔
209
                    if is_validation_enabled():
3✔
210
                        warn_if_nan(log_p, "log_prob_sum at site '{}'".format(name))
211
                        warn_if_inf(
212
                            log_p,
213
                            "log_prob_sum at site '{}'".format(name),
214
                            allow_neginf=True,
215
                        )
216
                result = result + log_p
3✔
217
        return result
3✔
218

219
    def compute_log_prob(self, site_filter=lambda name, site: True):
5✔
220
        """
221
        Compute the site-wise log probabilities of the trace.
222
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
223
        Each ``log_prob_sum`` is a scalar.
224
        Both computations are memoized.
225
        """
226
        for name, site in self.nodes.items():
4✔
227
            if site["type"] == "sample" and site_filter(name, site):
4✔
228
                if "log_prob" not in site:
4✔
229
                    try:
4✔
230
                        log_p = site["fn"].log_prob(
4✔
231
                            site["value"], *site["args"], **site["kwargs"]
232
                        )
233
                    except ValueError as e:
1✔
234
                        _, exc_value, traceback = sys.exc_info()
1✔
235
                        shapes = self.format_shapes(last_site=site["name"])
1✔
236
                        raise ValueError(
237
                            "Error while computing log_prob at site '{}':\n{}\n{}".format(
238
                                name, exc_value, shapes
239
                            )
240
                        ).with_traceback(traceback) from e
241
                    site["unscaled_log_prob"] = log_p
4✔
242
                    log_p = scale_and_mask(log_p, site["scale"], site["mask"])
4✔
243
                    site["log_prob"] = log_p
4✔
244
                    site["log_prob_sum"] = log_p.sum()
4✔
245
                    if is_validation_enabled():
4✔
246
                        warn_if_nan(
247
                            site["log_prob_sum"],
248
                            "log_prob_sum at site '{}'".format(name),
249
                        )
250
                        warn_if_inf(
251
                            site["log_prob_sum"],
252
                            "log_prob_sum at site '{}'".format(name),
253
                            allow_neginf=True,
254
                        )
255

256
    def compute_score_parts(self):
5✔
257
        """
258
        Compute the batched local score parts at each site of the trace.
259
        Each ``log_prob`` has shape equal to the corresponding ``batch_shape``.
260
        Each ``log_prob_sum`` is a scalar.
261
        All computations are memoized.
262
        """
263
        for name, site in self.nodes.items():
4✔
264
            if site["type"] == "sample" and "score_parts" not in site:
4✔
265
                # Note that ScoreParts overloads the multiplication operator
266
                # to correctly scale each of its three parts.
267
                try:
4✔
268
                    value = site["fn"].score_parts(
4✔
269
                        site["value"], *site["args"], **site["kwargs"]
270
                    )
271
                except ValueError as e:
1✔
272
                    _, exc_value, traceback = sys.exc_info()
1✔
273
                    shapes = self.format_shapes(last_site=site["name"])
1✔
274
                    raise ValueError(
275
                        "Error while computing score_parts at site '{}':\n{}\n{}".format(
276
                            name, exc_value, shapes
277
                        )
278
                    ).with_traceback(traceback) from e
279
                site["unscaled_log_prob"] = value.log_prob
4✔
280
                value = value.scale_and_mask(site["scale"], site["mask"])
4✔
281
                site["score_parts"] = value
4✔
282
                site["log_prob"] = value.log_prob
4✔
283
                site["log_prob_sum"] = value.log_prob.sum()
4✔
284
                if is_validation_enabled():
4✔
285
                    warn_if_nan(
286
                        site["log_prob_sum"], "log_prob_sum at site '{}'".format(name)
287
                    )
288
                    warn_if_inf(
289
                        site["log_prob_sum"],
290
                        "log_prob_sum at site '{}'".format(name),
291
                        allow_neginf=True,
292
                    )
293

294
    def detach_(self):
5✔
295
        """
296
        Detach values (in-place) at each sample site of the trace.
297
        """
298
        for _, site in self.nodes.items():
2✔
299
            if site["type"] == "sample":
2✔
300
                site["value"] = site["value"].detach()
2✔
301

302
    @property
5✔
303
    def observation_nodes(self):
5✔
304
        """
305
        :return: a list of names of observe sites
306
        """
307
        return [
1✔
308
            name
309
            for name, node in self.nodes.items()
310
            if node["type"] == "sample" and node["is_observed"]
311
        ]
312

313
    @property
5✔
314
    def param_nodes(self):
5✔
315
        """
316
        :return: a list of names of param sites
317
        """
318
        return [name for name, node in self.nodes.items() if node["type"] == "param"]
1✔
319

320
    @property
5✔
321
    def stochastic_nodes(self):
5✔
322
        """
323
        :return: a list of names of sample sites
324
        """
325
        return [
1✔
326
            name
327
            for name, node in self.nodes.items()
328
            if node["type"] == "sample" and not node["is_observed"]
329
        ]
330

331
    @property
5✔
332
    def reparameterized_nodes(self):
5✔
333
        """
334
        :return: a list of names of sample sites whose stochastic functions
335
            are reparameterizable primitive distributions
336
        """
337
        return [
1✔
338
            name
339
            for name, node in self.nodes.items()
340
            if node["type"] == "sample"
341
            and not node["is_observed"]
342
            and getattr(node["fn"], "has_rsample", False)
343
        ]
344

345
    @property
5✔
346
    def nonreparam_stochastic_nodes(self):
5✔
347
        """
348
        :return: a list of names of sample sites whose stochastic functions
349
            are not reparameterizable primitive distributions
350
        """
351
        return list(set(self.stochastic_nodes) - set(self.reparameterized_nodes))
1✔
352

353
    def iter_stochastic_nodes(self):
5✔
354
        """
355
        :return: an iterator over stochastic nodes in the trace.
356
        """
357
        for name, node in self.nodes.items():
4✔
358
            if node["type"] == "sample" and not node["is_observed"]:
4✔
359
                yield name, node
4✔
360

361
    def symbolize_dims(self, plate_to_symbol=None):
5✔
362
        """
363
        Assign unique symbols to all tensor dimensions.
364
        """
365
        plate_to_symbol = {} if plate_to_symbol is None else plate_to_symbol
3✔
366
        symbol_to_dim = {}
3✔
367
        for site in self.nodes.values():
3✔
368
            if site["type"] != "sample":
3✔
369
                continue
3✔
370

371
            # allocate even symbols for plate dims
372
            dim_to_symbol = {}
3✔
373
            for frame in site["cond_indep_stack"]:
3✔
374
                if frame.vectorized:
3✔
375
                    if frame.name in plate_to_symbol:
3✔
376
                        symbol = plate_to_symbol[frame.name]
3✔
377
                    else:
378
                        symbol = opt_einsum.get_symbol(2 * len(plate_to_symbol))
3✔
379
                        plate_to_symbol[frame.name] = symbol
3✔
380
                    symbol_to_dim[symbol] = frame.dim
3✔
381
                    dim_to_symbol[frame.dim] = symbol
3✔
382

383
            # allocate odd symbols for enum dims
384
            for dim, id_ in site["infer"].get("_dim_to_id", {}).items():
3✔
385
                symbol = opt_einsum.get_symbol(1 + 2 * id_)
3✔
386
                symbol_to_dim[symbol] = dim
3✔
387
                dim_to_symbol[dim] = symbol
3✔
388
            enum_dim = site["infer"].get("_enumerate_dim")
3✔
389
            if enum_dim is not None:
3✔
390
                site["infer"]["_enumerate_symbol"] = dim_to_symbol[enum_dim]
3✔
391
            site["infer"]["_dim_to_symbol"] = dim_to_symbol
3✔
392

393
        self.plate_to_symbol = plate_to_symbol
3✔
394
        self.symbol_to_dim = symbol_to_dim
3✔
395

396
    def pack_tensors(self, plate_to_symbol=None):
5✔
397
        """
398
        Computes packed representations of tensors in the trace.
399
        This should be called after :meth:`compute_log_prob` or :meth:`compute_score_parts`.
400
        """
401
        self.symbolize_dims(plate_to_symbol)
3✔
402
        for site in self.nodes.values():
3✔
403
            if site["type"] != "sample":
3✔
404
                continue
3✔
405
            dim_to_symbol = site["infer"]["_dim_to_symbol"]
3✔
406
            packed = site.setdefault("packed", {})
3✔
407
            try:
3✔
408
                packed["mask"] = pack(site["mask"], dim_to_symbol)
3✔
409
                if "score_parts" in site:
3✔
410
                    log_prob, score_function, entropy_term = site["score_parts"]
3✔
411
                    log_prob = pack(log_prob, dim_to_symbol)
3✔
412
                    score_function = pack(score_function, dim_to_symbol)
3✔
413
                    entropy_term = pack(entropy_term, dim_to_symbol)
3✔
414
                    packed["score_parts"] = ScoreParts(
3✔
415
                        log_prob, score_function, entropy_term
416
                    )
417
                    packed["log_prob"] = log_prob
3✔
418
                    packed["unscaled_log_prob"] = pack(
3✔
419
                        site["unscaled_log_prob"], dim_to_symbol
420
                    )
421
                elif "log_prob" in site:
3✔
422
                    packed["log_prob"] = pack(site["log_prob"], dim_to_symbol)
3✔
423
                    packed["unscaled_log_prob"] = pack(
3✔
424
                        site["unscaled_log_prob"], dim_to_symbol
425
                    )
426
            except ValueError as e:
1✔
427
                _, exc_value, traceback = sys.exc_info()
1✔
428
                shapes = self.format_shapes(last_site=site["name"])
1✔
429
                raise ValueError(
430
                    "Error while packing tensors at site '{}':\n  {}\n{}".format(
431
                        site["name"], exc_value, shapes
432
                    )
433
                ).with_traceback(traceback) from e
434

435
    def format_shapes(self, title="Trace Shapes:", last_site=None):
5✔
436
        """
437
        Returns a string showing a table of the shapes of all sites in the
438
        trace.
439
        """
440
        if not self.nodes:
2✔
441
            return title
×
442
        rows = [[title]]
2✔
443

444
        rows.append(["Param Sites:"])
2✔
445
        for name, site in self.nodes.items():
2✔
446
            if site["type"] == "param":
2✔
447
                rows.append([name, None] + [str(size) for size in site["value"].shape])
1✔
448
            if name == last_site:
2✔
449
                break
1✔
450

451
        rows.append(["Sample Sites:"])
2✔
452
        for name, site in self.nodes.items():
2✔
453
            if site["type"] == "sample":
2✔
454
                # param shape
455
                batch_shape = getattr(site["fn"], "batch_shape", ())
2✔
456
                event_shape = getattr(site["fn"], "event_shape", ())
2✔
457
                rows.append(
2✔
458
                    [name + " dist", None]
459
                    + [str(size) for size in batch_shape]
460
                    + ["|", None]
461
                    + [str(size) for size in event_shape]
462
                )
463

464
                # value shape
465
                event_dim = len(event_shape)
2✔
466
                shape = getattr(site["value"], "shape", ())
2✔
467
                batch_shape = shape[: len(shape) - event_dim]
2✔
468
                event_shape = shape[len(shape) - event_dim :]
2✔
469
                rows.append(
2✔
470
                    ["value", None]
471
                    + [str(size) for size in batch_shape]
472
                    + ["|", None]
473
                    + [str(size) for size in event_shape]
474
                )
475

476
                # log_prob shape
477
                if "log_prob" in site:
2✔
478
                    batch_shape = getattr(site["log_prob"], "shape", ())
1✔
479
                    rows.append(
1✔
480
                        ["log_prob", None]
481
                        + [str(size) for size in batch_shape]
482
                        + ["|", None]
483
                    )
484
            if name == last_site:
2✔
485
                break
1✔
486

487
        return _format_table(rows)
2✔
488

489

490
def _format_table(rows):
5✔
491
    """
492
    Formats a right justified table using None as column separator.
493
    """
494
    # compute column widths
495
    column_widths = [0, 0, 0]
2✔
496
    for row in rows:
2✔
497
        widths = [0, 0, 0]
2✔
498
        j = 0
2✔
499
        for cell in row:
2✔
500
            if cell is None:
2✔
501
                j += 1
2✔
502
            else:
503
                widths[j] += 1
2✔
504
        for j in range(3):
2✔
505
            column_widths[j] = max(column_widths[j], widths[j])
2✔
506

507
    # justify columns
508
    for i, row in enumerate(rows):
2✔
509
        cols = [[], [], []]
2✔
510
        j = 0
2✔
511
        for cell in row:
2✔
512
            if cell is None:
2✔
513
                j += 1
2✔
514
            else:
515
                cols[j].append(cell)
2✔
516
        cols = [
2✔
517
            [""] * (width - len(col)) + col
518
            if direction == "r"
519
            else col + [""] * (width - len(col))
520
            for width, col, direction in zip(column_widths, cols, "rrl")
521
        ]
522
        rows[i] = sum(cols, [])
2✔
523

524
    # compute cell widths
525
    cell_widths = [0] * len(rows[0])
2✔
526
    for row in rows:
2✔
527
        for j, cell in enumerate(row):
2✔
528
            cell_widths[j] = max(cell_widths[j], len(cell))
2✔
529

530
    # justify cells
531
    return "\n".join(
2✔
532
        " ".join(cell.rjust(width) for cell, width in zip(row, cell_widths))
533
        for row in rows
534
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc