• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyro-ppl / pyro / 8393367952

22 Mar 2024 04:22PM UTC coverage: 91.846% (+0.01%) from 91.834%
8393367952

Pull #3345

github

web-flow
Merge a4c4d8585 into 81def9c53
Pull Request #3345: Introducing pyro.infer.predictive.WeighedPredictive which reports weights along with predicted samples

50 of 52 new or added lines in 3 files covered. (96.15%)

10 existing lines in 2 files now uncovered.

23272 of 25338 relevant lines covered (91.85%)

2.3 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.46
/pyro/infer/importance.py
1
# Copyright (c) 2017-2019 Uber Technologies, Inc.
2
# SPDX-License-Identifier: Apache-2.0
3

4
import math
5✔
5
import warnings
5✔
6

7
import torch
5✔
8

9
import pyro
5✔
10
import pyro.poutine as poutine
5✔
11
from pyro.ops.stats import fit_generalized_pareto
5✔
12

13
from .abstract_infer import TracePosterior
5✔
14
from .enum import get_importance_trace
5✔
15

16

17
class Importance(TracePosterior):
5✔
18
    """
19
    :param model: probabilistic model defined as a function
20
    :param guide: guide used for sampling defined as a function
21
    :param num_samples: number of samples to draw from the guide (default 10)
22

23
    This method performs posterior inference by importance sampling
24
    using the guide as the proposal distribution.
25
    If no guide is provided, it defaults to proposing from the model's prior.
26
    """
27

28
    def __init__(self, model, guide=None, num_samples=None):
5✔
29
        """
30
        Constructor. default to num_samples = 10, guide = model
31
        """
32
        super().__init__()
2✔
33
        if num_samples is None:
2✔
UNCOV
34
            num_samples = 10
×
35
            warnings.warn(
36
                "num_samples not provided, defaulting to {}".format(num_samples)
37
            )
38
        if guide is None:
2✔
39
            # propose from the prior by making a guide from the model by hiding observes
40
            guide = poutine.block(model, hide_types=["observe"])
2✔
41
        self.num_samples = num_samples
2✔
42
        self.model = model
2✔
43
        self.guide = guide
2✔
44

45
    def _traces(self, *args, **kwargs):
5✔
46
        """
47
        Generator of weighted samples from the proposal distribution.
48
        """
49
        for i in range(self.num_samples):
2✔
50
            guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
2✔
51
            model_trace = poutine.trace(
2✔
52
                poutine.replay(self.model, trace=guide_trace)
53
            ).get_trace(*args, **kwargs)
54
            log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
2✔
55
            yield (model_trace, log_weight)
2✔
56

57
    def get_log_normalizer(self):
5✔
58
        """
59
        Estimator of the normalizing constant of the target distribution.
60
        (mean of the unnormalized weights)
61
        """
62
        # ensure list is not empty
UNCOV
63
        if self.log_weights:
×
64
            log_w = torch.tensor(self.log_weights)
×
65
            log_num_samples = torch.log(torch.tensor(self.num_samples * 1.0))
×
66
            return torch.logsumexp(log_w - log_num_samples, 0)
×
67
        else:
68
            warnings.warn(
69
                "The log_weights list is empty, can not compute normalizing constant estimate."
70
            )
71

72
    def get_normalized_weights(self, log_scale=False):
5✔
73
        """
74
        Compute the normalized importance weights.
75
        """
UNCOV
76
        if self.log_weights:
×
77
            log_w = torch.tensor(self.log_weights)
×
78
            log_w_norm = log_w - torch.logsumexp(log_w, 0)
×
79
            return log_w_norm if log_scale else torch.exp(log_w_norm)
×
80
        else:
81
            warnings.warn(
82
                "The log_weights list is empty. There is nothing to normalize."
83
            )
84

85
    def get_ESS(self):
5✔
86
        """
87
        Compute (Importance Sampling) Effective Sample Size (ESS).
88
        """
UNCOV
89
        if self.log_weights:
×
90
            log_w_norm = self.get_normalized_weights(log_scale=True)
×
91
            ess = torch.exp(-torch.logsumexp(2 * log_w_norm, 0))
×
92
        else:
93
            warnings.warn(
94
                "The log_weights list is empty, effective sample size is zero."
95
            )
UNCOV
96
            ess = 0
×
97
        return ess
×
98

99

100
def vectorized_importance_weights(model, guide, *args, **kwargs):
5✔
101
    """
102
    :param model: probabilistic model defined as a function
103
    :param guide: guide used for sampling defined as a function
104
    :param num_samples: number of samples to draw from the guide (default 1)
105
    :param int max_plate_nesting: Bound on max number of nested :func:`pyro.plate` contexts.
106
    :param bool normalized: set to True to return self-normalized importance weights
107
    :returns: returns a ``(num_samples,)``-shaped tensor of importance weights
108
        and the model and guide traces that produced them
109

110
    Vectorized computation of importance weights for models with static structure::
111

112
        log_weights, model_trace, guide_trace = \\
113
            vectorized_importance_weights(model, guide, *args,
114
                                          num_samples=1000,
115
                                          max_plate_nesting=4,
116
                                          normalized=False)
117
    """
118
    num_samples = kwargs.pop("num_samples", 1)
1✔
119
    max_plate_nesting = kwargs.pop("max_plate_nesting", None)
1✔
120
    normalized = kwargs.pop("normalized", False)
1✔
121

122
    if max_plate_nesting is None:
1✔
123
        raise ValueError("must provide max_plate_nesting")
124
    max_plate_nesting += 1
1✔
125

126
    def vectorize(fn):
1✔
127
        def _fn(*args, **kwargs):
1✔
128
            with pyro.plate(
1✔
129
                "num_particles_vectorized", num_samples, dim=-max_plate_nesting
130
            ):
131
                return fn(*args, **kwargs)
1✔
132

133
        return _fn
1✔
134

135
    model_trace, guide_trace = get_importance_trace(
1✔
136
        "flat", max_plate_nesting, vectorize(model), vectorize(guide), args, kwargs
137
    )
138

139
    guide_trace.pack_tensors()
1✔
140
    model_trace.pack_tensors(guide_trace.plate_to_symbol)
1✔
141

142
    if num_samples == 1:
1✔
UNCOV
143
        log_weights = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
×
144
    else:
145
        wd = guide_trace.plate_to_symbol["num_particles_vectorized"]
1✔
146
        log_weights = 0.0
1✔
147
        for site in model_trace.nodes.values():
1✔
148
            if site["type"] != "sample":
1✔
149
                continue
1✔
150
            log_weights += torch.einsum(
1✔
151
                site["packed"]["log_prob"]._pyro_dims + "->" + wd,
152
                [site["packed"]["log_prob"]],
153
            )
154

155
        for site in guide_trace.nodes.values():
1✔
156
            if site["type"] != "sample":
1✔
157
                continue
1✔
158
            log_weights -= torch.einsum(
1✔
159
                site["packed"]["log_prob"]._pyro_dims + "->" + wd,
160
                [site["packed"]["log_prob"]],
161
            )
162

163
    if normalized:
1✔
UNCOV
164
        log_weights = log_weights - torch.logsumexp(log_weights)
×
165
    return log_weights, model_trace, guide_trace
1✔
166

167

168
@torch.no_grad()
5✔
169
def psis_diagnostic(model, guide, *args, **kwargs):
5✔
170
    """
171
    Computes the Pareto tail index k for a model/guide pair using the technique
172
    described in [1], which builds on previous work in [2]. If :math:`0 < k < 0.5`
173
    the guide is a good approximation to the model posterior, in the sense
174
    described in [1]. If :math:`0.5 \\le k \\le 0.7`, the guide provides a suboptimal
175
    approximation to the posterior, but may still be useful in practice. If
176
    :math:`k > 0.7` the guide program provides a poor approximation to the full
177
    posterior, and caution should be used when using the guide. Note, however,
178
    that a guide may be a poor fit to the full posterior while still yielding
179
    reasonable model predictions. If :math:`k < 0.0` the importance weights
180
    corresponding to the model and guide appear to be bounded from above; this
181
    would be a bizarre outcome for a guide trained via ELBO maximization. Please
182
    see [1] for a more complete discussion of how the tail index k should be
183
    interpreted.
184

185
    Please be advised that a large number of samples may be required for an
186
    accurate estimate of k.
187

188
    Note that we assume that the model and guide are both vectorized and have
189
    static structure. As is canonical in Pyro, the args and kwargs are passed
190
    to the model and guide.
191

192
    References
193
    [1] 'Yes, but Did It Work?: Evaluating Variational Inference.'
194
    Yuling Yao, Aki Vehtari, Daniel Simpson, Andrew Gelman
195
    [2] 'Pareto Smoothed Importance Sampling.'
196
    Aki Vehtari, Andrew Gelman, Jonah Gabry
197

198
    :param callable model: the model program.
199
    :param callable guide: the guide program.
200
    :param int num_particles: the total number of times we run the model and guide in
201
        order to compute the diagnostic. defaults to 1000.
202
    :param max_simultaneous_particles: the maximum number of simultaneous samples drawn
203
        from the model and guide. defaults to `num_particles`. `num_particles` must be
204
        divisible by `max_simultaneous_particles`. compute the diagnostic. defaults to 1000.
205
    :param int max_plate_nesting: optional bound on max number of nested :func:`pyro.plate`
206
        contexts in the model/guide. defaults to 7.
207
    :returns float: the PSIS diagnostic k
208
    """
209

210
    num_particles = kwargs.pop("num_particles", 1000)
1✔
211
    max_simultaneous_particles = kwargs.pop("max_simultaneous_particles", num_particles)
1✔
212
    max_plate_nesting = kwargs.pop("max_plate_nesting", 7)
1✔
213

214
    if num_particles % max_simultaneous_particles != 0:
1✔
215
        raise ValueError(
216
            "num_particles must be divisible by max_simultaneous_particles."
217
        )
218

219
    N = num_particles // max_simultaneous_particles
1✔
220
    log_weights = [
1✔
221
        vectorized_importance_weights(
222
            model,
223
            guide,
224
            num_samples=max_simultaneous_particles,
225
            max_plate_nesting=max_plate_nesting,
226
            *args,
227
            **kwargs,
228
        )[0]
229
        for _ in range(N)
230
    ]
231
    log_weights = torch.cat(log_weights)
1✔
232
    log_weights -= log_weights.max()
1✔
233
    log_weights = torch.sort(log_weights, descending=False)[0]
1✔
234

235
    cutoff_index = (
1✔
236
        -int(math.ceil(min(0.2 * num_particles, 3.0 * math.sqrt(num_particles)))) - 1
237
    )
238
    lw_cutoff = max(math.log(1.0e-15), log_weights[cutoff_index])
1✔
239
    lw_tail = log_weights[log_weights > lw_cutoff]
1✔
240

241
    if len(lw_tail) < 10:
1✔
242
        warnings.warn(
243
            "Not enough tail samples to compute PSIS diagnostic; increase num_particles."
244
        )
UNCOV
245
        k = float("inf")
×
246
    else:
247
        k, _ = fit_generalized_pareto(lw_tail.exp() - math.exp(lw_cutoff))
1✔
248

249
    return k
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc