• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 18842925384

27 Oct 2025 01:37PM UTC coverage: 88.742% (-0.08%) from 88.825%
18842925384

Pull #257

github

web-flow
Merge 8af3dc9bd into db9e72417
Pull Request #257: Implement balancing shard info iterator

78 of 91 new or added lines in 2 files covered. (85.71%)

3019 of 3402 relevant lines covered (88.74%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.3
/src/sedpack/io/shard_info_iterator/balanced_iterator.py
1
# Copyright 2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""The default shard info iterator."""
15
from collections import defaultdict
1✔
16
from collections.abc import Hashable
1✔
17
import heapq
1✔
18
import itertools
1✔
19
import json
1✔
20
import logging
1✔
21
from pathlib import Path
1✔
22

23
from typing import Callable, Iterator
1✔
24

25
from sedpack.io.metadata import DatasetInfo
1✔
26
from sedpack.io.shard_file_metadata import ShardInfo
1✔
27
from sedpack.io.shard_info_iterator.shard_info_iterator import ShardInfoIterator
1✔
28
from sedpack.io.types import SplitT
1✔
29
from sedpack.io.itertools import shuffle_buffer
1✔
30

31

32
class _SingleLevelBalancer:
1✔
33
    """Take a bunch of iterators of `ShardInfo` and interleave them such that
34
    the number of seen examples from each of them is roughly the same (or
35
    weighted). When one or more of the iterators are exhausted continue until
36
    all of them are exhausted.
37
    """
38

39
    def __init__(
1✔
40
        self,
41
        iterators: list[Iterator[tuple[float, ShardInfo]]],
42
    ) -> None:
43
        """Initialize the balancing.
44

45
        Args:
46

47
          iterators (list[Iterator[tuple[float, ShardInfo]]]): The iterators to
48
          be interleaved fairly. The float is interpreted as the `weight`.
49
          Meaning each example counts for `weight`.
50
        """
51
        self.iterators: list[Iterator[tuple[float, ShardInfo]]] = [
1✔
52
            iter(i) for i in iterators
53
        ]
54
        self.balancing_heap: list[tuple[float, int]] = [
1✔
55
            # (weighted seen examples, id of the iterator)
56
            (0.0, i) for i in range(len(iterators))
57
        ]
58

59
    def __iter__(self) -> Iterator[ShardInfo]:
1✔
60
        """Return the shard information iterator (reentrant).
61
        """
62
        return self
1✔
63

64
    def __next__(self) -> ShardInfo:
1✔
65
        """Return the next `ShardInfo` and the corresponding weight.
66
        """
67
        while self.balancing_heap:
1✔
68
            seen_examples, iterator_id = heapq.heappop(self.balancing_heap)
1✔
69
            try:
1✔
70
                weight, shard_info = next(self.iterators[iterator_id])
1✔
71
                heapq.heappush(
1✔
72
                    self.balancing_heap,
73
                    (
74
                        seen_examples +
75
                        (weight * shard_info.number_of_examples),
76
                        iterator_id,
77
                    ),
78
                )
79
                return shard_info
1✔
80
            except StopIteration:
1✔
81
                pass
1✔
82

83
        raise StopIteration
1✔
84

85

86
def _split_balancing(
1✔
87
    shard_list: list[ShardInfo],
88
    balance_by: tuple[Callable[[ShardInfo], Hashable], ...],
89
    repeat: bool,
90
    shuffle: int,
91
) -> Iterator[ShardInfo]:
92
    """Balance in a specified order.
93

94
    Args:
95

96
      shard_list (list[ShardInfo]): The list of shards to be balanced.
97

98
      balance_by (tuple[Callable[[ShardInfo], Hashable], ...]): The list of
99
      priority of balancing. The first will be the most important to be
100
      balanced. If this callable is an object with a `weight(self, shard_info)
101
      -> float` method then each example from this shard counts for `weight`.
102
      Otherwise each example counts as 1. Meaning that setting the weight to
103
      0.5 will result into seeing twice as many of these shards. Be careful
104
      with weights of zero and negative.
105

106
      repeat (bool): Should the `ShardInfo` be repeated indefinitely?
107

108
      shuffle (int): The size of shuffle buffer in the lowest level iteration.
109

110
    Returns: an iterator of the `ShardInfo` objects.
111
    """
112
    if not balance_by:
1✔
113
        inner_iterator: Iterator[ShardInfo]
114
        if repeat:
1✔
115
            inner_iterator = itertools.cycle(shard_list)
1✔
116
        else:
117
            inner_iterator = iter(shard_list)
1✔
118
        if shuffle:
1✔
119
            inner_iterator = iter(
1✔
120
                shuffle_buffer(
121
                    iterable=inner_iterator,
122
                    buffer_size=shuffle,
123
                ))
124
        return inner_iterator
1✔
125

126
    classes: defaultdict[Hashable, list[ShardInfo]] = defaultdict(list)
1✔
127
    current_balancer: Callable[[ShardInfo], Hashable] = balance_by[0]
1✔
128

129
    for shard_info in shard_list:
1✔
130
        classes[current_balancer(shard_info)].append(shard_info)
1✔
131

132
    iterators: list[Iterator[ShardInfo]] = [
1✔
133
        _split_balancing(
134
            shard_list=v,
135
            balance_by=balance_by[1:],
136
            repeat=repeat,
137
            shuffle=shuffle,
138
        ) for v in classes.values()
139
    ]
140

141
    # How do we get weights from the current balancer.
142
    if (hasattr(current_balancer, "weight") and
1✔
143
            callable(current_balancer.weight)):
144

145
        def prepend_weight(shard_info: ShardInfo) -> tuple[float, ShardInfo]:
1✔
146
            return (
1✔
147
                current_balancer.weight(shard_info),
148
                shard_info,
149
            )
150
    else:
151

152
        def prepend_weight(shard_info: ShardInfo) -> tuple[float, ShardInfo]:
1✔
153
            return (
1✔
154
                1.0,  # Default just count examples.
155
                shard_info,
156
            )
157

158
    return _SingleLevelBalancer(
1✔
159
        iterators=[map(prepend_weight, i) for i in iterators])
160

161

162
class BalancedShardInfoIterator(ShardInfoIterator):
1✔
163
    """Iterate shards of a dataset.
164
    """
165

166
    def __init__(
1✔
167
            self,
168
            *,
169
            dataset_path: Path,
170
            dataset_info: DatasetInfo,
171
            split: SplitT | None,
172
            repeat: bool = True,
173
            shard_filter: Callable[[ShardInfo], bool] | None = None,
174
            shuffle: int = 0,
175
            balance_by: tuple[Callable[[ShardInfo], Hashable], ...] = (),
176
    ) -> None:
177
        """Initialize shard information iteration.
178

179
        Args:
180

181
          dataset_path (Path): The path to the dataset directory.
182

183
          dataset_info (DatasetInfo): The information about the iterated
184
          dataset.
185

186
          split (SplitT | None): Which split to iterate or all if set to None.
187

188
          repeat (bool): Should we cycle indefinitely? You most likely want to
189
          set this to True especially when using `balance_by` since otherwise
190
          the beginning will be balanced but whenever one type of shards will
191
          be less prevalent it will not appear towards the end.
192

193
          shard_filter (Callable[[ShardInfo], bool] | None): If present this is
194
          a function taking the ShardInfo and returning True if the shard shall
195
          be used for traversal and False otherwise.
196

197
          shuffle (int): When set to 0 the iteration is deterministic otherwise
198
          shuffle the shards with a shuffle buffer of at least `shuffle`
199
          elements. Current implementation shuffles all shard information.
200

201
          balance_by (tuple[Callable[[ShardInfo], Hashable], ...]): The list of
202
          priority of balancing. The first will be the most important to be
203
          balanced. If this callable is an object with a `weight(self,
204
          shard_info) -> float` method then each example from this shard counts
205
          for `weight`.  Otherwise each example counts as 1. Meaning that
206
          setting the weight to 0.5 will result into seeing twice as many of
207
          these shards. Be careful with weights of zero and negative.
208
        """
NEW
209
        super().__init__(
×
210
            dataset_path=dataset_path,
211
            dataset_info=dataset_info,
212
            split=split,
213
            repeat=repeat,
214
        )
215

NEW
216
        self.shuffle: int = shuffle
×
217

218
        # Logging for non-trivial operations such as filtering custom metadata.
NEW
219
        self._logger = logging.getLogger(__name__)
×
220

221
        # Cache the list of shards.
NEW
222
        shard_list: list[ShardInfo] = list(
×
223
            ShardInfoIterator(
224
                dataset_path=dataset_path,
225
                dataset_info=dataset_info,
226
                split=split,
227
                repeat=False,
228
            ))
229

230
        # Filter if needed.
NEW
231
        if shard_filter:
×
NEW
232
            shard_list = [
×
233
                shard_info for shard_info in shard_list
234
                if shard_filter(shard_info)
235
            ]
236

NEW
237
            kept_metadata: set[str] = {
×
238
                json.dumps(
239
                    s.custom_metadata,
240
                    sort_keys=True,
241
                ) for s in shard_list
242
            }
NEW
243
            self._logger.info(
×
244
                "Filtered shards with custom metadata: %s from split: %s",
245
                kept_metadata,
246
                split,
247
            )
248

249
        # Cached number of shards.
NEW
250
        self._number_of_shards: int = len(shard_list)
×
251

252
        # First balance by file type, and each file type balance by source.
NEW
253
        self._shard_info_iter = _split_balancing(
×
254
            shard_list=shard_list,
255
            balance_by=balance_by,
256
            repeat=repeat,
257
            shuffle=shuffle,
258
        )
259

260
    def number_of_shards(self) -> int:
1✔
261
        """Return the number of distinct shards that are iterated. When
262
        repeated this method still returns a finite answer.
263
        """
NEW
264
        return self._number_of_shards
×
265

266
    def __iter__(self) -> Iterator[ShardInfo]:
1✔
267
        """Return the shard information iterator (reentrant).
268
        """
NEW
269
        return self
×
270

271
    def __next__(self) -> ShardInfo:
1✔
272
        """Get the next item.
273
        """
NEW
274
        return next(self._shard_info_iter)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc