• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

google / sedpack / 12982388578

08 Jan 2025 11:35AM UTC coverage: 86.728% (-0.1%) from 86.849%
12982388578

push

github

web-flow
Add mypy.ini (#86)

A copy from https://github.com/google/scaaml

Fix https://github.com/google/sedpack/issues/61

---------

Co-authored-by: Karel <karel@star-lab.xyz>
Co-authored-by: wsxrdv <111074929+wsxrdv@users.noreply.github.com>

111 of 133 new or added lines in 14 files covered. (83.46%)

13 existing lines in 5 files now uncovered.

2359 of 2720 relevant lines covered (86.73%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.52
/src/sedpack/io/flatbuffer/iterate.py
1
# Copyright 2024 Google LLC
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     https://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
"""Iterate a FlatBuffers shard. See src/sedpack/io/flatbuffer/shard.fbs
15
sedpack.io.shard.shard_writer_flatbuffer.ShardWriterFlatBuffer for more
16
information how it is saved.
17
"""
18

19
import logging
1✔
20
from pathlib import Path
1✔
21
from typing import AsyncIterator, Iterable
1✔
22

23
import aiofiles
1✔
24
import numpy as np
1✔
25
import numpy.typing as npt
1✔
26

27
from sedpack.io.compress import CompressedFile
1✔
28
from sedpack.io.metadata import Attribute
1✔
29
from sedpack.io.types import ExampleT
1✔
30
from sedpack.io.shard import IterateShardBase
1✔
31
from sedpack.io.shard.iterate_shard_base import T
1✔
32
from sedpack.io.utils import func_or_identity
1✔
33

34
# Autogenerated from src/sedpack/io/flatbuffer/shard.fbs
35
import sedpack.io.flatbuffer.shardfile.Attribute as fbapi_Attribute
1✔
36
import sedpack.io.flatbuffer.shardfile.Example as fbapi_Example
1✔
37
import sedpack.io.flatbuffer.shardfile.Shard as fbapi_Shard
1✔
38

39

40
class IterateShardFlatBuffer(IterateShardBase[T]):
1✔
41
    """Remember everything to be able to iterate shards. This can be pickled
42
    and passed as a callable object into another process.
43
    """
44

45
    def _iterate_content(self, content: bytes) -> Iterable[ExampleT]:
1✔
46
        shard: fbapi_Shard.Shard = fbapi_Shard.Shard.GetRootAs(content, 0)
1✔
47

48
        for example_id in range(shard.ExamplesLength()):
1✔
49
            maybe_example: fbapi_Example.Example | None = shard.Examples(
1✔
50
                example_id)
51
            if maybe_example is None:
1✔
NEW
52
                logger = logging.getLogger("sedpack.io.Dataset")
×
NEW
53
                logger.error("Unable to get an example, corrupted shard?")
×
NEW
54
                continue
×
55
            example: fbapi_Example.Example = maybe_example
1✔
56

57
            example_dictionary: ExampleT = {}
1✔
58

59
            for attribute_id, attribute in enumerate(
1✔
60
                    self.dataset_structure.saved_data_description):
61
                # No-copy fast retrieval, represented as bytes.
62
                # This is a manually written method which uses the fact
63
                # that we know what dtype to decode. It might be cleaner to do
64
                # this using a union. There are two caveats:
65
                # - FlatBuffers only support a subset of types we care about
66
                #   (e.g., float16 which is not included in
67
                #   flatbuffers/python/flatbuffers/number_types.py).
68
                # - Speed, since we first need to check the type for every
69
                #   attribute.
70
                # Bytearray representation. Little endian, just loaded.
71
                maybe_attribute_data: fbapi_Attribute.Attribute | None
72
                maybe_attribute_data = example.Attributes(attribute_id)
1✔
73
                if maybe_attribute_data is None:
1✔
NEW
74
                    logger = logging.getLogger("sedpack.io.Dataset")
×
NEW
75
                    logger.error("Unable to get an attribute, corrupted shard?")
×
NEW
76
                    break
×
77
                attribute_data: fbapi_Attribute.Attribute = maybe_attribute_data
1✔
78
                np_bytes = attribute_data.AttributeBytesAsNumpy()
1✔
79

80
                np_array = IterateShardFlatBuffer.decode_array(
1✔
81
                    np_bytes=np_bytes,
82
                    attribute=attribute,
83
                )
84

85
                # Copy otherwise the arrays are immutable and keep the whole
86
                # file content from being garbage collected.
87
                np_array = np.copy(np_array)
1✔
88

89
                example_dictionary[attribute.name] = np_array
1✔
90

91
            yield example_dictionary
1✔
92

93
    @staticmethod
1✔
94
    def decode_array(np_bytes: npt.NDArray[np.uint8],
1✔
95
                     attribute: Attribute,
96
                     batch_size: int = 0) -> npt.NDArray[np.generic]:
97
        """Decode an array. See `sedpack.io.shard.shard_writer_flatbuffer
98
        .ShardWriterFlatBuffer.save_numpy_vector_as_bytearray`
99
        for format description. The code tries to avoid unnecessary copies.
100

101
        Args:
102

103
          np_bytes (np.ndarray): The bytes as an np.array of bytes.
104

105
          attribute (Attribute): Description of the final array (dtype and
106
          shape).
107

108
          batch_size (int): If `batch_size` is larger than zero we received a
109
          batch of these attributes. In case when `batch_size == -1` the
110
          `np.reshape` auto-deduces the dimension. Otherwise we received
111
          exactly one value of this attribute.
112

113
        Returns: the parsed np.ndarray of the correct dtype and shape.
114
        """
115
        dt = np.dtype(attribute.dtype)
1✔
116
        # FlatBuffers are little-endian. There is no byteswap by
117
        # `np.frombuffer` but the array will be interpreted correctly.
118
        dt = dt.newbyteorder("<")
1✔
119
        np_array = np.frombuffer(
1✔
120
            buffer=np_bytes,  # a view into the buffer, not a copy
121
            dtype=dt,
122
        )
123

124
        # Reshape if needed.
125
        if batch_size > 0 or batch_size == -1:
1✔
126
            np_array = np_array.reshape((batch_size, *attribute.shape))
×
127
        else:
128
            np_array = np_array.reshape(attribute.shape)
1✔
129

130
        return np_array
1✔
131

132
    def iterate_shard(self, file_path: Path) -> Iterable[ExampleT]:
1✔
133
        """Iterate a shard.
134
        """
135
        # Read then decompress (nice for benchmarking).
136
        with open(file_path, "rb") as f:
1✔
137
            content: bytes = f.read()
1✔
138
        content = CompressedFile(
1✔
139
            self.dataset_structure.compression).decompress(content)
140
        yield from self._iterate_content(content=content)
1✔
141

142
    # TODO(issue #85) fix and test async iterator typing
143
    async def iterate_shard_async(  # pylint: disable=invalid-overridden-method
1✔
144
        self,
145
        file_path: Path,
146
    ) -> AsyncIterator[ExampleT]:
147
        """Asynchronously iterate a shard.
148
        """
149
        async with aiofiles.open(file_path, "rb") as f:
1✔
150
            content = await f.read()
1✔
151
            content = CompressedFile(
1✔
152
                self.dataset_structure.compression).decompress(content)
153

154
        for example in self._iterate_content(content=content):
1✔
155
            yield example
1✔
156

157
    def process_and_list(self, shard_file: Path) -> list[T]:
1✔
158
        """Return a list of processed examples. Used as a function call in a
159
        different process. Returning a list as opposed to an iterator allows to
160
        do all work in another process and all that needs to be done is a
161
        memory copy between processes.
162

163
        TODO think of a way to avoid copying memory between processes.
164

165
        Args:
166

167
            shard_file (Path): Path to the shard file.
168
        """
169
        process_record = func_or_identity(self.process_record)
1✔
170

171
        return [
1✔
172
            process_record(example)
173
            for example in self.iterate_shard(shard_file)
174
        ]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc