• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pymc-devs / pymc3 / 9391

pending completion
9391

Pull #3638

travis-ci

web-flow
Drop first dimension when computing determinant of the Jacobian of the transformation.
Pull Request #3638: Simple stick breaking (Formerly #3620)

23 of 23 new or added lines in 1 file covered. (100.0%)

52178 of 100270 relevant lines covered (52.04%)

2.04 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

76.38
/pymc3/parallel_sampling.py
1
import multiprocessing
6✔
2
import multiprocessing.sharedctypes
6✔
3
import ctypes
6✔
4
import time
6✔
5
import logging
6✔
6
from collections import namedtuple
6✔
7
import traceback
6✔
8
from pymc3.exceptions import SamplingError
6✔
9
import errno
6✔
10

11
import numpy as np
6✔
12

13
from . import theanof
6✔
14

15
logger = logging.getLogger("pymc3")
6✔
16

17

18
def _get_broken_pipe_exception():
6✔
19
    import sys
×
20
    if sys.platform == 'win32':
×
21
        return RuntimeError("The communication pipe between the main process "
×
22
                            "and its spawned children is broken.\n"
23
                            "In Windows OS, this usually means that the child "
24
                            "process raised an exception while it was being "
25
                            "spawned, before it was setup to communicate to "
26
                            "the main process.\n"
27
                            "The exceptions raised by the child process while "
28
                            "spawning cannot be caught or handled from the "
29
                            "main process, and when running from an IPython or "
30
                            "jupyter notebook interactive kernel, the child's "
31
                            "exception and traceback appears to be lost.\n"
32
                            "A known way to see the child's error, and try to "
33
                            "fix or handle it, is to run the problematic code "
34
                            "as a batch script from a system's Command Prompt. "
35
                            "The child's exception will be printed to the "
36
                            "Command Promt's stderr, and it should be visible "
37
                            "above this error and traceback.\n"
38
                            "Note that if running a jupyter notebook that was "
39
                            "invoked from a Command Prompt, the child's "
40
                            "exception should have been printed to the Command "
41
                            "Prompt on which the notebook is running.")
42
    else:
43
        return None
×
44

45

46
class ParallelSamplingError(Exception):
6✔
47
    def __init__(self, message, chain, warnings=None):
6✔
48
        super().__init__(message)
×
49
        if warnings is None:
×
50
            warnings = []
×
51
        self._chain = chain
×
52
        self._warnings = warnings
×
53

54

55
# Taken from https://hg.python.org/cpython/rev/c4f92b597074
56
class RemoteTraceback(Exception):
6✔
57
    def __init__(self, tb):
6✔
58
        self.tb = tb
×
59

60
    def __str__(self):
6✔
61
        return self.tb
×
62

63

64
class ExceptionWithTraceback:
6✔
65
    def __init__(self, exc, tb):
6✔
66
        tb = traceback.format_exception(type(exc), exc, tb)
×
67
        tb = "".join(tb)
×
68
        self.exc = exc
×
69
        self.tb = '\n"""\n%s"""' % tb
×
70

71
    def __reduce__(self):
6✔
72
        return rebuild_exc, (self.exc, self.tb)
×
73

74

75
def rebuild_exc(exc, tb):
6✔
76
    exc.__cause__ = RemoteTraceback(tb)
×
77
    return exc
×
78

79

80
# Messages
81
# ('writing_done', is_last, sample_idx, tuning, stats, warns)
82
# ('error', warnings, *exception_info)
83

84
# ('abort', reason)
85
# ('write_next',)
86
# ('start',)
87

88

89
class _Process(multiprocessing.Process):
6✔
90
    """Seperate process for each chain.
91
    We communicate with the main process using a pipe,
92
    and send finished samples using shared memory.
93
    """
94

95
    def __init__(self, name, msg_pipe, step_method, shared_point, draws, tune, seed):
6✔
96
        super().__init__(daemon=True, name=name)
4✔
97
        self._msg_pipe = msg_pipe
4✔
98
        self._step_method = step_method
4✔
99
        self._shared_point = shared_point
4✔
100
        self._seed = seed
4✔
101
        self._tt_seed = seed + 1
4✔
102
        self._draws = draws
4✔
103
        self._tune = tune
4✔
104

105
    def run(self):
6✔
106
        try:
4✔
107
            # We do not create this in __init__, as pickling this
108
            # would destroy the shared memory.
109
            self._point = self._make_numpy_refs()
4✔
110
            self._start_loop()
4✔
111
        except KeyboardInterrupt:
1✔
112
            pass
1✔
113
        except BaseException as e:
×
114
            e = ExceptionWithTraceback(e, e.__traceback__)
×
115
            # Send is not blocking so we have to force a wait for the abort
116
            # message
117
            self._msg_pipe.send(("error", None, e))
×
118
            self._wait_for_abortion()
×
119
        finally:
120
            self._msg_pipe.close()
4✔
121

122
    def _wait_for_abortion(self):
6✔
123
        while True:
×
124
            msg = self._recv_msg()
×
125
            if msg[0] == "abort":
×
126
                break
×
127

128
    def _make_numpy_refs(self):
6✔
129
        shape_dtypes = self._step_method.vars_shape_dtype
4✔
130
        point = {}
4✔
131
        for name, (shape, dtype) in shape_dtypes.items():
4✔
132
            array = self._shared_point[name]
4✔
133
            self._shared_point[name] = array
4✔
134
            point[name] = np.frombuffer(array, dtype).reshape(shape)
4✔
135
        return point
4✔
136

137
    def _write_point(self, point):
6✔
138
        for name, vals in point.items():
4✔
139
            self._point[name][...] = vals
4✔
140

141
    def _recv_msg(self):
6✔
142
        return self._msg_pipe.recv()
4✔
143

144
    def _start_loop(self):
6✔
145
        np.random.seed(self._seed)
4✔
146
        theanof.set_tt_rng(self._tt_seed)
4✔
147

148
        draw = 0
4✔
149
        tuning = True
4✔
150

151
        msg = self._recv_msg()
4✔
152
        if msg[0] == "abort":
4✔
153
            raise KeyboardInterrupt()
×
154
        if msg[0] != "start":
4✔
155
            raise ValueError("Unexpected msg " + msg[0])
×
156

157
        while True:
4✔
158
            if draw < self._draws + self._tune:
4✔
159
                try:
4✔
160
                    point, stats = self._compute_point()
4✔
161
                except SamplingError as e:
×
162
                    warns = self._collect_warnings()
×
163
                    e = ExceptionWithTraceback(e, e.__traceback__)
×
164
                    self._msg_pipe.send(("error", warns, e))
×
165
            else:
166
                return
4✔
167

168
            if draw == self._tune:
4✔
169
                self._step_method.stop_tuning()
4✔
170
                tuning = False
4✔
171

172
            msg = self._recv_msg()
4✔
173
            if msg[0] == "abort":
4✔
174
                raise KeyboardInterrupt()
1✔
175
            elif msg[0] == "write_next":
4✔
176
                self._write_point(point)
4✔
177
                is_last = draw + 1 == self._draws + self._tune
4✔
178
                if is_last:
4✔
179
                    warns = self._collect_warnings()
4✔
180
                else:
181
                    warns = None
4✔
182
                self._msg_pipe.send(
4✔
183
                    ("writing_done", is_last, draw, tuning, stats, warns)
184
                )
185
                draw += 1
4✔
186
            else:
187
                raise ValueError("Unknown message " + msg[0])
×
188

189
    def _compute_point(self):
6✔
190
        if self._step_method.generates_stats:
4✔
191
            point, stats = self._step_method.step(self._point)
4✔
192
        else:
193
            point = self._step_method.step(self._point)
2✔
194
            stats = None
2✔
195
        return point, stats
4✔
196

197
    def _collect_warnings(self):
6✔
198
        if hasattr(self._step_method, "warnings"):
4✔
199
            return self._step_method.warnings()
4✔
200
        else:
201
            return []
2✔
202

203

204
class ProcessAdapter:
6✔
205
    """Control a Chain process from the main thread."""
206

207
    def __init__(self, draws, tune, step_method, chain, seed, start):
6✔
208
        self.chain = chain
4✔
209
        process_name = "worker_chain_%s" % chain
4✔
210
        self._msg_pipe, remote_conn = multiprocessing.Pipe()
4✔
211

212
        self._shared_point = {}
4✔
213
        self._point = {}
4✔
214
        for name, (shape, dtype) in step_method.vars_shape_dtype.items():
4✔
215
            size = 1
4✔
216
            for dim in shape:
4✔
217
                size *= int(dim)
4✔
218
            size *= dtype.itemsize
4✔
219
            if size != ctypes.c_size_t(size).value:
4✔
220
                raise ValueError("Variable %s is too large" % name)
×
221

222
            array = multiprocessing.sharedctypes.RawArray("c", size)
4✔
223
            self._shared_point[name] = array
4✔
224
            array_np = np.frombuffer(array, dtype).reshape(shape)
4✔
225
            array_np[...] = start[name]
4✔
226
            self._point[name] = array_np
4✔
227

228
        self._readable = True
4✔
229
        self._num_samples = 0
4✔
230

231
        self._process = _Process(
4✔
232
            process_name,
233
            remote_conn,
234
            step_method,
235
            self._shared_point,
236
            draws,
237
            tune,
238
            seed,
239
        )
240
        # We fork right away, so that the main process can start tqdm threads
241
        try:
4✔
242
            self._process.start()
4✔
243
        except IOError as e:
×
244
            # Something may have gone wrong during the fork / spawn
245
            if e.errno == errno.EPIPE:
×
246
                exc = _get_broken_pipe_exception()
×
247
                if exc is not None:
×
248
                    # Sleep a little to give the child process time to flush
249
                    # all its error message
250
                    time.sleep(0.2)
×
251
                    raise exc
×
252
            raise
×
253

254
    @property
6✔
255
    def shared_point_view(self):
256
        """May only be written to or read between a `recv_draw`
257
        call from the process and a `write_next` or `abort` call.
258
        """
259
        if not self._readable:
4✔
260
            raise RuntimeError()
×
261
        return self._point
4✔
262

263
    def start(self):
6✔
264
        self._msg_pipe.send(("start",))
4✔
265

266
    def write_next(self):
6✔
267
        self._readable = False
4✔
268
        self._msg_pipe.send(("write_next",))
4✔
269

270
    def abort(self):
6✔
271
        self._msg_pipe.send(("abort",))
4✔
272

273
    def join(self, timeout=None):
6✔
274
        self._process.join(timeout)
4✔
275

276
    def terminate(self):
6✔
277
        self._process.terminate()
×
278

279
    @staticmethod
6✔
280
    def recv_draw(processes, timeout=3600):
6✔
281
        if not processes:
4✔
282
            raise ValueError("No processes.")
×
283
        pipes = [proc._msg_pipe for proc in processes]
4✔
284
        ready = multiprocessing.connection.wait(pipes)
4✔
285
        if not ready:
4✔
286
            raise multiprocessing.TimeoutError("No message from samplers.")
×
287
        idxs = {id(proc._msg_pipe): proc for proc in processes}
4✔
288
        proc = idxs[id(ready[0])]
4✔
289
        msg = ready[0].recv()
4✔
290

291
        if msg[0] == "error":
4✔
292
            warns, old_error = msg[1:]
×
293
            if warns is not None:
×
294
                error = ParallelSamplingError(str(old_error), proc.chain, warns)
×
295
            else:
296
                error = RuntimeError("Chain %s failed." % proc.chain)
×
297
            raise error from old_error
×
298
        elif msg[0] == "writing_done":
4✔
299
            proc._readable = True
4✔
300
            proc._num_samples += 1
4✔
301
            return (proc,) + msg[1:]
4✔
302
        else:
303
            raise ValueError("Sampler sent bad message.")
×
304

305
    @staticmethod
6✔
306
    def terminate_all(processes, patience=2):
6✔
307
        for process in processes:
4✔
308
            try:
4✔
309
                process.abort()
4✔
310
            except EOFError:
×
311
                pass
×
312

313
        start_time = time.time()
4✔
314
        try:
4✔
315
            for process in processes:
4✔
316
                timeout = time.time() + patience - start_time
4✔
317
                if timeout < 0:
4✔
318
                    raise multiprocessing.TimeoutError()
×
319
                process.join(timeout)
4✔
320
        except multiprocessing.TimeoutError:
×
321
            logger.warn(
×
322
                "Chain processes did not terminate as expected. "
323
                "Terminating forcefully..."
324
            )
325
            for process in processes:
×
326
                process.terminate()
×
327
            for process in processes:
×
328
                process.join()
×
329

330

331
Draw = namedtuple(
6✔
332
    "Draw", ["chain", "is_last", "draw_idx", "tuning", "stats", "point", "warnings"]
333
)
334

335

336
class ParallelSampler:
6✔
337
    def __init__(
6✔
338
        self,
339
        draws,
340
        tune,
341
        chains,
342
        cores,
343
        seeds,
344
        start_points,
345
        step_method,
346
        start_chain_num=0,
347
        progressbar=True,
348
    ):
349
        if progressbar:
4✔
350
            from tqdm import tqdm
4✔
351

352
        if any(len(arg) != chains for arg in [seeds, start_points]):
4✔
353
            raise ValueError("Number of seeds and start_points must be %s." % chains)
×
354

355
        self._samplers = [
4✔
356
            ProcessAdapter(
357
                draws, tune, step_method, chain + start_chain_num, seed, start
358
            )
359
            for chain, seed, start in zip(range(chains), seeds, start_points)
360
        ]
361

362
        self._inactive = self._samplers.copy()
4✔
363
        self._finished = []
4✔
364
        self._active = []
4✔
365
        self._max_active = cores
4✔
366

367
        self._in_context = False
4✔
368
        self._start_chain_num = start_chain_num
4✔
369

370
        self._progress = None
4✔
371
        self._divergences = 0
4✔
372
        self._desc = "Sampling {0._chains:d} chains, {0._divergences:,d} divergences"
4✔
373
        self._chains = chains
4✔
374
        if progressbar:
4✔
375
            self._progress = tqdm(
4✔
376
                total=chains * (draws + tune),
377
                unit="draws",
378
                desc=self._desc.format(self)
379
            )
380

381
    def _make_active(self):
6✔
382
        while self._inactive and len(self._active) < self._max_active:
4✔
383
            proc = self._inactive.pop(0)
4✔
384
            proc.start()
4✔
385
            proc.write_next()
4✔
386
            self._active.append(proc)
4✔
387

388
    def __iter__(self):
6✔
389
        if not self._in_context:
4✔
390
            raise ValueError("Use ParallelSampler as context manager.")
×
391
        self._make_active()
4✔
392

393
        while self._active:
4✔
394
            draw = ProcessAdapter.recv_draw(self._active)
4✔
395
            proc, is_last, draw, tuning, stats, warns = draw
4✔
396
            if self._progress is not None:
4✔
397
                if not tuning and stats and stats[0].get('diverging'):
4✔
398
                    self._divergences += 1
2✔
399
                    self._progress.set_description(self._desc.format(self))
2✔
400
                self._progress.update()
4✔
401

402
            if is_last:
4✔
403
                proc.join()
4✔
404
                self._active.remove(proc)
4✔
405
                self._finished.append(proc)
4✔
406
                self._make_active()
4✔
407

408
            # We could also yield proc.shared_point_view directly,
409
            # and only call proc.write_next() after the yield returns.
410
            # This seems to be faster overally though, as the worker
411
            # loses less time waiting.
412
            point = {name: val.copy() for name, val in proc.shared_point_view.items()}
4✔
413

414
            # Already called for new proc in _make_active
415
            if not is_last:
4✔
416
                proc.write_next()
4✔
417

418
            yield Draw(proc.chain, is_last, draw, tuning, stats, point, warns)
4✔
419

420
    def __enter__(self):
6✔
421
        self._in_context = True
4✔
422
        return self
4✔
423

424
    def __exit__(self, *args):
6✔
425
        ProcessAdapter.terminate_all(self._samplers)
4✔
426
        if self._progress is not None:
4✔
427
            self._progress.close()
4✔
428

429
def _cpu_count():
6✔
430
    """Try to guess the number of CPUs in the system.
431

432
    We use the number provided by psutil if that is installed.
433
    If not, we use the number provided by multiprocessing, but assume
434
    that half of the cpus are only hardware threads and ignore those.
435
    """
436
    try:
6✔
437
        cpus = multiprocessing.cpu_count() // 2
6✔
438
    except NotImplementedError:
×
439
        cpus = 1
×
440
    return cpus
6✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc