• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tableau / TabPy / 4694295226

pending completion
4694295226

Pull #595

github

GitHub
Merge a85607d7c into fad6807d4
Pull Request #595: TabPy Arrow Support

207 of 207 new or added lines in 7 files covered. (100.0%)

1317 of 2311 relevant lines covered (56.99%)

0.57 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

48.78
/tabpy/tabpy_server/app/arrow_server.py
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17

18

19
import ast
1✔
20
import logging
1✔
21
import threading
1✔
22
import time
1✔
23
import uuid
1✔
24

25
import pyarrow
1✔
26
import pyarrow.flight
1✔
27

28

29
logger = logging.getLogger('__main__.' + __name__)
1✔
30

31
class FlightServer(pyarrow.flight.FlightServerBase):
1✔
32
    def __init__(self, host="localhost", location=None,
1✔
33
                 tls_certificates=None, verify_client=False,
34
                 root_certificates=None, auth_handler=None, middleware=None):
35
        super(FlightServer, self).__init__(
1✔
36
            location, auth_handler, tls_certificates, verify_client,
37
            root_certificates, middleware)
38
        self.flights = {}
1✔
39
        self.host = host
1✔
40
        self.tls_certificates = tls_certificates
1✔
41
        self.location = location
1✔
42

43
    @classmethod
1✔
44
    def descriptor_to_key(self, descriptor):
45
        return (descriptor.descriptor_type.value, descriptor.command,
×
46
                tuple(descriptor.path or tuple()))
47

48
    def _make_flight_info(self, key, descriptor, table):
1✔
49
        if self.tls_certificates:
1✔
50
            location = pyarrow.flight.Location.for_grpc_tls(
×
51
                self.host, self.port)
52
        else:
53
            location = pyarrow.flight.Location.for_grpc_tcp(
1✔
54
                self.host, self.port)
55
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
1✔
56

57
        mock_sink = pyarrow.MockOutputStream()
1✔
58
        stream_writer = pyarrow.RecordBatchStreamWriter(
1✔
59
            mock_sink, table.schema)
60
        stream_writer.write_table(table)
1✔
61
        stream_writer.close()
1✔
62
        data_size = mock_sink.size()
1✔
63

64
        return pyarrow.flight.FlightInfo(table.schema,
1✔
65
                                         descriptor, endpoints,
66
                                         table.num_rows, data_size)
67

68
    def list_flights(self, context, criteria):
1✔
69
        for key, table in self.flights.items():
1✔
70
            if key[1] is not None:
1✔
71
                descriptor = \
×
72
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
73
            else:
74
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
1✔
75

76
            yield self._make_flight_info(key, descriptor, table)
1✔
77

78
    def get_flight_info(self, context, descriptor):
1✔
79
        key = FlightServer.descriptor_to_key(descriptor)
×
80
        logger.info(f"get_flight_info: key={key}")
×
81
        if key in self.flights:
×
82
            table = self.flights[key]
×
83
            return self._make_flight_info(key, descriptor, table)
×
84
        raise KeyError('Flight not found.')
×
85

86
    def do_put(self, context, descriptor, reader, writer):
1✔
87
        key = FlightServer.descriptor_to_key(descriptor)
×
88
        logger.info(f"do_put: key={key}")
×
89
        self.flights[key] = reader.read_all()
×
90

91
    def do_get(self, context, ticket):
1✔
92
        logger.info(f"do_get: ticket={ticket}")
×
93
        key = ast.literal_eval(ticket.ticket.decode())
×
94
        if key not in self.flights:
×
95
            logger.warn(f"do_get: key={key} not found")
×
96
            return None
×
97
        logger.info(f"do_get: returning key={key}")
×
98
        flight = self.flights.pop(key)
×
99
        return pyarrow.flight.RecordBatchStream(flight)
×
100

101
    def list_actions(self, context):
1✔
102
        return iter([
×
103
            ("getUniquePath", "Get a unique FlightDescriptor path to put data to."),
104
            ("clear", "Clear the stored flights."),
105
            ("shutdown", "Shut down this server."),
106
        ])
107

108
    def do_action(self, context, action):
1✔
109
        logger.info(f"do_action: action={action.type}")
×
110
        if action.type == "getUniquePath":
×
111
            uniqueId = str(uuid.uuid4())
×
112
            logger.info(f"getUniquePath id={uniqueId}")
×
113
            yield uniqueId.encode('utf-8')
×
114
        elif action.type == "clear":
×
115
            self._clear()
×
116
            # raise NotImplementedError(
117
            #     "{} is not implemented.".format(action.type))
118
        elif action.type == "healthcheck":
×
119
            pass
×
120
        elif action.type == "shutdown":
×
121
            self._clear()
×
122
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
×
123
            # Shut down on background thread to avoid blocking current
124
            # request
125
            threading.Thread(target=self._shutdown).start()
×
126
        else:
127
            raise KeyError("Unknown action {!r}".format(action.type))
×
128

129
    def _clear(self):
1✔
130
        """Clear the stored flights."""
131
        self.flights = {}
×
132

133
    def _shutdown(self):
1✔
134
        """Shut down after a delay."""
135
        logger.info("Server is shutting down...")
×
136
        time.sleep(2)
×
137
        self.shutdown()
×
138

139
def start(server):
1✔
140
    logger.info(f"Serving on {server.location}")
×
141
    server.serve()
×
142

143

144
if __name__ == '__main__':
1✔
145
    start()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc