• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyro-ppl / pyro / 3735025322

pending completion
3735025322

Pull #3012

github

GitHub
Merge 4ba3734bd into 3422c3a43
Pull Request #3012: Add make tutorial to docs CI workflow

22222 of 24227 relevant lines covered (91.72%)

2.23 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

38.24
/pyro/contrib/examples/util.py
1
# Copyright (c) 2017-2019 Uber Technologies, Inc.
2
# SPDX-License-Identifier: Apache-2.0
3

4
import os
1✔
5
import sys
1✔
6

7
import torchvision.datasets as datasets
1✔
8
from torch.utils.data import DataLoader
1✔
9
from torchvision import transforms
1✔
10

11

12
class MNIST(datasets.MNIST):
1✔
13
    mirrors = [
1✔
14
        "https://d2hg8soec8ck9v.cloudfront.net/datasets/mnist/"
15
    ] + datasets.MNIST.mirrors
16

17

18
def get_data_loader(
1✔
19
    dataset_name,
20
    data_dir,
21
    batch_size=1,
22
    dataset_transforms=None,
23
    is_training_set=True,
24
    shuffle=True,
25
):
26
    if not dataset_transforms:
×
27
        dataset_transforms = []
×
28
    trans = transforms.Compose([transforms.ToTensor()] + dataset_transforms)
×
29
    if dataset_name == "MNIST":
×
30
        dataset = MNIST
×
31
    else:
32
        dataset = getattr(datasets, dataset_name)
×
33
    print("downloading data")
×
34
    dset = dataset(root=data_dir, train=is_training_set, transform=trans, download=True)
×
35
    print("download complete.")
×
36
    return DataLoader(dset, batch_size=batch_size, shuffle=shuffle)
×
37

38

39
def print_and_log(logger, msg):
1✔
40
    # print and log a message (if a logger is present)
41
    print(msg)
×
42
    sys.stdout.flush()
×
43
    if logger is not None:
×
44
        logger.write("{}\n".format(msg))
×
45
        logger.flush()
×
46

47

48
def get_data_directory(filepath=None):
1✔
49
    if "CI" in os.environ:
1✔
50
        return os.path.expanduser("~/.data")
1✔
51
    return os.path.abspath(os.path.join(os.path.dirname(filepath), ".data"))
×
52

53

54
def _mkdir_p(dirname):
1✔
55
    if not os.path.exists(dirname):
×
56
        try:
×
57
            os.makedirs(dirname)
×
58
        except FileExistsError:
×
59
            pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc