Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

vanvalenlab / kiosk-client / 402

5 Jun 2020 - 19:09 coverage: 93.461% (-3.7%) from 97.148%
402

Pull #50

travis-ci-com

9181eb84f9c35729a3bad740fb7f9d93?size=18&default=identiconweb-flow
Merge a814a240b into 30e9812e4
Pull Request #50: JobManagers download zipped image results for each job.

7 of 31 new or added lines in 1 file covered. (22.58%)

586 of 627 relevant lines covered (93.46%)

4.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.35
/kiosk_client/manager.py
1
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
2
# Technology (Caltech), with support from the Paul Allen Family Foundation,
3
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
4
# All rights reserved.
5
#
6
# Licensed under a modified Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.github.com/vanvalenlab/kiosk-client/LICENSE
11
#
12
# The Work provided may be used for non-commercial academic purposes only.
13
# For any other use of the Work, including commercial use, please contact:
14
# vanvalenlab@gmail.com
15
#
16
# Neither the name of Caltech nor the names of its contributors may be used
17
# to endorse or promote products derived from this software without specific
18
# prior written permission.
19
#
20
# Unless required by applicable law or agreed to in writing, software
21
# distributed under the License is distributed on an "AS IS" BASIS,
22
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
# See the License for the specific language governing permissions and
24
# limitations under the License.
25
# ============================================================================
26
"""Manager class used to create and manage jobs"""
5×
27
from __future__ import absolute_import
5×
28
from __future__ import division
5×
29
from __future__ import print_function
5×
30

31
import io
5×
32
import json
5×
33
import logging
5×
34
import os
5×
35
import shutil
5×
36
import timeit
5×
37
import uuid
5×
38
import zipfile
5×
39

40
import requests
5×
41
from google.cloud import storage as google_storage
5×
42
from twisted.internet import defer, reactor
5×
43
from twisted.web.client import HTTPConnectionPool
5×
44

45
from kiosk_client.job import Job
5×
46
from kiosk_client.utils import iter_image_files, sleep, strip_bucket_prefix
5×
47
from kiosk_client import settings
5×
48

49
from kiosk_client.cost import CostGetter
5×
50

51

52
class JobManager(object):
5×
53
    """Manages many DeepCell Kiosk jobs.
54

55
    Args:
56
        host (str): public IP address of the DeepCell Kiosk cluster.
57
        job_type (str): DeepCell Kiosk job type (e.g. "segmentation").
58
        upload_prefix (str): upload all files to this folder in the bucket.
59
        refresh_rate (int): seconds between each manager status check.
60
        update_interval (int): seconds between each job status refresh.
61
        expire_time (int): seconds until finished jobs are expired.
62
        start_delay (int): delay between each job, in seconds.
63
    """
64

65
    def __init__(self, host, job_type, **kwargs):
5×
66
        self.logger = logging.getLogger(str(self.__class__.__name__))
5×
67
        self.created_at = timeit.default_timer()
5×
68
        self.all_jobs = []
5×
69

70
        host = str(host)
5×
71
        if not any(host.startswith(x) for x in ('http://', 'https://')):
5×
72
            host = 'http://{}'.format(host)
5×
73

74
        self.host = host
5×
75
        self.job_type = job_type
5×
76

77
        model = kwargs.get('model', '')
5×
78
        if model:
5×
79
            try:
5×
80
                model_name, model_version = str(model).split(':')
5×
81
                model_version = int(model_version)
5×
82
            except Exception as err:
5×
83
                self.logger.error('Invalid model name, must be of the form '
5×
84
                                  '"ModelName:Version", for example "model:0".')
85
                raise err
5×
86
        else:
87
            model_name, model_version = '', ''
5×
88

89
        self.model_name = model_name
5×
90
        self.model_version = model_version
5×
91

92
        data_scale = str(kwargs.get('data_scale', ''))
5×
93
        if data_scale:
5×
94
            try:
5×
95
                data_scale = float(data_scale)
5×
96
            except ValueError:
5×
97
                raise ValueError('data_scale must be a number.')
5×
98
        self.data_scale = data_scale
5×
99

100
        data_label = str(kwargs.get('data_label', ''))
5×
101
        if data_label:
5×
102
            try:
5×
103
                data_label = int(data_label)
5×
104
            except ValueError:
5×
105
                raise ValueError('data_label must be an integer.')
5×
106
        self.data_label = data_label
5×
107

108
        self.preprocess = kwargs.get('preprocess', '')
5×
109
        self.postprocess = kwargs.get('postprocess', '')
5×
110
        self.upload_prefix = kwargs.get('upload_prefix', 'uploads')
5×
111
        self.upload_prefix = strip_bucket_prefix(self.upload_prefix)
5×
112
        self.refresh_rate = int(kwargs.get('refresh_rate', 10))
5×
113
        self.update_interval = kwargs.get('update_interval', 10)
5×
114
        self.expire_time = kwargs.get('expire_time', 3600)
5×
115
        self.start_delay = kwargs.get('start_delay', 0.1)
5×
116
        self.bucket = kwargs.get('storage_bucket')
5×
117
        self.upload_results = kwargs.get('upload_results', False)
5×
118
        self.download_results = kwargs.get('download_results', False)
5×
119
        self.calculate_cost = kwargs.get('calculate_cost', False)
5×
120

121
        # initializing cost estimation workflow
122
        self.cost_getter = CostGetter()
5×
123

124
        self.sleep = sleep  # allow monkey-patch
5×
125

126
        # twisted configuration
127
        self.pool = HTTPConnectionPool(reactor, persistent=True)
5×
128
        self.pool.maxPersistentPerHost = settings.CONCURRENT_REQUESTS_PER_HOST
5×
129
        self.pool.retryAutomatically = False
5×
130

131
    def upload_file(self, filepath, acl='publicRead',
5×
132
                    hash_filename=True, prefix=None):
133
        prefix = self.upload_prefix if prefix is None else prefix
!
134
        start = timeit.default_timer()
!
135
        storage_client = google_storage.Client()
!
136

137
        self.logger.debug('Uploading %s.', filepath)
!
138
        if hash_filename:
!
139
            _, ext = os.path.splitext(filepath)
!
140
            dest = '{}{}'.format(uuid.uuid4().hex, ext)
!
141
        else:
142
            dest = os.path.basename(filepath)
!
143

144
        bucket = storage_client.get_bucket(self.bucket)
!
145
        blob = bucket.blob(os.path.join(prefix, dest))
!
146
        blob.upload_from_filename(filepath, predefined_acl=acl)
!
147
        self.logger.debug('Uploaded %s to %s in %s seconds.',
!
148
                          filepath, dest, timeit.default_timer() - start)
149
        return dest
!
150

151
    def make_job(self, filepath):
5×
152
        return Job(filepath=filepath,
5×
153
                   host=self.host,
154
                   model_name=self.model_name,
155
                   model_version=self.model_version,
156
                   job_type=self.job_type,
157
                   data_scale=self.data_scale,
158
                   data_label=self.data_label,
159
                   postprocess=self.postprocess,
160
                   upload_prefix=self.upload_prefix,
161
                   update_interval=self.update_interval,
162
                   expire_time=self.expire_time,
163
                   pool=self.pool)
164

165
    def get_completed_job_count(self):
5×
166
        created, complete, failed, expired = 0, 0, 0, 0
5×
167

168
        statuses = {}
5×
169

170
        for j in self.all_jobs:
5×
171
            expired += int(j.is_expired)  # true mark of being done
5×
172
            complete += int(j.is_summarized)
5×
173
            created += int(j.job_id is not None)
5×
174

175
            if j.status is not None:
5×
176
                if j.status not in statuses:
5×
177
                    statuses[j.status] = 1
5×
178
                else:
179
                    statuses[j.status] += 1
5×
180

181
            if j.failed:
5×
182
                j.restart(delay=self.start_delay * failed)
5×
183

184
            # # TODO: patched! "done" jobs can get stranded before summarization
185
            # if j.status == 'done' and not j.is_summarized:
186
            #     j.summarize()
187
            #
188
            # # TODO: patched! sometimes jobs don't get expired?
189
            # elif j.status == 'done' and j.is_summarized and not j.is_expired:
190
            #     j.expire()
191

192
        self.logger.info('%s created; %s finished; %s summarized; '
5×
193
                         '%s; %s jobs total', created, expired, complete,
194
                         '; '.join('%s %s' % (v, k)
195
                                   for k, v in statuses.items()),
196
                         len(self.all_jobs))
197

198
        if len(self.all_jobs) - expired <= 25:
5×
199
            for j in self.all_jobs:
5×
200
                if not j.is_expired:
5×
201
                    self.logger.info('Waiting on key `%s` with status %s',
5×
202
                                     j.job_id, j.status)
203

204
        return expired
5×
205

206
    @defer.inlineCallbacks
5×
207
    def _stop(self):
208
        yield reactor.stop()  # pylint: disable=no-member
!
209

210
    @defer.inlineCallbacks
5×
211
    def check_job_status(self):
212
        complete = -1  # initialize comparison value
5×
213

214
        while complete != len(self.all_jobs):
5×
215
            yield self.sleep(self.refresh_rate)
5×
216

217
            complete = self.get_completed_job_count()  # synchronous
5×
218

219
        self.summarize()  # synchronous
5×
220

221
        yield self._stop()
5×
222

223
    def download_result_files(self, output_filepath):
5×
224
        """Download all output image files"""
225
        # TODO: resolve treq SSL issue, replace requests with treq.
NEW
226
        dir_name = os.path.splitext(os.path.basename(output_filepath))[0]
!
NEW
227
        results_dir = os.path.join(settings.OUTPUT_DIR, dir_name)
!
NEW
228
        os.mkdir(results_dir)
!
229

NEW
230
        for job in self.all_jobs:
!
NEW
231
            try:
!
NEW
232
                start = timeit.default_timer()
!
NEW
233
                with requests.get(job.output_url, stream=True) as r:
!
NEW
234
                    r.raise_for_status()
!
NEW
235
                    fileobj = io.BytesIO(r.content)
!
NEW
236
                    with zipfile.ZipFile(fileobj) as zf:
!
NEW
237
                        for f in zf.namelist():
!
NEW
238
                            basename = os.path.basename(f)
!
NEW
239
                            filepath = os.path.join(results_dir, basename)
!
NEW
240
                            with open(filepath, 'wb') as outfile:
!
NEW
241
                                shutil.copyfileobj(zf.open(f), outfile)
!
NEW
242
                            end = timeit.default_timer()
!
NEW
243
                            self.logger.debug('Saved %s in %s s.',
!
244
                                              filepath, end - start)
245

NEW
246
            except Exception as err:
!
NEW
247
                self.logger.error('Could not download %s due to %s. '
!
248
                                  'Please manually download this file.',
249
                                  job.output_url, err)
250

251
    def summarize(self):
5×
252
        time_elapsed = timeit.default_timer() - self.created_at
5×
253
        self.logger.info('Finished %s jobs in %s seconds.',
5×
254
                         len(self.all_jobs), time_elapsed)
255

256
        # add cost and timing data to json output
257
        cpu_cost, gpu_cost, total_cost = '', '', ''
5×
258
        if self.calculate_cost:
5×
259
            try:
5×
260
                cpu_cost, gpu_cost, total_cost = self.cost_getter.finish()
5×
261
            except Exception as err:  # pylint: disable=broad-except
5×
262
                self.logger.error('Encountered %s while getting cost data: %s',
5×
263
                                  type(err).__name__, err)
264

265
        jsondata = {
5×
266
            'cpu_node_cost': cpu_cost,
267
            'gpu_node_cost': gpu_cost,
268
            'total_node_and_networking_costs': total_cost,
269
            'start_delay': self.start_delay,
270
            'num_jobs': len(self.all_jobs),
271
            'time_elapsed': time_elapsed,
272
            'job_data': [j.json() for j in self.all_jobs]
273
        }
274

275
        output_filepath = '{}{}jobs_{}delay_{}.json'.format(
5×
276
            '{}gpu_'.format(settings.NUM_GPUS) if settings.NUM_GPUS else '',
277
            len(self.all_jobs), self.start_delay, uuid.uuid4().hex)
278
        output_filepath = os.path.join(settings.OUTPUT_DIR, output_filepath)
5×
279

280
        if self.download_results:
5×
NEW
281
            try:
!
NEW
282
                self.download_result_files(output_filepath)
!
NEW
283
            except Exception as err:  # pylint: disable=broad-except
!
NEW
284
                self.logger.error(err)
!
NEW
285
                self.logger.error('Could not download all results.')
!
286

287
        with open(output_filepath, 'w') as jsonfile:
5×
288
            json.dump(jsondata, jsonfile, indent=4)
5×
289
            self.logger.info('Wrote job data as JSON to %s.', output_filepath)
5×
290

291
        if self.upload_results:
5×
292
            try:
5×
293
                _ = self.upload_file(output_filepath,
5×
294
                                     hash_filename=False,
295
                                     prefix='output')
296
            except Exception as err:  # pylint: disable=broad-except
5×
297
                self.logger.error(err)
5×
298
                self.logger.error('Could not upload output file to bucket. '
5×
299
                                  'Copy this file from the docker container to '
300
                                  'keep the data.')
301

302
    def run(self, *args, **kwargs):
5×
303
        raise NotImplementedError
304

305

306
class BenchmarkingJobManager(JobManager):
5×
307
    # pylint: disable=arguments-differ
308

309
    @defer.inlineCallbacks
5×
310
    def run(self, filepath, count, upload=False):
5×
311
        self.logger.info('Benchmarking %s jobs of file `%s`', count, filepath)
5×
312

313
        for i in range(count):
5×
314

315
            job = self.make_job(filepath)
5×
316

317
            self.all_jobs.append(job)
5×
318

319
            # stagger the delay seconds; if upload it will be staggered already
320
            job.start(delay=self.start_delay * i * int(not upload),
5×
321
                      upload=upload)
322

323
            yield self.sleep(self.start_delay * upload)
5×
324

325
            if upload:
5×
326
                self.get_completed_job_count()  # log during uploading
5×
327

328
        yield self.check_job_status()
5×
329

330

331
class BatchProcessingJobManager(JobManager):
5×
332
    # pylint: disable=arguments-differ
333

334
    @defer.inlineCallbacks
5×
335
    def run(self, filepath):
336
        self.logger.info('Benchmarking all image/zip files in `%s`', filepath)
5×
337

338
        for i, f in enumerate(iter_image_files(filepath)):
5×
339
            job = self.make_job(f)
5×
340
            self.all_jobs.append(job)
5×
341
            # stagger the delay seconds
342
            job.start(delay=self.start_delay * i, upload=True)
5×
343

344
        yield self.check_job_status()
5×
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2023 Coveralls, Inc