• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 17584283333

09 Sep 2025 01:30PM UTC coverage: 88.638%. First build
17584283333

Pull #235

github

web-flow
Merge b7044abdd into 198cd9828
Pull Request #235: Improve Rust batching

37 of 38 new or added lines in 5 files covered. (97.37%)

2871 of 3239 relevant lines covered (88.64%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.65
/src/sedpack/io/iteration/rust_batched_generator.py
1
# Copyright 2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Rust batched generator object wrapping the rust object to behave nicely with
15
TensorFlow."""
16
import itertools
1✔
17
import os
1✔
18
from pathlib import Path
1✔
19
from types import TracebackType
1✔
20
from typing import (
1✔
21
    Callable,
22
    Iterable,
23
    Iterator,
24
    Type,
25
)
26
from typing_extensions import Self
1✔
27

28
import numpy as np
1✔
29

30
from sedpack.io.flatbuffer import IterateShardFlatBuffer
1✔
31
from sedpack.io.metadata import DatasetStructure
1✔
32
from sedpack.io.shard.iterate_shard_base import T
1✔
33
from sedpack.io.shard_file_metadata import ShardInfo
1✔
34
from sedpack.io.types import BatchT, ExampleT
1✔
35

36
from sedpack._sedpack_rs import BatchedRustIter
1✔
37

38

39
class RustBatchedGenerator:
1✔
40
    """Similar to sedpack.io.iteration.RustGenerator with batching.
41
    Experimental API, expect breaking changes.
42
    """
43

44
    def __init__(
1✔
45
        self,
46
        *,
47
        dataset_path: Path,
48
        dataset_structure: DatasetStructure,
49
        shard_iterator: Iterable[ShardInfo],
50
        batch_size: int,
51
        process_batch: Callable[[BatchT], T] | None = None,
52
        file_parallelism: int = os.cpu_count() or 1,
53
    ) -> None:
54
        """A reentrant generator.
55

56
        Args:
57

58
          dataset_path (Path): The root path of the dataset.
59

60
          dataset_structure (DatasetStructure): The structure of the dataset.
61

62
          shard_iterator: (Iterable[ShardInfo]): How the shards should be
63
          iterated.
64

65
          batch_size (int): Size of the batches.
66

67
          process_batch (Callable[[BatchT], T] | None): Optional transformation
68
          of whole batch of examples.
69

70
          file_parallelism (int): How many files to read in parallel.
71
        """
72
        self._iter: BatchedRustIter | None  # type: ignore[no-any-unimported]
1✔
73
        self._iter = None
1✔
74
        self._stopped: bool = False
1✔
75

76
        # Workaround until BatchedRustIter supports an Iterable[ShardInfo]. Take
77
        # _shard_chunk_size shard paths at once.
78
        self._shard_chunk_size: int = 1_000_000
1✔
79

80
        # Check file_parallelism is positive.
81
        if file_parallelism <= 0:
1✔
82
            raise ValueError("The argument file_parallelism should be "
1✔
83
                             f"positive but is {file_parallelism}")
84

85
        self._dataset_path: Path = dataset_path
1✔
86
        self._dataset_structure: DatasetStructure = dataset_structure
1✔
87
        # Make sure that any iteration on shard_iterator advances instead of
88
        # starting again.
89
        self._shard_iterator: Iterator[ShardInfo] = iter(shard_iterator)
1✔
90
        self._process_batch: Callable[[BatchT], T] | None = process_batch
1✔
91
        self._batch_size: int = batch_size
1✔
92
        self._file_parallelism: int = file_parallelism
1✔
93

94
        # Which attributes have fixed shapes and which do not.
95
        self._has_fixed_shape: tuple[bool, ...] = tuple(
1✔
96
            not attribute.has_variable_size()
97
            for attribute in dataset_structure.saved_data_description)
98

99
        # Only FlatBuffers are supported.
100
        if dataset_structure.shard_file_type != "fb":
1✔
101
            raise ValueError(
1✔
102
                "RustBatchedGenerator is implemented only for FlatBuffers.")
103

104
        # Check if the compression type is supported by Rust.
105
        supported_compressions = BatchedRustIter.supported_compressions()
1✔
106
        if dataset_structure.compression not in supported_compressions:
1✔
107
            raise ValueError(
1✔
108
                f"The compression {dataset_structure.compression} is not "
109
                "among the supported compressions: {supported_compressions}")
110

111
        def to_dict(example: list[np.typing.NDArray[np.uint8]]) -> BatchT:
1✔
112
            result: BatchT = {}
1✔
113
            for np_bytes, attribute in zip(
1✔
114
                    example, dataset_structure.saved_data_description):
115
                result[attribute.name] = IterateShardFlatBuffer.decode_batched(
1✔
116
                    np_bytes=np_bytes,
117
                    attribute=attribute,
118
                    batch_size=-1,
119
                )
120
            return result
1✔
121

122
        self._to_dict = to_dict
1✔
123

124
    def __enter__(self) -> Self:
1✔
125
        """Enter the context manager (takes care of freeing memory held by
126
        Rust).
127
        """
128
        return self
1✔
129

130
    def __exit__(
1✔
131
        self,
132
        exc_type: Type[BaseException] | None,
133
        exc_value: BaseException | None,
134
        exc_tb: TracebackType | None,
135
    ) -> None:
136
        """Drop the rust data structure holding content of open files and
137
        future examples.
138
        """
139
        if self._iter is not None:
1✔
140
            self._iter.__exit__(exc_type, exc_value, exc_tb)
×
141

142
    def __call__(self) -> Iterable[ExampleT] | Iterable[T]:
1✔
143
        """Return an iterable.
144
        """
145
        while not self._stopped:
1✔
146
            yield from self._single_iter()
1✔
147

148
    def _single_iter(self) -> Iterable[ExampleT] | Iterable[T]:
1✔
149
        """Iterate over a single chunk of shards.
150
        """
151
        if self._iter is None:
1✔
152
            shard_paths: list[str] = [
1✔
153
                str(self._dataset_path / s.file_infos[0].file_path)
154
                for s in itertools.islice(
155
                    self._shard_iterator,
156
                    self._shard_chunk_size,
157
                )
158
            ]
159

160
            if not shard_paths:
1✔
161
                # No shards to iterate.
162
                self._stopped = True
1✔
163
                return
1✔
164

165
            self._iter = BatchedRustIter(
1✔
166
                files=shard_paths,
167
                threads=self._file_parallelism,
168
                compression=self._dataset_structure.compression,
169
                batch_size=self._batch_size,
170
                has_fixed_shape=self._has_fixed_shape,
171
            )
172
            # Manually calling __enter__ and __exit__ -- see class docstring.
173
            self._iter.__enter__()  # pylint: disable=unnecessary-dunder-call
1✔
174
        elif not self._iter.can_iterate:
×
175
            self._iter.__enter__()  # pylint: disable=unnecessary-dunder-call
×
176

177
        example_iterator = map(self._to_dict, iter(self._iter))
1✔
178
        if self._process_batch:
1✔
NEW
179
            yield from map(self._process_batch, example_iterator)
×
180
        else:
181
            yield from example_iterator
1✔
182

183
        self._iter.__exit__(None, None, None)
1✔
184
        self._iter = None
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc