• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 17467328657

04 Sep 2025 02:09PM UTC coverage: 87.958%. First build
17467328657

Pull #233

github

web-flow
Merge 2146e58e4 into 8ea3087db
Pull Request #233: Allow batching in as_numpy_iterator_rust

20 of 21 new or added lines in 2 files covered. (95.24%)

2666 of 3031 relevant lines covered (87.96%)

0.88 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.18
/tests/io/iteration/test_rust_batched_generator.py
1
# Copyright 2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from pathlib import Path
1✔
16
import pytest
1✔
17
import random
1✔
18
from typing import Union
1✔
19
import uuid
1✔
20

21
import numpy as np
1✔
22

23
import sedpack
1✔
24
from sedpack.io import Dataset
1✔
25
from sedpack.io.iteration import RustBatchedGenerator
1✔
26
from sedpack.io.metadata import DatasetStructure, Metadata
1✔
27
from sedpack.io.shard_info_iterator import CachedShardInfoIterator
1✔
28
from sedpack.io.types import TRAIN_SPLIT, CompressionT, ShardFileTypeT
1✔
29

30

31
@pytest.fixture(scope="module")
1✔
32
def dataset_and_values(tmpdir_factory) -> None:
1✔
33
    data_points: int = 1_024
1✔
34
    dtype: str = "float32"
1✔
35

36
    # Values saved in the dataset.
37
    values = {
1✔
38
        "fixed": np.random.random((data_points, 138)).astype(dtype),
39
        "fixed_2d": np.random.random((data_points, 3, 5)).astype(dtype),
40
        # TODO(reintroduce) when https://github.com/google/sedpack/pull/227 is
41
        # merged
42
        #"dynamic_shape": [
43
        #    uuid.uuid4().hex[:random.randint(15, 25)]
44
        #    for _ in range(data_points)
45
        #],
46
    }
47
    tmpdir = tmpdir_factory.mktemp("end_2_end_data")
1✔
48

49
    tiny_experiment_path: Path = Path(tmpdir) / "e2e_experiment"
1✔
50

51
    # Create a dataset
52
    dataset_metadata = Metadata(description="Test of the lib")
1✔
53

54
    example_attributes = [
1✔
55
        sedpack.io.metadata.Attribute(
56
            name="fixed",
57
            dtype=str(dtype),
58
            shape=values["fixed"][0].shape,
59
        ),
60
        sedpack.io.metadata.Attribute(
61
            name="fixed_2d",
62
            dtype=str(dtype),
63
            shape=values["fixed_2d"][0].shape,
64
        ),
65
        #sedpack.io.metadata.Attribute(
66
        #    name="dynamic_shape",
67
        #    dtype="str",
68
        #    shape=(),
69
        #),
70
    ]
71

72
    dataset_structure = sedpack.io.metadata.DatasetStructure(
1✔
73
        saved_data_description=example_attributes,
74
        compression="LZ4",
75
        examples_per_shard=24,
76
        shard_file_type="fb",
77
    )
78

79
    # Test attribute_by_name
80
    for attribute in example_attributes:
1✔
81
        assert dataset_structure.attribute_by_name(
1✔
82
            attribute_name=attribute.name) == attribute
83

84
    dataset = Dataset.create(
1✔
85
        path=tiny_experiment_path,
86
        metadata=dataset_metadata,
87
        dataset_structure=dataset_structure,
88
    )
89

90
    # Fill data in the dataset
91

92
    with dataset.filler() as filler:
1✔
93
        for i in range(data_points):
1✔
94
            filler.write_example(
1✔
95
                values={
96
                    name: val[i] for name, val in values.items()
97
                },
98
                split=TRAIN_SPLIT,
99
            )
100

101
    # Check the data is correct
102
    # Reopen the dataset
103
    dataset = Dataset(tiny_experiment_path)
1✔
104
    dataset.check()
1✔
105

106
    yield (dataset, values)
1✔
107

108
    # Teardown
109

110

111
def test_wrong_file_paralelism() -> None:
1✔
112
    with pytest.raises(
1✔
113
            ValueError,
114
            match="The argument file_parallelism should be positive.*",
115
    ):
116
        g = RustBatchedGenerator(
1✔
117
            dataset_path=Path(),
118
            dataset_structure=DatasetStructure(),
119
            shard_iterator=[],
120
            process_record=None,
121
            file_parallelism=0,
122
            batch_size=1,
123
        )
124

125

126
def test_wrong_shard_type() -> None:
1✔
127
    with pytest.raises(
1✔
128
            ValueError,
129
            match="RustBatchedGenerator is implemented only for FlatBuffers.",
130
    ):
131
        g = RustBatchedGenerator(
1✔
132
            dataset_path=Path(),
133
            dataset_structure=DatasetStructure(shard_file_type="tfrec"),
134
            shard_iterator=[],
135
            process_record=None,
136
            file_parallelism=1,
137
            batch_size=1,
138
        )
139

140

141
def test_wrong_compression() -> None:
1✔
142
    with pytest.raises(
1✔
143
            ValueError,
144
            match=
145
            "The compression .* is not among the supported compressions: .*",
146
    ):
147
        g = RustBatchedGenerator(
1✔
148
            dataset_path=Path(),
149
            dataset_structure=DatasetStructure(
150
                shard_file_type="fb",
151
                compression="ZIP",
152
            ),
153
            shard_iterator=[],
154
            process_record=None,
155
            file_parallelism=1,
156
            batch_size=1,
157
        )
158

159

160
@pytest.mark.parametrize("batch_size", [1, 2, 7])
1✔
161
def test_end_to_end_rust_batched(
1✔
162
    batch_size,
163
    dataset_and_values,
164
):
165
    dataset, values = dataset_and_values
1✔
166

167
    with RustBatchedGenerator(
1✔
168
            dataset_path=dataset.path,
169
            dataset_structure=dataset.dataset_structure,
170
            shard_iterator=CachedShardInfoIterator(
171
                dataset_path=dataset.path,
172
                dataset_info=dataset.dataset_info,
173
                split="train",
174
                repeat=False,
175
                shards=None,
176
                custom_metadata_type_limit=None,
177
                shard_filter=None,
178
                shuffle=0,
179
            ),
180
            batch_size=batch_size,
181
            process_record=None,
182
            file_parallelism=8,
183
    ) as g:
184
        index: int = 0
1✔
185
        for batch in g():
1✔
186
            current_batch_size: int = -1
1✔
187

188
            for name, attribute_values in batch.items():
1✔
189
                if current_batch_size < 0:
1✔
190
                    current_batch_size = len(attribute_values)
1✔
191
                else:
192
                    assert len(attribute_values) == current_batch_size
1✔
193

194
                for i in range(current_batch_size):
1✔
195
                    if name == "dynamic_shape":
1✔
196
                        assert values[name][index + i] == attribute_values[i]
×
197
                    else:
198
                        assert (values[name][index +
1✔
199
                                             i] == attribute_values[i]).all()
200

201
            index += current_batch_size
1✔
202

203

204
@pytest.mark.parametrize("batch_size", [1, 3])
1✔
205
def test_end_to_end_as_numpy_iterator_rust(
1✔
206
    batch_size,
207
    dataset_and_values,
208
):
209
    dataset, values = dataset_and_values
1✔
210
    index: int = 0
1✔
211

212
    for batch in dataset.as_numpy_iterator_rust(
1✔
213
            split="train",
214
            process_record=None,
215
            shards=None,
216
            shard_filter=None,
217
            repeat=False,
218
            batch_size=batch_size,
219
            file_parallelism=8,
220
            shuffle=0,
221
    ):
222
        current_batch_size: int = -1
1✔
223

224
        for name, attribute_values in batch.items():
1✔
225
            if current_batch_size < 0:
1✔
226
                current_batch_size = len(attribute_values)
1✔
227
            else:
228
                assert len(attribute_values) == current_batch_size
1✔
229

230
            for i in range(current_batch_size):
1✔
231
                if name == "dynamic_shape":
1✔
NEW
232
                    assert values[name][index + i] == attribute_values[i]
×
233
                else:
234
                    assert (values[name][index +
1✔
235
                                         i] == attribute_values[i]).all()
236

237
        index += current_batch_size
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc