• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 17561841648

08 Sep 2025 07:28PM UTC coverage: 88.513% (+0.6%) from 87.958%
17561841648

push

github

web-flow
Improve saving and loading of non-NumPy attributes (#227)

Improve saving of `int`, `str`, and `bytes` attributes. This PR mostly
fixes existing bugs. It might cause a regression and incompatibility
with older versions (version bumped).

- tfrec str is represented as bytes (which is consistent with
TensorFlow)
- int is represented as int64, which might be revisited later to enable
arbitrarily large integers
- bytes in npz are stored as contiguous array with indexes. This is not
the most robust implementation. Otherwise we would not allow dynamic
size (with `examples_per_shard > 1`). The other option of saving them as
padded size arrays was causing length problems when decoding.

---------

Co-authored-by: wsxrdv <111074929+wsxrdv@users.noreply.github.com>

208 of 215 new or added lines in 6 files covered. (96.74%)

3 existing lines in 2 files now uncovered.

2851 of 3221 relevant lines covered (88.51%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.23
/src/sedpack/io/shard/shard_writer_np.py
1
# Copyright 2024 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Dataset shard manipulation.
15

16
For information how to read and write TFRecord files see
17
https://www.tensorflow.org/tutorials/load_data/tfrecord
18
"""
19

20
from pathlib import Path
1✔
21

22
import numpy as np
1✔
23
from numpy import typing as npt
1✔
24

25
from sedpack.io.metadata import Attribute, DatasetStructure
1✔
26
from sedpack.io.types import AttributeValueT, CompressionT, ExampleT
1✔
27
from sedpack.io.shard.shard_writer_base import ShardWriterBase
1✔
28

29

30
class ShardWriterNP(ShardWriterBase):
1✔
31
    """Shard writing capabilities.
32
    """
33

34
    def __init__(self, dataset_structure: DatasetStructure,
1✔
35
                 shard_file: Path) -> None:
36
        """Collect information about a new shard.
37

38
        Args:
39

40
            dataset_structure (DatasetStructure): The structure of data being
41
            saved.
42

43
            shard_file (Path): Full path to the shard file.
44
        """
45
        assert dataset_structure.shard_file_type == "npz"
1✔
46

47
        super().__init__(
1✔
48
            dataset_structure=dataset_structure,
49
            shard_file=shard_file,
50
        )
51

52
        self._buffer: dict[str, list[AttributeValueT]] = {}
1✔
53

54
        # A prefix such that prepended it creates a new name without collision
55
        # with any attribute name.
56
        self._counting_prefix: str = "len" + "_" * max(
1✔
57
            len(attribute.name)
58
            for attribute in dataset_structure.saved_data_description)
59

60
    def _value_to_np(
1✔
61
        self,
62
        attribute: Attribute,
63
        value: AttributeValueT,
64
    ) -> npt.NDArray[np.generic] | str:
65
        match attribute.dtype:
1✔
66
            case "bytes":
1✔
NEW
67
                raise ValueError("Attributes bytes are saved extra")
×
68
            case "str":
1✔
69
                assert isinstance(value, str)
1✔
70
                return value
1✔
71
            case _:
1✔
72
                return np.copy(value)
1✔
73

74
    def _write(self, values: ExampleT) -> None:
1✔
75
        """Write an example on disk. Writing may be buffered.
76

77
        Args:
78

79
            values (dict[str, npt.NDArray[np.generic]]): Attribute values.
80
        """
81
        # Just buffer all values.
82
        if not self._buffer:
1✔
83
            self._buffer = {}
1✔
84

85
        for attribute in self.dataset_structure.saved_data_description:
1✔
86
            name = attribute.name
1✔
87
            value = values[name]
1✔
88

89
            if attribute.dtype != "bytes":
1✔
90
                current_values = self._buffer.get(name, [])
1✔
91
                current_values.append(
1✔
92
                    self._value_to_np(
93
                        attribute=attribute,
94
                        value=value,
95
                    ))
96
                self._buffer[name] = current_values
1✔
97
            else:
98
                # Extend and remember the length.  Attributes with dtype "bytes"
99
                # may have variable length. Handle this case. We need to avoid
100
                # two things:
101
                # - Having wrong length of the bytes array and ideally also
102
                # avoid padding.
103
                # - Using allow_pickle when saving since that could lead to code
104
                # execution when loading malicious dataset.
105
                # We prefix the attribute name by `len_?` such that the new name
106
                # is unique and tells us the lengths of the byte arrays.
107
                counts = self._buffer.get(self._counting_prefix + name, [0])
1✔
108
                counts.append(counts[-1] +
1✔
109
                              len(value)  # type: ignore[arg-type,operator]
110
                             )
111
                self._buffer[self._counting_prefix + name] = counts
1✔
112

113
                byte_list: list[list[int]]
114
                byte_list = self._buffer.get(  # type: ignore[assignment]
1✔
115
                    name, [[]])
116
                byte_list[0].extend(list(value)  # type: ignore[arg-type]
1✔
117
                                   )
118
                self._buffer[name] = byte_list  # type: ignore[assignment]
1✔
119

120
    def close(self) -> None:
1✔
121
        """Close the shard file(-s).
122
        """
123
        if not self._buffer:
1✔
124
            assert not self._shard_file.is_file()
×
125
            return
×
126

127
        # Deal properly with "bytes" attributes.
128
        for attribute in self.dataset_structure.saved_data_description:
1✔
129
            if attribute.dtype != "bytes":
1✔
130
                continue
1✔
131
            self._buffer[attribute.name] = [
1✔
132
                np.array(
133
                    self._buffer[attribute.name][0],
134
                    dtype=np.uint8,
135
                )
136
            ]
137

138
        # Write the buffer into a file. We should not need to allow_pickle while
139
        # saving (the default value is True). But on GitHub actions macos-13
140
        # runner the tests were failing while reading. The security concern
141
        # (code execution) should be more on the side of loading pickled data.
142
        match self.dataset_structure.compression:
1✔
143
            case "ZIP":
1✔
144
                np.savez_compressed(
1✔
145
                    str(self._shard_file),
146
                    **self._buffer,  # type: ignore[arg-type]
147
                )
148
            case "":
1✔
149
                np.savez(
1✔
150
                    str(self._shard_file),
151
                    **self._buffer,  # type: ignore[arg-type]
152
                )
UNCOV
153
            case _:
×
154
                # Default should never happen since ShardWriterBase checks that
155
                # the requested compression type is supported.
156
                raise ValueError(f"Unsupported compression type "
×
157
                                 f"{self.dataset_structure.compression} in "
158
                                 f"ShardWriterNP, supported values are: "
159
                                 f"{self.supported_compressions()}")
160

161
        self._buffer = {}
1✔
162
        assert self._shard_file.is_file()
1✔
163

164
    @staticmethod
1✔
165
    def supported_compressions() -> list[CompressionT]:
1✔
166
        """Return a list of supported compression types.
167
        """
168
        return ["ZIP", ""]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc