• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 17561841648

08 Sep 2025 07:28PM UTC coverage: 88.513% (+0.6%) from 87.958%
17561841648

push

github

web-flow
Improve saving and loading of non-NumPy attributes (#227)

Improve saving of `int`, `str`, and `bytes` attributes. This PR mostly
fixes existing bugs. It might cause a regression and incompatibility
with older versions (version bumped).

- tfrec str is represented as bytes (which is consistent with
TensorFlow)
- int is represented as int64, which might be revisited later to enable
arbitrarily large integers
- bytes in npz are stored as contiguous array with indexes. This is not
the most robust implementation. Otherwise we would not allow dynamic
size (with `examples_per_shard > 1`). The other option of saving them as
padded size arrays was causing length problems when decoding.

---------

Co-authored-by: wsxrdv <111074929+wsxrdv@users.noreply.github.com>

208 of 215 new or added lines in 6 files covered. (96.74%)

3 existing lines in 2 files now uncovered.

2851 of 3221 relevant lines covered (88.51%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.77
/src/sedpack/io/npz/iterate_npz.py
1
# Copyright 2024-2025 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Iterate a npz shard. See sedpack.io.shard.shard_writer_np.ShardWriterNP
15
for more information how a npz shard is saved.
16
"""
17

18
import io
1✔
19
from pathlib import Path
1✔
20
from typing import AsyncIterator, Iterable
1✔
21

22
import aiofiles
1✔
23
import numpy as np
1✔
24

25
from sedpack.io.metadata import Attribute
1✔
26
from sedpack.io.shard import IterateShardBase
1✔
27
from sedpack.io.shard.iterate_shard_base import T
1✔
28
from sedpack.io.types import AttributeValueT, ExampleT
1✔
29
from sedpack.io.utils import func_or_identity
1✔
30

31

32
class IterateShardNP(IterateShardBase[T]):
1✔
33
    """Iterate a shard saved in the npz format.
34
    """
35

36
    @staticmethod
1✔
37
    def decode_attribute(
1✔
38
        attribute: Attribute,
39
        example_index: int,
40
        prefixed_name: str,
41
        shard_content: dict[str, list[AttributeValueT]],
42
    ) -> AttributeValueT:
43
        """Choose the correct way to decode the given attribute.
44

45
        Args:
46

47
          attribute (Attribute): Information about the attribute being decoded.
48

49
          example_index (int): Which example from this shard is being decoded.
50

51
          prefixed_name (str): For the case of `bytes` attributes we need to
52
          store them in continuous array otherwise variable length would require
53
          allowing pickling and result in a potential arbitrary code execution.
54
          This name is the prefix-sum encoded lengths of the attribute values.
55

56
          shard_content (dict[str, list[AttributeValueT]]): The shard values.
57
        """
58
        if attribute.dtype == "bytes":
1✔
59
            return IterateShardNP.decode_bytes_attribute(
1✔
60
                value=shard_content[attribute.name][0],
61
                indexes=shard_content[prefixed_name],
62
                attribute=attribute,
63
                index=example_index,
64
            )
65

66
        return IterateShardNP.decode_non_bytes_attribute(
1✔
67
            np_value=shard_content[attribute.name][example_index],
68
            attribute=attribute,
69
        )
70

71
    @staticmethod
1✔
72
    def decode_non_bytes_attribute(
1✔
73
        np_value: AttributeValueT,
74
        attribute: Attribute,
75
    ) -> AttributeValueT:
76
        match attribute.dtype:
1✔
77
            case "str":
1✔
78
                return str(np_value)
1✔
79
            case "bytes":
1✔
NEW
80
                raise ValueError("One needs to use decode_bytes_attribute")
×
81
            case "int":
1✔
82
                return int(np_value)
1✔
83
            case _:
1✔
84
                return np_value
1✔
85

86
    @staticmethod
1✔
87
    def decode_bytes_attribute(
1✔
88
        value: AttributeValueT,
89
        indexes: list[AttributeValueT],
90
        attribute: Attribute,
91
        index: int,
92
    ) -> AttributeValueT:
93
        """Decode a bytes attribute. We are saving the byte attributes as a
94
        continuous array across multiple examples and on the side we also save
95
        the indexes into this array.
96

97
        Args:
98

99
          value (AttributeValueT): The NumPy array of np.uint8 containing
100
          concatenated bytes values.
101

102
          indexes (list[AttributeValueT]): Indexes into this array.
103

104
          attribute (Attribute): The attribute description.
105

106
          index (int): Which example out of this shard to return.
107
        """
108
        if attribute.dtype != "bytes":
1✔
NEW
109
            raise ValueError("One needs to use decode_attribute")
×
110
        # Help with type-checking:
111
        my_value = np.array(value, np.uint8)
1✔
112
        my_indexes = np.array(indexes, np.int64)
1✔
113
        del value
1✔
114
        del indexes
1✔
115

116
        begin: int = my_indexes[index]
1✔
117
        end: int = my_indexes[index + 1]
1✔
118
        return bytes(my_value[begin:end])
1✔
119

120
    def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
1✔
121
        """Iterate a shard saved in the NumPy format npz.
122
        """
123
        # A prefix such that prepended it creates a new name without collision
124
        # with any attribute name.
125
        counting_prefix: str = "len" + "_" * max(
1✔
126
            len(attribute.name)
127
            for attribute in self.dataset_structure.saved_data_description)
128
        self._prefixed_names: dict[str, str] = {
1✔
129
            attribute.name: counting_prefix + attribute.name
130
            for attribute in self.dataset_structure.saved_data_description
131
        }
132

133
        shard_content: dict[str, list[AttributeValueT]] = np.load(
1✔
134
            file_path,
135
            allow_pickle=False,
136
        )
137

138
        # A given shard contains the same number of elements for each
139
        # attribute.
140
        elements: int
141
        first_attribute = self.dataset_structure.saved_data_description[0]
1✔
142
        if self._prefixed_names[first_attribute.name] in shard_content:
1✔
143
            elements = len(
1✔
144
                shard_content[self._prefixed_names[first_attribute.name]]) - 1
145
        else:
146
            elements = len(shard_content[first_attribute.name])
1✔
147

148
        for example_index in range(elements):
1✔
149
            yield {
1✔
150
                attribute.name:
151
                    IterateShardNP.decode_attribute(
152
                        attribute=attribute,
153
                        example_index=example_index,
154
                        prefixed_name=self._prefixed_names[attribute.name],
155
                        shard_content=shard_content,
156
                    )
157
                for attribute in self.dataset_structure.saved_data_description
158
            }
159

160
    # TODO(issue #85) fix and test async iterator typing
161
    async def iterate_shard_async(  # pylint: disable=invalid-overridden-method
1✔
162
        self,
163
        file_path: Path,
164
    ) -> AsyncIterator[ExampleT]:
165
        """Asynchronously iterate a shard saved in the NumPy format npz.
166
        """
167
        async with aiofiles.open(file_path, "rb") as f:
1✔
168
            content_bytes: bytes = await f.read()
1✔
169
            content_io = io.BytesIO(content_bytes)
1✔
170

171
        shard_content: dict[str, list[AttributeValueT]] = np.load(
1✔
172
            content_io,
173
            allow_pickle=False,
174
        )
175

176
        # A given shard contains the same number of elements for each
177
        # attribute.
178
        elements: int = 0
1✔
179
        for values in shard_content.values():
1✔
180
            elements = len(values)
1✔
181
            break
1✔
182

183
        for example_index in range(elements):
1✔
184
            yield {
1✔
185
                name: value[example_index]
186
                for name, value in shard_content.items()
187
            }
188

189
    def process_and_list(self, shard_file: Path) -> list[T]:
1✔
190
        process_record = func_or_identity(self.process_record)
1✔
191

192
        return [
1✔
193
            process_record(example)
194
            for example in self.iterate_shard(file_path=shard_file)
195
        ]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc