• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pyro-ppl / numpyro / 14091721588

26 Mar 2025 07:20PM CUT coverage: 92.867% (-0.04%) from 92.91%
14091721588

Pull #2005

github

web-flow
Merge 5a1c8a840 into fc3a7b169
Pull Request #2005: Equinox Integration

29 of 33 new or added lines in 1 file covered. (87.88%)

34 existing lines in 1 file now uncovered.

13619 of 14665 relevant lines covered (92.87%)

1.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.6
/numpyro/examples/datasets.py
1
# Copyright Contributors to the Pyro project.
2
# SPDX-License-Identifier: Apache-2.0
3

4
from collections import namedtuple
2✔
5
import csv
2✔
6
import gzip
2✔
7
import io
2✔
8
import os
2✔
9
import pickle
2✔
10
import struct
2✔
11
from urllib.parse import urlparse
2✔
12
from urllib.request import urlretrieve
2✔
13
import warnings
2✔
14
import zipfile
2✔
15

16
import numpy as np
2✔
17

18
from jax import lax
2✔
19

20
from numpyro.util import find_stack_level
2✔
21

22
if "CI" in os.environ:
2✔
23
    DATA_DIR = os.path.expanduser("~/.data")
2✔
24
else:
25
    DATA_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".data"))
×
26
os.makedirs(DATA_DIR, exist_ok=True)
2✔
27

28
dset = namedtuple("dset", ["name", "urls"])
2✔
29

30
BART = dset(
2✔
31
    "bart",
32
    [
33
        "https://github.com/pyro-ppl/datasets/blob/master/bart/bart_0.npz?raw=true",
34
        "https://github.com/pyro-ppl/datasets/blob/master/bart/bart_1.npz?raw=true",
35
        "https://github.com/pyro-ppl/datasets/blob/master/bart/bart_2.npz?raw=true",
36
        "https://github.com/pyro-ppl/datasets/blob/master/bart/bart_3.npz?raw=true",
37
    ],
38
)
39

40
BASEBALL = dset(
2✔
41
    "baseball",
42
    ["https://github.com/pyro-ppl/datasets/blob/master/EfronMorrisBB.txt?raw=true"],
43
)
44

45
BOSTON_HOUSING = dset(
2✔
46
    "boston_housing",
47
    ["https://archive.ics.uci.edu/ml/machine-learning-databases/housing/housing.data"],
48
)
49

50
COVTYPE = dset(
2✔
51
    "covtype", ["https://github.com/pyro-ppl/datasets/blob/master/covtype.npz?raw=true"]
52
)
53

54
DIPPER_VOLE = dset(
2✔
55
    "dipper_vole",
56
    ["https://github.com/pyro-ppl/datasets/blob/master/dipper_vole.zip?raw=true"],
57
)
58

59
MNIST = dset(
2✔
60
    "mnist",
61
    [
62
        "https://github.com/pyro-ppl/datasets/blob/master/mnist/train-images-idx3-ubyte.gz?raw=true",
63
        "https://github.com/pyro-ppl/datasets/blob/master/mnist/train-labels-idx1-ubyte.gz?raw=true",
64
        "https://github.com/pyro-ppl/datasets/blob/master/mnist/t10k-images-idx3-ubyte.gz?raw=true",
65
        "https://github.com/pyro-ppl/datasets/blob/master/mnist/t10k-labels-idx1-ubyte.gz?raw=true",
66
    ],
67
)
68

69
SP500 = dset(
2✔
70
    "SP500", ["https://github.com/pyro-ppl/datasets/blob/master/SP500.csv?raw=true"]
71
)
72

73
UCBADMIT = dset(
2✔
74
    "ucbadmit",
75
    ["https://github.com/pyro-ppl/datasets/blob/master/UCBadmit.csv?raw=true"],
76
)
77

78
LYNXHARE = dset(
2✔
79
    "lynxhare",
80
    ["https://github.com/pyro-ppl/datasets/blob/master/LynxHare.txt?raw=true"],
81
)
82

83
JSB_CHORALES = dset(
2✔
84
    "jsb_chorales",
85
    [
86
        "https://github.com/pyro-ppl/datasets/blob/master/polyphonic/jsb_chorales.pickle?raw=true"
87
    ],
88
)
89

90
HIGGS = dset(
2✔
91
    "higgs",
92
    ["https://archive.ics.uci.edu/ml/machine-learning-databases/00280/HIGGS.csv.gz"],
93
)
94

95
NINE_MERS = dset(
2✔
96
    "9mers",
97
    ["https://github.com/pyro-ppl/datasets/blob/master/9mers_data.pkl?raw=true"],
98
)
99

100
MORTALITY = dset(
2✔
101
    "mortality",
102
    [
103
        "https://github.com/pyro-ppl/datasets/blob/master/simulated_mortality.csv?raw=true"
104
    ],
105
)
106

107

108
def _download(dset):
2✔
109
    for url in dset.urls:
1✔
110
        file = os.path.basename(urlparse(url).path)
1✔
111
        out_path = os.path.join(DATA_DIR, file)
1✔
112
        if not os.path.exists(out_path):
1✔
113
            print("Downloading - {}.".format(url))
1✔
114
            urlretrieve(url, out_path)
1✔
115
            print("Download complete.")
1✔
116

117

118
def load_bart_od():
2✔
UNCOV
119
    _download(BART)
×
120

UNCOV
121
    filenames = [os.path.join(DATA_DIR, f"bart_{i}.npz") for i in range(4)]
×
UNCOV
122
    datasets = [np.load(filename, allow_pickle=True) for filename in filenames]
×
UNCOV
123
    counts = np.vstack([dataset["counts"] for dataset in datasets])
×
UNCOV
124
    return {
×
125
        "stations": datasets[0]["stations"],
126
        "start_date": datasets[0]["start_date"],
127
        "counts": counts,
128
    }
129

130

131
def _load_baseball():
2✔
132
    _download(BASEBALL)
1✔
133

134
    def train_test_split(file):
1✔
135
        train, test, player_names = [], [], []
1✔
136
        with open(file, "r") as f:
1✔
137
            csv_reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE)
1✔
138
            for row in csv_reader:
1✔
139
                player_names.append(row["FirstName"] + " " + row["LastName"])
1✔
140
                at_bats, hits = row["At-Bats"], row["Hits"]
1✔
141
                train.append(np.array([int(at_bats), int(hits)]))
1✔
142
                season_at_bats, season_hits = row["SeasonAt-Bats"], row["SeasonHits"]
1✔
143
                test.append(np.array([int(season_at_bats), int(season_hits)]))
1✔
144
        return np.stack(train), np.stack(test), player_names
1✔
145

146
    train, test, player_names = train_test_split(
1✔
147
        os.path.join(DATA_DIR, "EfronMorrisBB.txt")
148
    )
149
    return {"train": (train, player_names), "test": (test, player_names)}
1✔
150

151

152
def _load_boston_housing():
2✔
153
    _download(BOSTON_HOUSING)
1✔
154
    file_path = os.path.join(DATA_DIR, "housing.data")
1✔
155
    data = np.loadtxt(file_path)
1✔
156
    return {"train": (data[:, :-1], data[:, -1])}
1✔
157

158

159
def _load_covtype():
2✔
160
    _download(COVTYPE)
1✔
161

162
    file_path = os.path.join(DATA_DIR, "covtype.npz")
1✔
163
    data = np.load(file_path)
1✔
164

165
    return {"train": (data["data"], data["target"])}
1✔
166

167

168
def _load_dipper_vole():
2✔
169
    _download(DIPPER_VOLE)
1✔
170

171
    file_path = os.path.join(DATA_DIR, "dipper_vole.zip")
1✔
172
    data = {}
1✔
173
    with zipfile.ZipFile(file_path) as zipper:
1✔
174
        data["dipper"] = (
1✔
175
            np.genfromtxt(zipper.open("dipper_capture_history.csv"), delimiter=",")[
176
                :, 1:
177
            ].astype(int),
178
            np.genfromtxt(zipper.open("dipper_sex.csv"), delimiter=",")[:, 1].astype(
179
                int
180
            ),
181
        )
182
        data["vole"] = (
1✔
183
            np.genfromtxt(
184
                zipper.open("meadow_voles_capture_history.csv"), delimiter=","
185
            )[:, 1:],
186
        )
187

188
    return data
1✔
189

190

191
def _load_mnist():
2✔
192
    _download(MNIST)
1✔
193

194
    def read_label(file):
1✔
195
        with gzip.open(file, "rb") as f:
1✔
196
            f.read(8)
1✔
197
            data = np.frombuffer(f.read(), dtype=np.int8)
1✔
198
            return data
1✔
199

200
    def read_img(file):
1✔
201
        with gzip.open(file, "rb") as f:
1✔
202
            _, _, nrows, ncols = struct.unpack(">IIII", f.read(16))
1✔
203
            data = np.frombuffer(f.read(), dtype=np.uint8) / np.float32(255.0)
1✔
204
            return data.reshape(-1, nrows, ncols)
1✔
205

206
    files = [
1✔
207
        os.path.join(DATA_DIR, os.path.basename(urlparse(url).path))
208
        for url in MNIST.urls
209
    ]
210
    return {
1✔
211
        "train": (read_img(files[0]), read_label(files[1])),
212
        "test": (read_img(files[2]), read_label(files[3])),
213
    }
214

215

216
def _load_sp500():
2✔
217
    _download(SP500)
1✔
218

219
    date, value = [], []
1✔
220
    with open(os.path.join(DATA_DIR, "SP500.csv"), "r") as f:
1✔
221
        csv_reader = csv.DictReader(f, quoting=csv.QUOTE_NONE)
1✔
222
        for row in csv_reader:
1✔
223
            date.append(row["DATE"])
1✔
224
            value.append(float(row["VALUE"]))
1✔
225
    value = np.stack(value)
1✔
226

227
    return {"train": (date, value)}
1✔
228

229

230
def _load_ucbadmit():
2✔
231
    _download(UCBADMIT)
1✔
232

233
    dept, male, applications, admit = [], [], [], []
1✔
234
    with open(os.path.join(DATA_DIR, "UCBadmit.csv")) as f:
1✔
235
        csv_reader = csv.DictReader(
1✔
236
            f,
237
            delimiter=";",
238
            fieldnames=["index", "dept", "gender", "admit", "reject", "applications"],
239
        )
240
        next(csv_reader)  # skip the first row
1✔
241
        for row in csv_reader:
1✔
242
            dept.append(ord(row["dept"]) - ord("A"))
1✔
243
            male.append(row["gender"] == "male")
1✔
244
            applications.append(int(row["applications"]))
1✔
245
            admit.append(int(row["admit"]))
1✔
246

247
    return {
1✔
248
        "train": (
249
            np.stack(dept),
250
            np.stack(male),
251
            np.stack(applications),
252
            np.stack(admit),
253
        )
254
    }
255

256

257
def _load_lynxhare():
2✔
258
    _download(LYNXHARE)
1✔
259

260
    file_path = os.path.join(DATA_DIR, "LynxHare.txt")
1✔
261
    data = np.loadtxt(file_path)
1✔
262

263
    return {"train": (data[:, 0].astype(int), data[:, 1:])}
1✔
264

265

266
def _pad_sequence(sequences):
2✔
267
    # like torch.nn.utils.rnn.pad_sequence with batch_first=True
268
    max_length = max(x.shape[0] for x in sequences)
1✔
269
    padded_sequences = []
1✔
270
    for x in sequences:
1✔
271
        pad = [(0, 0)] * np.ndim(x)
1✔
272
        pad[0] = (0, max_length - x.shape[0])
1✔
273
        padded_sequences.append(np.pad(x, pad))
1✔
274
    return np.stack(padded_sequences)
1✔
275

276

277
def _load_jsb_chorales():
2✔
278
    _download(JSB_CHORALES)
1✔
279

280
    file_path = os.path.join(DATA_DIR, "jsb_chorales.pickle")
1✔
281
    with open(file_path, "rb") as f:
1✔
282
        data = pickle.load(f)
1✔
283

284
    # XXX: we might expose those in `load_dataset` keywords
285
    min_note = 21
1✔
286
    note_range = 88
1✔
287
    processed_dataset = {}
1✔
288
    for split, data_split in data.items():
1✔
289
        processed_dataset[split] = {}
1✔
290
        n_seqs = len(data_split)
1✔
291
        processed_dataset[split]["sequence_lengths"] = np.zeros(n_seqs, dtype=int)
1✔
292
        processed_dataset[split]["sequences"] = []
1✔
293
        for seq in range(n_seqs):
1✔
294
            seq_length = len(data_split[seq])
1✔
295
            processed_dataset[split]["sequence_lengths"][seq] = seq_length
1✔
296
            processed_sequence = np.zeros((seq_length, note_range))
1✔
297
            for t in range(seq_length):
1✔
298
                note_slice = np.array(list(data_split[seq][t])) - min_note
1✔
299
                slice_length = len(note_slice)
1✔
300
                if slice_length > 0:
1✔
301
                    processed_sequence[t, note_slice] = np.ones(slice_length)
1✔
302
            processed_dataset[split]["sequences"].append(processed_sequence)
1✔
303

304
    for k, v in processed_dataset.items():
1✔
305
        lengths = v["sequence_lengths"]
1✔
306
        sequences = v["sequences"]
1✔
307
        processed_dataset[k] = (lengths, _pad_sequence(sequences).astype("int32"))
1✔
308
    return processed_dataset
1✔
309

310

311
def _load_higgs(num_datapoints):
2✔
UNCOV
312
    warnings.warn(
×
313
        "Higgs is a 2.6 GB dataset",
314
        stacklevel=find_stack_level(),
315
    )
UNCOV
316
    _download(HIGGS)
×
317

UNCOV
318
    file_path = os.path.join(DATA_DIR, "HIGGS.csv.gz")
×
UNCOV
319
    with io.TextIOWrapper(gzip.open(file_path, "rb")) as f:
×
UNCOV
320
        csv_reader = csv.reader(f, delimiter=",", quoting=csv.QUOTE_NONE)
×
UNCOV
321
        obs = []
×
UNCOV
322
        data = []
×
UNCOV
323
        for i, row in enumerate(csv_reader):
×
UNCOV
324
            obs.append(int(float(row[0])))
×
UNCOV
325
            data.append([float(v) for v in row[1:]])
×
UNCOV
326
            if num_datapoints and i > num_datapoints:
×
UNCOV
327
                break
×
UNCOV
328
    obs = np.stack(obs)
×
UNCOV
329
    data = np.stack(data)
×
UNCOV
330
    (n,) = obs.shape
×
331

UNCOV
332
    return {
×
333
        "train": (data[: -(n // 20)], obs[: -(n // 20)]),
334
        "test": (data[-(n // 20) :], obs[-(n // 20) :]),
335
    }  # standard split -500_000: as test
336

337

338
def _load_9mers():
2✔
339
    _download(NINE_MERS)
1✔
340
    file_path = os.path.join(DATA_DIR, "9mers_data.pkl")
1✔
341
    return pickle.load(open(file_path, "rb"))
1✔
342

343

344
def _load_mortality():
2✔
345
    _download(MORTALITY)
1✔
346

347
    a, s1, s2, t, deaths, population = [], [], [], [], [], []
1✔
348
    with open(os.path.join(DATA_DIR, "simulated_mortality.csv")) as f:
1✔
349
        csv_reader = csv.DictReader(
1✔
350
            f,
351
            fieldnames=[
352
                "age_group",
353
                "year",
354
                "a",
355
                "s1",
356
                "s2",
357
                "t",
358
                "deaths",
359
                "population",
360
            ],
361
        )
362
        next(csv_reader)  # skip the first row
1✔
363
        for row in csv_reader:
1✔
364
            a.append(int(row["a"]))
1✔
365
            s1.append(int(row["s1"]))
1✔
366
            s2.append(int(row["s2"]))
1✔
367
            t.append(int(row["t"]))
1✔
368
            deaths.append(int(row["deaths"]))
1✔
369
            population.append(int(row["population"]))
1✔
370

371
    return {
1✔
372
        "train": (
373
            np.stack(a),
374
            np.stack(s1),
375
            np.stack(s2),
376
            np.stack(t),
377
            np.stack(deaths),
378
            np.stack(population),
379
        )
380
    }
381

382

383
def _load(dset, num_datapoints=-1):
2✔
384
    if dset == BASEBALL:
1✔
385
        return _load_baseball()
1✔
386
    elif dset == BOSTON_HOUSING:
1✔
387
        return _load_boston_housing()
1✔
388
    elif dset == COVTYPE:
1✔
389
        return _load_covtype()
1✔
390
    elif dset == DIPPER_VOLE:
1✔
391
        return _load_dipper_vole()
1✔
392
    elif dset == MNIST:
1✔
393
        return _load_mnist()
1✔
394
    elif dset == SP500:
1✔
395
        return _load_sp500()
1✔
396
    elif dset == UCBADMIT:
1✔
397
        return _load_ucbadmit()
1✔
398
    elif dset == LYNXHARE:
1✔
399
        return _load_lynxhare()
1✔
400
    elif dset == JSB_CHORALES:
1✔
401
        return _load_jsb_chorales()
1✔
402
    elif dset == HIGGS:
1✔
UNCOV
403
        return _load_higgs(num_datapoints)
×
404
    elif dset == NINE_MERS:
1✔
405
        return _load_9mers()
1✔
406
    elif dset == MORTALITY:
1✔
407
        return _load_mortality()
1✔
UNCOV
408
    raise ValueError("Dataset - {} not found.".format(dset.name))
×
409

410

411
def iter_dataset(dset, batch_size=None, split="train", shuffle=True):
2✔
UNCOV
412
    arrays = _load(dset)[split]
×
UNCOV
413
    num_records = len(arrays[0])
×
UNCOV
414
    idxs = np.arange(num_records)
×
UNCOV
415
    if not batch_size:
×
UNCOV
416
        batch_size = num_records
×
UNCOV
417
    if shuffle:
×
UNCOV
418
        idxs = np.random.permutation(idxs)
×
UNCOV
419
    for i in range(num_records // batch_size):
×
UNCOV
420
        start_idx = i * batch_size
×
UNCOV
421
        end_idx = min((i + 1) * batch_size, num_records)
×
UNCOV
422
        yield tuple(a[idxs[start_idx:end_idx]] for a in arrays)
×
423

424

425
def load_dataset(
2✔
426
    dset,
427
    batch_size=None,
428
    split="train",
429
    shuffle=True,
430
    num_datapoints=None,
431
):
432
    data = _load(dset, num_datapoints)
1✔
433
    if isinstance(data, dict):
1✔
434
        arrays = data[split]
1✔
435
    num_records = len(arrays[0])
1✔
436
    idxs = np.arange(num_records)
1✔
437
    if not batch_size:
1✔
438
        batch_size = num_records
1✔
439

440
    def init():
1✔
441
        return (
1✔
442
            num_records // batch_size,
443
            np.random.permutation(idxs) if shuffle else idxs,
444
        )
445

446
    def get_batch(i=0, idxs=idxs):
1✔
447
        ret_idx = lax.dynamic_slice_in_dim(idxs, i * batch_size, batch_size)
1✔
448
        return tuple(
1✔
449
            np.take(a, ret_idx, axis=0)
450
            if isinstance(a, list)
451
            else lax.index_take(a, (ret_idx,), axes=(0,))
452
            for a in arrays
453
        )
454

455
    return init, get_batch
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc