• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

tableau / TabPy / 4544451387

pending completion
4544451387

Pull #595

github

GitHub
Merge 6036c9260 into fad6807d4
Pull Request #595: Draft: TabPy Arrow Support

285 of 285 new or added lines in 5 files covered. (100.0%)

1292 of 2389 relevant lines covered (54.08%)

0.54 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

25.0
/tabpy/tabpy_server/app/arrow_server.py
1
# Licensed to the Apache Software Foundation (ASF) under one
2
# or more contributor license agreements.  See the NOTICE file
3
# distributed with this work for additional information
4
# regarding copyright ownership.  The ASF licenses this file
5
# to you under the Apache License, Version 2.0 (the
6
# "License"); you may not use this file except in compliance
7
# with the License.  You may obtain a copy of the License at
8
#
9
#   http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing,
12
# software distributed under the License is distributed on an
13
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
# KIND, either express or implied.  See the License for the
15
# specific language governing permissions and limitations
16
# under the License.
17

18
"""An example Flight Python server."""
1✔
19

20
import argparse
1✔
21
import ast
1✔
22
import logging
1✔
23
import threading
1✔
24
import time
1✔
25
import uuid
1✔
26

27
import pyarrow
1✔
28
import pyarrow.flight
1✔
29

30
from tabpy.tabpy_server.app.app_parameters import SettingsParameters
1✔
31

32

33
logger = logging.getLogger('__main__.' + __name__)
1✔
34

35
class FlightServer(pyarrow.flight.FlightServerBase):
1✔
36
    def __init__(self, host="localhost", location=None,
1✔
37
                 tls_certificates=None, verify_client=False,
38
                 root_certificates=None, auth_handler=None):
39
        super(FlightServer, self).__init__(
×
40
            location, auth_handler, tls_certificates, verify_client,
41
            root_certificates)
42
        self.flights = {}
×
43
        self.host = host
×
44
        self.tls_certificates = tls_certificates
×
45

46
    @classmethod
1✔
47
    def descriptor_to_key(self, descriptor):
48
        return (descriptor.descriptor_type.value, descriptor.command,
×
49
                tuple(descriptor.path or tuple()))
50

51
    def _make_flight_info(self, key, descriptor, table):
1✔
52
        if self.tls_certificates:
×
53
            location = pyarrow.flight.Location.for_grpc_tls(
×
54
                self.host, self.port)
55
        else:
56
            location = pyarrow.flight.Location.for_grpc_tcp(
×
57
                self.host, self.port)
58
        endpoints = [pyarrow.flight.FlightEndpoint(repr(key), [location]), ]
×
59

60
        mock_sink = pyarrow.MockOutputStream()
×
61
        stream_writer = pyarrow.RecordBatchStreamWriter(
×
62
            mock_sink, table.schema)
63
        stream_writer.write_table(table)
×
64
        stream_writer.close()
×
65
        data_size = mock_sink.size()
×
66

67
        return pyarrow.flight.FlightInfo(table.schema,
×
68
                                         descriptor, endpoints,
69
                                         table.num_rows, data_size)
70

71
    def list_flights(self, context, criteria):
1✔
72
        for key, table in self.flights.items():
×
73
            if key[1] is not None:
×
74
                descriptor = \
×
75
                    pyarrow.flight.FlightDescriptor.for_command(key[1])
76
            else:
77
                descriptor = pyarrow.flight.FlightDescriptor.for_path(*key[2])
×
78

79
            yield self._make_flight_info(key, descriptor, table)
×
80

81
    def get_flight_info(self, context, descriptor):
1✔
82
        key = FlightServer.descriptor_to_key(descriptor)
×
83
        logger.info(f"get_flight_info: key={key}")
×
84
        if key in self.flights:
×
85
            table = self.flights[key]
×
86
            return self._make_flight_info(key, descriptor, table)
×
87
        raise KeyError('Flight not found.')
×
88

89
    def do_put(self, context, descriptor, reader, writer):
1✔
90
        key = FlightServer.descriptor_to_key(descriptor)
×
91
        logger.info(f"do_put: key={key}")
×
92
        self.flights[key] = reader.read_all()
×
93

94
    def do_get(self, context, ticket):
1✔
95
        logger.info(f"do_get: ticket={ticket}")
×
96
        key = ast.literal_eval(ticket.ticket.decode())
×
97
        if key not in self.flights:
×
98
            logger.warn(f"do_get: key={key} not found")
×
99
            return None
×
100
        logger.info(f"do_get: returning key={key}")
×
101
        flight = self.flights.pop(key)
×
102
        return pyarrow.flight.RecordBatchStream(flight)
×
103

104
    def list_actions(self, context):
1✔
105
        return iter([
×
106
            ("getUniquePath", "Get a unique FileDescriptor path to put data to."),
107
            ("clear", "Clear the stored flights."),
108
            ("shutdown", "Shut down this server."),
109
        ])
110

111
    def do_action(self, context, action):
1✔
112
        logger.info(f"do_action: action={action.type}")
×
113
        if action.type == "getUniquePath":
×
114
            uniqueId = str(uuid.uuid4())
×
115
            logger.info(f"getUniquePath id={uniqueId}")
×
116
            yield uniqueId.encode('utf-8')
×
117
        elif action.type == "clear":
×
118
            self._clear()
×
119
            # raise NotImplementedError(
120
            #     "{} is not implemented.".format(action.type))
121
        elif action.type == "healthcheck":
×
122
            pass
×
123
        elif action.type == "shutdown":
×
124
            self._clear()
×
125
            yield pyarrow.flight.Result(pyarrow.py_buffer(b'Shutdown!'))
×
126
            # Shut down on background thread to avoid blocking current
127
            # request
128
            threading.Thread(target=self._shutdown).start()
×
129
        else:
130
            raise KeyError("Unknown action {!r}".format(action.type))
×
131

132
    def _clear(self):
1✔
133
        """Clear the stored flights."""
134
        self.flights = {}
×
135

136
    def _shutdown(self):
1✔
137
        """Shut down after a delay."""
138
        logger.info("Server is shutting down...")
×
139
        time.sleep(2)
×
140
        self.shutdown()
×
141

142
def _parse_args():
1✔
143
    parser = argparse.ArgumentParser()
×
144
    parser.add_argument("--host", type=str, default="localhost",
×
145
                        help="Address or hostname to listen on")
146
    parser.add_argument("--tls", nargs=2, default=None,
×
147
                        metavar=('CERTFILE', 'KEYFILE'),
148
                        help="Enable transport-level security")
149
    parser.add_argument("--verify_client", type=bool, default=False,
×
150
                        help="enable mutual TLS and verify the client if True")
151
    parser.add_argument("--config", type=str, default=None, help="should be ignored") # TODO: implement config
×
152

153
    return parser.parse_args()
×
154

155
def _get_tls_certificates(args):
1✔
156
    tls_certificates = []
×
157
    with open(args.tls[0], "rb") as cert_file:
×
158
        tls_cert_chain = cert_file.read()
×
159
    with open(args.tls[1], "rb") as key_file:
×
160
        tls_private_key = key_file.read()
×
161
    tls_certificates.append((tls_cert_chain, tls_private_key))
×
162
    return tls_certificates
×
163

164
def start(config):
1✔
165
    args = _parse_args()
×
166

167
    tls_certificates = None
×
168
    scheme = "grpc+tcp"
×
169
    if args.tls:
×
170
        scheme = "grpc+tls"
×
171
        tls_certificates = _get_tls_certificates(args)
×
172

173
    port = config.get(SettingsParameters.ArrowFlightPort)
×
174
    location = "{}://{}:{}".format(scheme, args.host, port)
×
175

176
    server = FlightServer(args.host, location,
×
177
                          tls_certificates=tls_certificates,
178
                          verify_client=args.verify_client)
179
    logger.info(f"Serving on {location}")
×
180
    server.serve()
×
181

182

183
if __name__ == '__main__':
1✔
184
    start()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc