• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 17561841648

08 Sep 2025 07:28PM UTC coverage: 88.513% (+0.6%) from 87.958%
17561841648

push

github

web-flow
Improve saving and loading of non-NumPy attributes (#227)

Improve saving of `int`, `str`, and `bytes` attributes. This PR mostly
fixes existing bugs. It might cause a regression and incompatibility
with older versions (version bumped).

- tfrec str is represented as bytes (which is consistent with
TensorFlow)
- int is represented as int64, which might be revisited later to enable
arbitrarily large integers
- bytes in npz are stored as contiguous array with indexes. This is not
the most robust implementation. Otherwise we would not allow dynamic
size (with `examples_per_shard > 1`). The other option of saving them as
padded size arrays was causing length problems when decoding.

---------

Co-authored-by: wsxrdv <111074929+wsxrdv@users.noreply.github.com>

208 of 215 new or added lines in 6 files covered. (96.74%)

3 existing lines in 2 files now uncovered.

2851 of 3221 relevant lines covered (88.51%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.06
/tests/io/test_end2end_dtypes.py
1
# Copyright 2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from pathlib import Path
1✔
16
from typing import Any
1✔
17
import random
1✔
18

19
import numpy as np
1✔
20
import numpy.typing as npt
1✔
21
import pytest
1✔
22

23
import sedpack
1✔
24
from sedpack.io import Dataset
1✔
25
from sedpack.io.shard_info_iterator import ShardInfoIterator
1✔
26
from sedpack.io import Metadata
1✔
27
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT
1✔
28

29

30
def dataset_and_values_dynamic_shape(
1✔
31
    tmpdir: str | Path,
32
    shard_file_type: str,
33
    compression: str,
34
    dtypes: list[str],
35
    items: int,
36
) -> (Dataset, dict[str, list[Any]]):
37
    values: dict[str, list[Any]] = {}
1✔
38
    ds_path = Path(tmpdir) / f"e2e_{shard_file_type}_{'_'.join(dtypes)}"
1✔
39
    dataset_metadata = Metadata(description="Test of the lib")
1✔
40

41
    # The order should not play a role.
42
    random.shuffle(dtypes)
1✔
43

44
    example_attributes = [
1✔
45
        sedpack.io.metadata.Attribute(
46
            name=f"attribute_{dtype}",
47
            dtype=dtype,
48
            shape=(),
49
        ) for dtype in dtypes
50
    ]
51

52
    for dtype in dtypes:
1✔
53
        values[f"attribute_{dtype}"] = []
1✔
54

55
        match dtype:
1✔
56
            case "int":
1✔
57
                for _ in range(items):
1✔
58
                    # TODO larger range than just int64
59
                    values[f"attribute_{dtype}"].append(
1✔
60
                        random.randint(-2**60, 2**60))
61
            case "str":
1✔
62
                long_string = "Ḽơᶉëᶆ ȋṕšᶙṁ ḍỡḽǭᵳ ʂǐť ӓṁệẗ, ĉṓɲṩḙċťᶒțûɾ" \
1✔
63
                      "https://arxiv.org/abs/2306.07249 ḹẩḇőꝛế" \
64
                      "ấɖḯƥĭṩčįɳġ ḝłįʈ, șếᶑ ᶁⱺ ẽḭŭŝḿꝋď ṫĕᶆᶈṓɍ ỉñḉīḑȋᵭṵńť ṷŧ" \
65
                      ":(){ :|:& };: éȶ đꝍꞎôꝛȇ ᵯáꞡᶇā ąⱡîɋṹẵ."
66
                for _ in range(items):
1✔
67
                    begin: int = random.randint(0, len(long_string) // 2)
1✔
68
                    end: int = random.randint(begin + 1, len(long_string))
1✔
69
                    values[f"attribute_{dtype}"].append(long_string[begin:end])
1✔
70
            case "bytes":
1✔
71
                for _ in range(items):
1✔
72
                    values[f"attribute_{dtype}"].append(
1✔
73
                        np.random.randint(
74
                            0,
75
                            256,
76
                            size=random.randint(5, 20),
77
                            dtype=np.uint8,
78
                        ).tobytes())
79

80
    dataset_structure = sedpack.io.metadata.DatasetStructure(
1✔
81
        saved_data_description=example_attributes,
82
        compression=compression,
83
        examples_per_shard=3,
84
        shard_file_type=shard_file_type,
85
    )
86

87
    # Test attribute_by_name
88
    for attribute in example_attributes:
1✔
89
        assert dataset_structure.attribute_by_name(
1✔
90
            attribute_name=attribute.name) == attribute
91

92
    dataset = Dataset.create(
1✔
93
        path=ds_path,
94
        metadata=dataset_metadata,
95
        dataset_structure=dataset_structure,
96
    )
97

98
    # Fill data in the dataset
99

100
    with dataset.filler() as filler:
1✔
101
        for i in range(items):
1✔
102
            filler.write_example(
1✔
103
                values={
104
                    name: value[i] for name, value in values.items()
105
                },
106
                split=TRAIN_SPLIT,
107
            )
108

109
    # Check the data is correct
110
    # Reopen the dataset
111
    dataset = Dataset(ds_path)
1✔
112
    dataset.check()
1✔
113

114
    return (values, dataset)
1✔
115

116

117
@pytest.fixture(
1✔
118
    scope="module",
119
    params=[
120
        {
121
            "dtypes": ["str"],
122
            "compression": "GZIP",
123
        },
124
        {
125
            "dtypes": ["bytes"],
126
            "compression": "GZIP",
127
        },
128
        {
129
            "dtypes": ["int"],
130
            "compression": "GZIP",
131
        },
132
        {
133
            "dtypes": ["str", "bytes", "int"],
134
            "compression": "GZIP",
135
        },
136
    ],
137
)
138
def values_and_dataset_tfrec(request, tmpdir_factory) -> None:
1✔
139
    shard_file_type: str = "tfrec"
1✔
140
    yield dataset_and_values_dynamic_shape(
1✔
141
        tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"),
142
        shard_file_type=shard_file_type,
143
        compression=request.param["compression"],
144
        dtypes=request.param["dtypes"],
145
        items=137,
146
    )
147
    # Teardown.
148

149

150
@pytest.fixture(
1✔
151
    scope="module",
152
    params=[
153
        {
154
            "dtypes": ["str"],
155
            "compression": "ZIP",
156
        },
157
        {
158
            "dtypes": ["bytes"],
159
            "compression": "ZIP",
160
        },
161
        {
162
            "dtypes": ["int"],
163
            "compression": "ZIP",
164
        },
165
        {
166
            "dtypes": ["str", "bytes", "int"],
167
            "compression": "ZIP",
168
        },
169
    ],
170
)
171
def values_and_dataset_npz(request, tmpdir_factory) -> None:
1✔
172
    shard_file_type: str = "npz"
1✔
173
    yield dataset_and_values_dynamic_shape(
1✔
174
        tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"),
175
        shard_file_type=shard_file_type,
176
        compression=request.param["compression"],
177
        dtypes=request.param["dtypes"],
178
        items=137,
179
    )
180
    # Teardown.
181

182

183
@pytest.fixture(
1✔
184
    scope="module",
185
    params=[
186
        {
187
            "dtypes": ["str"],
188
            "compression": "LZ4",
189
        },
190
        {
191
            "dtypes": ["bytes"],
192
            "compression": "LZ4",
193
        },
194
        {
195
            "dtypes": ["int"],
196
            "compression": "LZ4",
197
        },
198
        {
199
            "dtypes": ["str", "bytes", "int"],
200
            "compression": "LZ4",
201
        },
202
    ],
203
)
204
def values_and_dataset_fb(request, tmpdir_factory) -> None:
1✔
205
    shard_file_type: str = "fb"
1✔
206
    yield dataset_and_values_dynamic_shape(
1✔
207
        tmpdir=tmpdir_factory.mktemp(f"dtype_{shard_file_type}"),
208
        shard_file_type=shard_file_type,
209
        compression=request.param["compression"],
210
        dtypes=request.param["dtypes"],
211
        items=137,
212
    )
213
    # Teardown.
214

215

216
def check_iteration_of_values(
1✔
217
    method: str,
218
    dataset: Dataset,
219
    values: dict[str, list[Any]],
220
) -> None:
221
    match method:
1✔
222
        case "as_tfdataset":
1✔
223
            for i, example in enumerate(
1✔
224
                    dataset.as_tfdataset(
225
                        split=TRAIN_SPLIT,
226
                        shuffle=0,
227
                        repeat=False,
228
                        batch_size=1,
229
                    )):
230
                assert len(example) == len(values)
1✔
231

232
                # No idea how to have an actual string or bytes in TensorFlow.
233
                # Maybe it is best to leave it as a tensor anyway since that is
234
                # the "native" type.
235

236
                for name, returned_batch in example.items():
1✔
237
                    assert returned_batch == values[name][i:i + 1]
1✔
238
        case "as_numpy_iterator":
1✔
239
            for i, example in enumerate(
1✔
240
                    dataset.as_numpy_iterator(
241
                        split=TRAIN_SPLIT,
242
                        shuffle=0,
243
                        repeat=False,
244
                    )):
245
                assert len(example) == len(values)
1✔
246
                for name, returned_value in example.items():
1✔
247
                    if dataset.dataset_structure.shard_file_type != "tfrec":
1✔
248
                        assert returned_value == values[name][i]
1✔
249
                        assert type(returned_value) == type(values[name][i])
1✔
250
                    else:
251
                        if "attribute_str" == name:
1✔
252
                            assert returned_value == values[name][i].encode(
1✔
253
                                "utf-8")
254
                        else:
255
                            assert returned_value == values[name][i]
1✔
256
        case "as_numpy_iterator_concurrent":
1✔
257
            for i, example in enumerate(
1✔
258
                    dataset.as_numpy_iterator_concurrent(
259
                        split=TRAIN_SPLIT,
260
                        shuffle=0,
261
                        repeat=False,
262
                    )):
263
                assert len(example) == len(values)
1✔
264
                for name, returned_value in example.items():
1✔
265
                    if dataset.dataset_structure.shard_file_type != "tfrec":
1✔
266
                        assert returned_value == values[name][i]
1✔
267
                        assert type(returned_value) == type(values[name][i])
1✔
268
                    else:
269
                        if "attribute_str" == name:
1✔
270
                            assert returned_value == values[name][i].encode(
1✔
271
                                "utf-8")
272
                        else:
273
                            assert returned_value == values[name][i]
1✔
274
        case "as_numpy_iterator_rust":
1✔
275
            for i, example in enumerate(
1✔
276
                    dataset.as_numpy_iterator_concurrent(
277
                        split=TRAIN_SPLIT,
278
                        shuffle=0,
279
                        repeat=False,
280
                    )):
281
                assert len(example) == len(values)
1✔
282
                for name, returned_value in example.items():
1✔
283
                    assert returned_value == values[name][i]
1✔
284
                    assert type(returned_value) == type(values[name][i])
1✔
285

286
    # We tested everything
287
    if i + 1 != len(next(iter(values.values()))):
1✔
NEW
288
        raise AssertionError("Not all examples have been iterated")
×
289

290
    # Number of shards matches
291
    full_iterator = ShardInfoIterator(
1✔
292
        dataset_path=dataset.path,
293
        dataset_info=dataset.dataset_info,
294
        split=None,
295
    )
296
    number_of_all_shards: int = full_iterator.number_of_shards()
1✔
297
    assert number_of_all_shards == len(full_iterator)
1✔
298
    assert number_of_all_shards == len(list(full_iterator))
1✔
299
    assert number_of_all_shards == sum(
1✔
300
        ShardInfoIterator(
301
            dataset_path=dataset.path,
302
            dataset_info=dataset.dataset_info,
303
            split=split,
304
        ).number_of_shards() for split in ["train", "test", "holdout"])
305

306

307
@pytest.mark.parametrize("method", [
1✔
308
    "as_tfdataset",
309
    "as_numpy_iterator",
310
    "as_numpy_iterator_concurrent",
311
])
312
def test_end2end_dtypes_str_tfrec(
1✔
313
    method: str,
314
    values_and_dataset_tfrec,
315
) -> None:
316
    values, dataset = values_and_dataset_tfrec
1✔
317
    check_iteration_of_values(
1✔
318
        method=method,
319
        dataset=dataset,
320
        values=values,
321
    )
322

323

324
@pytest.mark.parametrize("method", [
1✔
325
    "as_numpy_iterator",
326
    "as_numpy_iterator_concurrent",
327
])
328
def test_end2end_dtypes_str_npz(
1✔
329
    method: str,
330
    values_and_dataset_npz,
331
) -> None:
332
    values, dataset = values_and_dataset_npz
1✔
333
    check_iteration_of_values(
1✔
334
        method=method,
335
        dataset=dataset,
336
        values=values,
337
    )
338

339

340
@pytest.mark.parametrize("method", [
1✔
341
    "as_numpy_iterator",
342
    "as_numpy_iterator_concurrent",
343
    "as_numpy_iterator_rust",
344
])
345
def test_end2end_dtypes_str_fb(
1✔
346
    method: str,
347
    values_and_dataset_fb,
348
) -> None:
349
    values, dataset = values_and_dataset_fb
1✔
350
    check_iteration_of_values(
1✔
351
        method=method,
352
        dataset=dataset,
353
        values=values,
354
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc