• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cenkalti / kuyruk / 9402813842

06 Jun 2024 02:33PM UTC coverage: 90.953% (-0.2%) from 91.108%
9402813842

push

github

web-flow
Add ability to set priority for workers (#81)

93 of 101 branches covered (92.08%)

754 of 829 relevant lines covered (90.95%)

1.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.51
kuyruk/worker.py
1
import os
2✔
2
import sys
2✔
3
import json
2✔
4
import platform
2✔
5
import socket
2✔
6
import signal
2✔
7
import logging
2✔
8
import logging.config
2✔
9
import threading
2✔
10
import traceback
2✔
11
import argparse
2✔
12
import multiprocessing
2✔
13
from time import monotonic
2✔
14
from typing import Dict, Any, Tuple, Optional, Type, cast, List
2✔
15

16
import amqp
2✔
17

18
from kuyruk import importer, signals
2✔
19
from kuyruk.kuyruk import Kuyruk
2✔
20
from kuyruk.task import Task
2✔
21
from kuyruk.heartbeat import Heartbeat
2✔
22
from kuyruk.exceptions import Reject, Discard, HeartbeatError, ExcInfoType
2✔
23

24
logger = logging.getLogger(__name__)
2✔
25

26

27
class Worker:
2✔
28
    """Consumes tasks from queues and runs them.
29

30
    :param app: An instance of :class:`~kuyruk.Kuyruk`
31
    :param args: Command line arguments
32

33
    """
34

35
    def __init__(self, app: Kuyruk, args: argparse.Namespace) -> None:
2✔
36
        self.kuyruk = app
2✔
37

38
        if not args.queues:
2✔
39
            args.queues = ['kuyruk']
2✔
40

41
        def add_host(queue: str) -> str:
2✔
42
            if queue.endswith('.localhost'):
2✔
43
                queue = queue.rsplit('.localhost')[0]
2✔
44
                return "%s.%s" % (queue, self._hostname)
2✔
45
            else:
46
                return queue
2✔
47

48
        self._hostname = socket.gethostname()
2✔
49
        self.queues = [add_host(q) for q in args.queues]
2✔
50
        self._tasks: Dict[Tuple[str, str], Task] = {}
2✔
51
        self.shutdown_pending = threading.Event()
2✔
52
        self.consuming = False
2✔
53
        self.current_task: Optional[Task] = None
2✔
54
        self.current_args: Optional[Tuple] = None
2✔
55
        self.current_kwargs: Optional[Dict[str, Any]] = None
2✔
56
        self._heartbeat_error: Optional[Exception]
2✔
57

58
        self._started_at: Optional[float] = None
2✔
59
        self._pid = os.getpid()
2✔
60

61
        self._logging_level = app.config.WORKER_LOGGING_LEVEL
2✔
62
        if args.logging_level is not None:
2✔
63
            self._logging_level = args.logging_level
2✔
64

65
        self._max_run_time = app.config.WORKER_MAX_RUN_TIME
2✔
66
        if args.max_run_time is not None:
2✔
67
            self._max_run_time = args.max_run_time
2✔
68

69
        self._max_load = app.config.WORKER_MAX_LOAD
2✔
70
        if args.max_load is not None:
2✔
71
            self._max_load = args.max_load
×
72
        if self._max_load == -1:
2✔
73
            self._max_load == multiprocessing.cpu_count()
×
74

75
        self._priority = app.config.WORKER_PRIORITY
2✔
76
        if args.priority is not None:
2✔
77
            self._priority = args.priority
×
78

79
        self._reconnect_interval = app.config.WORKER_RECONNECT_INTERVAL
2✔
80

81
        self._threads: List[threading.Thread] = []
2✔
82
        if self._max_load:
2✔
83
            self._threads.append(threading.Thread(target=self._watch_load))
×
84
        if self._max_run_time:
2✔
85
            self._threads.append(threading.Thread(target=self._shutdown_timer))
2✔
86

87
        signals.worker_init.send(self.kuyruk, worker=self)
2✔
88

89
    def run(self) -> None:
2✔
90
        """Runs the worker and consumes messages from RabbitMQ.
91
        Returns only after `shutdown()` is called.
92

93
        """
94
        if self._logging_level:
2✔
95
            logging.basicConfig(
2✔
96
                level=getattr(logging, self._logging_level.upper()),
97
                format="%(levelname).1s %(name)s.%(funcName)s:%(lineno)d - %(message)s")
98

99
        signal.signal(signal.SIGINT, self._handle_sigint)
2✔
100
        signal.signal(signal.SIGTERM, self._handle_sigterm)
2✔
101
        if platform.system() != 'Windows':
2✔
102
            # These features will not be available on Windows, but that is OK.
103
            # Read this issue for more details:
104
            # https://github.com/cenkalti/kuyruk/issues/54
105
            signal.signal(signal.SIGHUP, self._handle_sighup)
2✔
106
            signal.signal(signal.SIGUSR1, self._handle_sigusr1)
2✔
107
            signal.signal(signal.SIGUSR2, self._handle_sigusr2)
2✔
108

109
        self._started_at = os.times().elapsed
2✔
110

111
        for t in self._threads:
2✔
112
            t.start()
2✔
113

114
        try:
2✔
115
            signals.worker_start.send(self.kuyruk, worker=self)
2✔
116
            while not self.shutdown_pending.is_set():
2✔
117
                try:
2✔
118
                    self._consume_messages()
2✔
119
                    break
2✔
120
                except HeartbeatError:
2✔
121
                    logger.error("Heartbeat error")
2✔
122
                except (ConnectionError, amqp.exceptions.ConnectionError) as e:
2✔
123
                    logger.error("Connection error: %s", e)
×
124
                    traceback.print_exc()
×
125

126
                logger.info("Waiting %d seconds before reconnecting...", self._reconnect_interval)
2✔
127
                self.shutdown_pending.wait(self._reconnect_interval)
2✔
128
        finally:
129
            self.shutdown_pending.set()
2✔
130
            for t in self._threads:
2✔
131
                t.join()
2✔
132

133
            signals.worker_shutdown.send(self.kuyruk, worker=self)
2✔
134

135
        logger.debug("End run worker")
2✔
136

137
    def _consume_messages(self) -> None:
2✔
138
        with self.kuyruk.new_connection() as connection:
2✔
139
            ch = connection.channel()
2✔
140

141
            # Set prefetch count to 1. If we don't set this, RabbitMQ keeps
142
            # sending messages while we are already working on a message.
143
            ch.basic_qos(0, 1, False)
2✔
144

145
            self._declare_queues(ch)
2✔
146
            self._consume_queues(ch)
2✔
147
            logger.info('Consumer started')
2✔
148
            self._main_loop(ch)
2✔
149
            ch.close()
2✔
150

151
    def _main_loop(self, ch: amqp.Channel) -> None:
2✔
152
        while not self.shutdown_pending.is_set():
2✔
153
            self._pause_or_resume(ch)
2✔
154
            ch.connection.heartbeat_tick()
2✔
155
            try:
2✔
156
                ch.connection.drain_events(timeout=1)
2✔
157
            except socket.timeout:
2✔
158
                pass
2✔
159

160
    def _consumer_tag(self, queue: str) -> str:
2✔
161
        return "%s:%s@%s" % (queue, self._pid, self._hostname)
2✔
162

163
    def _declare_queues(self, ch: amqp.Channel) -> None:
2✔
164
        for queue in self.queues:
2✔
165
            logger.debug("queue_declare: %s", queue)
2✔
166
            ch.queue_declare(queue=queue, durable=True, auto_delete=False)
2✔
167

168
    def _pause_or_resume(self, channel: amqp.Channel) -> None:
2✔
169
        if not self._max_load:
2✔
170
            return
2✔
171

172
        try:
×
173
            load = self._current_load
×
174
        except AttributeError:
×
175
            should_pause = False
×
176
        else:
177
            should_pause = load > self._max_load
×
178

179
        if should_pause and self.consuming:
×
180
            logger.warning('Load is above the treshold (%.2f/%s), ' 'pausing consumer', load, self._max_load)
×
181
            self._cancel_queues(channel)
×
182
        elif not should_pause and not self.consuming:
×
183
            logger.warning('Load is below the treshold (%.2f/%s), ' 'resuming consumer', load, self._max_load)
×
184
            self._consume_queues(channel)
×
185

186
    def _consume_queues(self, ch: amqp.Channel) -> None:
2✔
187
        self.consuming = True
2✔
188
        for queue in self.queues:
2✔
189
            logger.debug("basic_consume: %s", queue)
2✔
190

191
            arguments = {}
2✔
192
            if self._priority:
2✔
193
                arguments['x-priority'] = self._priority
×
194

195
            ch.basic_consume(queue=queue,
2✔
196
                             consumer_tag=self._consumer_tag(queue),
197
                             callback=self._process_message,
198
                             arguments=arguments)
199

200
    def _cancel_queues(self, ch: amqp.Channel) -> None:
2✔
201
        self.consuming = False
×
202
        for queue in self.queues:
×
203
            logger.debug("basic_cancel: %s", queue)
×
204
            ch.basic_cancel(self._consumer_tag(queue))
×
205

206
    def _process_message(self, message: amqp.Message) -> None:
2✔
207
        """Processes the message received from the queue."""
208
        if self.shutdown_pending.is_set():
2✔
209
            return
×
210

211
        try:
2✔
212
            if isinstance(message.body, bytes):
2✔
213
                message.body = message.body.decode()
×
214
            description = json.loads(message.body)
2✔
215
        except Exception:
2✔
216
            logger.error("Cannot decode message. Dropping. Message: %r", message.body)
2✔
217
            traceback.print_exc()
2✔
218
            message.channel.basic_reject(message.delivery_tag, requeue=False)
2✔
219
        else:
220
            logger.info("Processing task: %r", description)
2✔
221
            self._process_description(message, description)
2✔
222

223
    def _process_description(self, message: amqp.Message, description: Dict[str, Any]) -> None:
2✔
224
        try:
2✔
225
            task = self._import_task(description['module'], description['function'])
2✔
226
            args, kwargs = description['args'], description['kwargs']
2✔
227
        except Exception:
2✔
228
            logger.error('Cannot import task')
2✔
229
            exc_info = sys.exc_info()
2✔
230
            signals.worker_failure.send(self.kuyruk, description=description, exc_info=exc_info, worker=self)
2✔
231
            message.channel.basic_reject(message.delivery_tag, requeue=False)
2✔
232
        else:
233
            self._process_task(message, description, task, args, kwargs)
2✔
234

235
    def _import_task(self, module: str, function: str) -> Task:
2✔
236
        if (module, function) in self._tasks:
2✔
237
            return self._tasks[(module, function)]
2✔
238

239
        task = importer.import_object(module, function)
2✔
240
        self._tasks[(module, function)] = task
2✔
241
        return task
2✔
242

243
    def _process_task(
2✔
244
            self,
245
            message: amqp.Message,
246
            description: Dict[str, Any],
247
            task: Task,
248
            args: Tuple,
249
            kwargs: Dict[str, Any],
250
    ) -> None:
251
        queue = message.delivery_info['routing_key']
2✔
252
        reply_to = message.properties.get('reply_to')
2✔
253
        try:
2✔
254
            result = self._run_task(message.channel.connection, task, args, kwargs)
2✔
255
        except Reject:
2✔
256
            logger.warning('Task is rejected')
2✔
257
            message.channel.basic_reject(message.delivery_tag, requeue=True)
2✔
258
        except Discard:
2✔
259
            logger.warning('Task is discarded')
2✔
260
            message.channel.basic_reject(message.delivery_tag, requeue=False)
2✔
261
            if reply_to:
2✔
262
                exc_info = sys.exc_info()
2✔
263
                self._send_reply(reply_to, message.channel, None, exc_info)
2✔
264
        except HeartbeatError:
2✔
265
            exc_info = sys.exc_info()
2✔
266
            logger.error('Heartbeat error:\n%s', ''.join(traceback.format_exception(*exc_info)))
2✔
267
            signals.worker_failure.send(
2✔
268
                self.kuyruk,
269
                description=description,
270
                task=task,
271
                args=args,
272
                kwargs=kwargs,
273
                exc_info=exc_info,
274
                worker=self,
275
                queue=queue)
276
            raise
2✔
277
        except Exception:
2✔
278
            exc_info = sys.exc_info()
2✔
279
            logger.error('Task raised an exception:\n%s', ''.join(traceback.format_exception(*exc_info)))
2✔
280
            signals.worker_failure.send(
2✔
281
                self.kuyruk,
282
                description=description,
283
                task=task,
284
                args=args,
285
                kwargs=kwargs,
286
                exc_info=exc_info,
287
                worker=self,
288
                queue=queue)
289
            message.channel.basic_reject(message.delivery_tag, requeue=False)
2✔
290
            if reply_to:
2✔
291
                self._send_reply(reply_to, message.channel, None, exc_info)
2✔
292
        else:
293
            logger.info('Task is successful')
2✔
294
            message.channel.basic_ack(message.delivery_tag)
2✔
295
            if reply_to:
2✔
296
                self._send_reply(reply_to, message.channel, result, None)
2✔
297
        finally:
298
            logger.debug("Task is processed")
2✔
299

300
    def _run_task(self, connection: amqp.Connection, task: Task, args: Tuple, kwargs: Dict[str, Any]) -> Any:
2✔
301
        hb = Heartbeat(connection, self._on_heartbeat_error)
2✔
302
        hb.start()
2✔
303

304
        self.current_task = task
2✔
305
        self.current_args = args
2✔
306
        self.current_kwargs = kwargs
2✔
307
        try:
2✔
308
            return self._apply_task(task, args, kwargs)
2✔
309
        finally:
310
            self.current_task = None
2✔
311
            self.current_args = None
2✔
312
            self.current_kwargs = None
2✔
313

314
            hb.stop()
2✔
315

316
    def _on_heartbeat_error(self, error: Exception) -> None:
2✔
317
        self._heartbeat_error = error
2✔
318
        os.kill(os.getpid(), signal.SIGHUP)
2✔
319

320
    @staticmethod
2✔
321
    def _apply_task(task: Task, args: Tuple, kwargs: Dict[str, Any]) -> Any:
2✔
322
        """Logs the time spent while running the task."""
323
        if args is None:
2✔
324
            args = ()
×
325
        if kwargs is None:
2✔
326
            kwargs = {}
×
327

328
        start = monotonic()
2✔
329
        try:
2✔
330
            return task.apply(*args, **kwargs)
2✔
331
        finally:
332
            delta = monotonic() - start
2✔
333
            logger.info("%s finished in %i seconds." % (task.name, delta))
2✔
334

335
    def _send_reply(
2✔
336
            self,
337
            reply_to: str,
338
            channel: amqp.Channel,
339
            result: Any,
340
            exc_info: Optional[ExcInfoType],
341
    ) -> None:
342
        logger.debug("Sending reply result=%r", result)
2✔
343

344
        reply = {'result': result}
2✔
345
        if exc_info:
2✔
346
            reply['exception'] = self._exc_info_dict(exc_info)
2✔
347

348
        try:
2✔
349
            body = json.dumps(reply)
2✔
350
        except Exception as e:
2✔
351
            logger.error('Cannot serialize result as JSON: %s', e)
2✔
352
            exc_info = sys.exc_info()
2✔
353
            reply = {'result': None, 'exception': self._exc_info_dict(exc_info)}
2✔
354
            body = json.dumps(reply)
2✔
355

356
        msg = amqp.Message(body=body)
2✔
357
        channel.basic_publish(msg, exchange="", routing_key=reply_to)
2✔
358

359
    @staticmethod
2✔
360
    def _exc_info_dict(exc_info: ExcInfoType) -> Dict[str, str]:
2✔
361
        type_, val, tb = exc_info
2✔
362
        return {
2✔
363
            'type': '%s.%s' % (type_.__module__, cast(Type[BaseException], type_).__name__),
364
            'value': str(val),
365
            'traceback': ''.join(traceback.format_tb(tb)),
366
        }
367

368
    def _watch_load(self) -> None:
2✔
369
        """Pause consuming messages if lood goes above the allowed limit."""
370
        while not self.shutdown_pending.wait(1):
×
371
            self._current_load = os.getloadavg()[0]
×
372

373
    @property
2✔
374
    def uptime(self) -> float:
2✔
375
        if not self._started_at:
2✔
376
            return 0
×
377

378
        return os.times().elapsed - self._started_at
2✔
379

380
    def _shutdown_timer(self) -> None:
2✔
381
        """Counts down from MAX_WORKER_RUN_TIME. When it reaches zero sutdown
382
        gracefully.
383

384
        """
385
        remaining = cast(float, self._max_run_time) - self.uptime
2✔
386
        if not self.shutdown_pending.wait(remaining):
2✔
387
            logger.warning('Run time reached zero')
2✔
388
            self.shutdown()
2✔
389

390
    def shutdown(self) -> None:
2✔
391
        """Exits after the current task is finished."""
392
        logger.warning("Shutdown requested")
2✔
393
        self.shutdown_pending.set()
2✔
394

395
    def _handle_sigint(self, signum: int, frame: Any) -> None:
2✔
396
        """Shutdown after processing current task."""
397
        logger.warning("Catched SIGINT")
2✔
398
        self.shutdown()
2✔
399

400
    def _handle_sigterm(self, signum: int, frame: Any) -> None:
2✔
401
        """Shutdown after processing current task."""
402
        logger.warning("Catched SIGTERM")
2✔
403
        self.shutdown()
2✔
404

405
    def _handle_sighup(self, signum: int, frame: Any) -> None:
2✔
406
        """Used internally to fail the task when connection to RabbitMQ is
407
        lost during the execution of the task.
408

409
        """
410
        logger.debug("Catched SIGHUP")
2✔
411
        error = self._heartbeat_error
2✔
412
        self._heartbeat_error = None
2✔
413
        raise HeartbeatError from error
2✔
414

415
    @staticmethod
2✔
416
    def _handle_sigusr1(signum: int, frame: Any) -> None:
2✔
417
        """Print stacktrace."""
418
        print('=' * 70)
2✔
419
        print(''.join(traceback.format_stack()))
2✔
420
        print('-' * 70)
2✔
421

422
    def _handle_sigusr2(self, signum: int, frame: Any) -> None:
2✔
423
        """Drop current task."""
424
        logger.warning("Catched SIGUSR2")
2✔
425
        if self.current_task:
2✔
426
            logger.warning("Dropping current task...")
2✔
427
            raise Discard
2✔
428

429
    def drop_task(self) -> None:
2✔
430
        os.kill(os.getpid(), signal.SIGUSR2)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc