• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

cenkalti / kuyruk / 5350155812

22 Jun 2023 08:47PM UTC coverage: 91.655% (-0.1%) from 91.799%
5350155812

push

github

cenkalti
set global QoS to False

78 of 87 branches covered (89.66%)

637 of 695 relevant lines covered (91.65%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.94
kuyruk/worker.py
1
import os
1✔
2
import sys
1✔
3
import json
1✔
4
import platform
1✔
5
import socket
1✔
6
import signal
1✔
7
import logging
1✔
8
import logging.config
1✔
9
import threading
1✔
10
import traceback
1✔
11
import argparse
1✔
12
import multiprocessing
1✔
13
from time import monotonic
1✔
14
from typing import Dict, Any, List, Tuple, Optional, Type, cast  # noqa
1✔
15

16
import amqp
1✔
17

18
from kuyruk import importer, signals
1✔
19
from kuyruk.kuyruk import Kuyruk
1✔
20
from kuyruk.task import Task
1✔
21
from kuyruk.heartbeat import Heartbeat
1✔
22
from kuyruk.exceptions import Reject, Discard, HeartbeatError, ExcInfoType
1✔
23

24
logger = logging.getLogger(__name__)
1✔
25

26

27
class Worker:
1✔
28
    """Consumes tasks from queues and runs them.
29

30
    :param app: An instance of :class:`~kuyruk.Kuyruk`
31
    :param args: Command line arguments
32

33
    """
34
    def __init__(self, app: Kuyruk, args: argparse.Namespace) -> None:
1✔
35
        self.kuyruk = app
1✔
36

37
        if not args.queues:
1✔
38
            args.queues = ['kuyruk']
1✔
39

40
        def add_host(queue: str) -> str:
1✔
41
            if queue.endswith('.localhost'):
1✔
42
                queue = queue.rsplit('.localhost')[0]
1✔
43
                return "%s.%s" % (queue, self._hostname)
1✔
44
            else:
45
                return queue
1✔
46

47
        self._hostname = socket.gethostname()
1✔
48
        self.queues = [add_host(q) for q in args.queues]
1✔
49
        self._tasks = {}  # type: Dict[Tuple[str, str], Task]
1✔
50
        self.shutdown_pending = threading.Event()
1✔
51
        self.consuming = False
1✔
52
        self.current_task = None  # type: Optional[Task]
1✔
53
        self.current_args = None  # type: Optional[Tuple]
1✔
54
        self.current_kwargs = None  # type: Optional[Dict[str, Any]]
1✔
55
        self._heartbeat_error: Optional[Exception]
1✔
56

57
        self._started_at = None  # type: Optional[float]
1✔
58
        self._pid = os.getpid()
1✔
59

60
        self._logging_level = app.config.WORKER_LOGGING_LEVEL
1✔
61
        if args.logging_level is not None:
1✔
62
            self._logging_level = args.logging_level
1✔
63

64
        self._max_run_time = app.config.WORKER_MAX_RUN_TIME
1✔
65
        if args.max_run_time is not None:
1✔
66
            self._max_run_time = args.max_run_time
1✔
67

68
        self._max_load = app.config.WORKER_MAX_LOAD
1✔
69
        if args.max_load is not None:
1✔
70
            self._max_load = args.max_load
×
71
        if self._max_load == -1:
1✔
72
            self._max_load == multiprocessing.cpu_count()
×
73

74
        self._reconnect_interval = app.config.WORKER_RECONNECT_INTERVAL
1✔
75

76
        self._threads = []  # type: List[threading.Thread]
1✔
77
        if self._max_load:
1✔
78
            self._threads.append(threading.Thread(target=self._watch_load))
×
79
        if self._max_run_time:
1✔
80
            self._threads.append(threading.Thread(target=self._shutdown_timer))
1✔
81

82
        signals.worker_init.send(self.kuyruk, worker=self)
1✔
83

84
    def run(self) -> None:
1✔
85
        """Runs the worker and consumes messages from RabbitMQ.
86
        Returns only after `shutdown()` is called.
87

88
        """
89
        if self._logging_level:
1✔
90
            logging.basicConfig(
1✔
91
                level=getattr(logging, self._logging_level.upper()),
92
                format="%(levelname).1s %(name)s.%(funcName)s:%(lineno)d - %(message)s")
93

94
        signal.signal(signal.SIGINT, self._handle_sigint)
1✔
95
        signal.signal(signal.SIGTERM, self._handle_sigterm)
1✔
96
        if platform.system() != 'Windows':
1✔
97
            # These features will not be available on Windows, but that is OK.
98
            # Read this issue for more details:
99
            # https://github.com/cenkalti/kuyruk/issues/54
100
            signal.signal(signal.SIGHUP, self._handle_sighup)
1✔
101
            signal.signal(signal.SIGUSR1, self._handle_sigusr1)
1✔
102
            signal.signal(signal.SIGUSR2, self._handle_sigusr2)
1✔
103

104
        self._started_at = os.times().elapsed
1✔
105

106
        for t in self._threads:
1✔
107
            t.start()
1✔
108

109
        try:
1✔
110
            signals.worker_start.send(self.kuyruk, worker=self)
1✔
111
            while not self.shutdown_pending.is_set():
1✔
112
                try:
1✔
113
                    self._consume_messages()
1✔
114
                    break
1✔
115
                except HeartbeatError:
1✔
116
                    logger.error("Heartbeat error")
1✔
117
                except (ConnectionError, amqp.exceptions.ConnectionError) as e:
1✔
118
                    logger.error("Connection error: %s", e)
×
119
                    traceback.print_exc()
×
120

121
                logger.info("Waiting %d seconds before reconnecting...", self._reconnect_interval)
1✔
122
                self.shutdown_pending.wait(self._reconnect_interval)
1✔
123
        finally:
124
            self.shutdown_pending.set()
1✔
125
            for t in self._threads:
1✔
126
                t.join()
1✔
127

128
            signals.worker_shutdown.send(self.kuyruk, worker=self)
1✔
129

130
        logger.debug("End run worker")
1✔
131

132
    def _consume_messages(self) -> None:
1✔
133
        with self.kuyruk.channel() as ch:
1✔
134
            # Set prefetch count to 1. If we don't set this, RabbitMQ keeps
135
            # sending messages while we are already working on a message.
136
            ch.basic_qos(0, 1, False)
1✔
137

138
            self._declare_queues(ch)
1✔
139
            self._consume_queues(ch)
1✔
140
            logger.info('Consumer started')
1✔
141
            self._main_loop(ch)
1✔
142

143
    def _main_loop(self, ch: amqp.Channel) -> None:
1✔
144
        while not self.shutdown_pending.is_set():
1✔
145
            self._pause_or_resume(ch)
1✔
146
            ch.connection.heartbeat_tick()
1✔
147
            try:
1✔
148
                ch.connection.drain_events(timeout=1)
1✔
149
            except socket.timeout:
1✔
150
                pass
1✔
151

152
    def _consumer_tag(self, queue: str) -> str:
1✔
153
        return "%s:%s@%s" % (queue, self._pid, self._hostname)
1✔
154

155
    def _declare_queues(self, ch: amqp.Channel) -> None:
1✔
156
        for queue in self.queues:
1✔
157
            logger.debug("queue_declare: %s", queue)
1✔
158
            ch.queue_declare(queue=queue, durable=True, auto_delete=False)
1✔
159

160
    def _pause_or_resume(self, channel: amqp.Channel) -> None:
1✔
161
        if not self._max_load:
1✔
162
            return
1✔
163

164
        try:
×
165
            load = self._current_load
×
166
        except AttributeError:
×
167
            should_pause = False
×
168
        else:
169
            should_pause = load > self._max_load
×
170

171
        if should_pause and self.consuming:
×
172
            logger.warning('Load is above the treshold (%.2f/%s), ' 'pausing consumer', load, self._max_load)
×
173
            self._cancel_queues(channel)
×
174
        elif not should_pause and not self.consuming:
×
175
            logger.warning('Load is below the treshold (%.2f/%s), ' 'resuming consumer', load, self._max_load)
×
176
            self._consume_queues(channel)
×
177

178
    def _consume_queues(self, ch: amqp.Channel) -> None:
1✔
179
        self.consuming = True
1✔
180
        for queue in self.queues:
1✔
181
            logger.debug("basic_consume: %s", queue)
1✔
182
            ch.basic_consume(queue=queue, consumer_tag=self._consumer_tag(queue), callback=self._process_message)
1✔
183

184
    def _cancel_queues(self, ch: amqp.Channel) -> None:
1✔
185
        self.consuming = False
×
186
        for queue in self.queues:
×
187
            logger.debug("basic_cancel: %s", queue)
×
188
            ch.basic_cancel(self._consumer_tag(queue))
×
189

190
    def _process_message(self, message: amqp.Message) -> None:
1✔
191
        """Processes the message received from the queue."""
192
        if self.shutdown_pending.is_set():
1✔
193
            return
×
194

195
        try:
1✔
196
            if isinstance(message.body, bytes):
1✔
197
                message.body = message.body.decode()
×
198
            description = json.loads(message.body)
1✔
199
        except Exception:
1✔
200
            logger.error("Cannot decode message. Dropping. Message: %r", message.body)
1✔
201
            traceback.print_exc()
1✔
202
            message.channel.basic_reject(message.delivery_tag, requeue=False)
1✔
203
        else:
204
            logger.info("Processing task: %r", description)
1✔
205
            self._process_description(message, description)
1✔
206

207
    def _process_description(self, message: amqp.Message, description: Dict[str, Any]) -> None:
1✔
208
        try:
1✔
209
            task = self._import_task(description['module'], description['function'])
1✔
210
            args, kwargs = description['args'], description['kwargs']
1✔
211
        except Exception:
1✔
212
            logger.error('Cannot import task')
1✔
213
            exc_info = sys.exc_info()
1✔
214
            signals.worker_failure.send(self.kuyruk, description=description, exc_info=exc_info, worker=self)
1✔
215
            message.channel.basic_reject(message.delivery_tag, requeue=False)
1✔
216
        else:
217
            self._process_task(message, description, task, args, kwargs)
1✔
218

219
    def _import_task(self, module: str, function: str) -> Task:
1✔
220
        if (module, function) in self._tasks:
1✔
221
            return self._tasks[(module, function)]
1✔
222

223
        task = importer.import_object(module, function)
1✔
224
        self._tasks[(module, function)] = task
1✔
225
        return task
1✔
226

227
    def _process_task(
1✔
228
            self,
229
            message: amqp.Message,
230
            description: Dict[str, Any],
231
            task: Task,
232
            args: Tuple,
233
            kwargs: Dict[str, Any],
234
    ) -> None:
235
        queue = message.delivery_info['routing_key']
1✔
236
        reply_to = message.properties.get('reply_to')
1✔
237
        try:
1✔
238
            result = self._run_task(message.channel.connection, task, args, kwargs)
1✔
239
        except Reject:
1✔
240
            logger.warning('Task is rejected')
1✔
241
            message.channel.basic_reject(message.delivery_tag, requeue=True)
1✔
242
        except Discard:
1✔
243
            logger.warning('Task is discarded')
1✔
244
            message.channel.basic_reject(message.delivery_tag, requeue=False)
1✔
245
            if reply_to:
1✔
246
                exc_info = sys.exc_info()
1✔
247
                self._send_reply(reply_to, message.channel, None, exc_info)
1✔
248
        except HeartbeatError:
1✔
249
            exc_info = sys.exc_info()
1✔
250
            logger.error('Heartbeat error:\n%s', ''.join(traceback.format_exception(*exc_info)))
1✔
251
            signals.worker_failure.send(
1✔
252
                self.kuyruk,
253
                description=description,
254
                task=task,
255
                args=args,
256
                kwargs=kwargs,
257
                exc_info=exc_info,
258
                worker=self,
259
                queue=queue)
260
            raise
1✔
261
        except Exception:
1✔
262
            exc_info = sys.exc_info()
1✔
263
            logger.error('Task raised an exception:\n%s', ''.join(traceback.format_exception(*exc_info)))
1✔
264
            signals.worker_failure.send(
1✔
265
                self.kuyruk,
266
                description=description,
267
                task=task,
268
                args=args,
269
                kwargs=kwargs,
270
                exc_info=exc_info,
271
                worker=self,
272
                queue=queue)
273
            message.channel.basic_reject(message.delivery_tag, requeue=False)
1✔
274
            if reply_to:
1✔
275
                self._send_reply(reply_to, message.channel, None, exc_info)
1✔
276
        else:
277
            logger.info('Task is successful')
1✔
278
            message.channel.basic_ack(message.delivery_tag)
1✔
279
            if reply_to:
1✔
280
                self._send_reply(reply_to, message.channel, result, None)
1✔
281
        finally:
282
            logger.debug("Task is processed")
1✔
283

284
    def _run_task(self, connection: amqp.Connection, task: Task, args: Tuple, kwargs: Dict[str, Any]) -> Any:
1✔
285
        hb = Heartbeat(connection, self._on_heartbeat_error)
1✔
286
        hb.start()
1✔
287

288
        self.current_task = task
1✔
289
        self.current_args = args
1✔
290
        self.current_kwargs = kwargs
1✔
291
        try:
1✔
292
            return self._apply_task(task, args, kwargs)
1✔
293
        finally:
294
            self.current_task = None
1✔
295
            self.current_args = None
1✔
296
            self.current_kwargs = None
1✔
297

298
            hb.stop()
1✔
299

300
    def _on_heartbeat_error(self, error: Exception) -> None:
1✔
301
        self._heartbeat_error = error
1✔
302
        os.kill(os.getpid(), signal.SIGHUP)
1✔
303

304
    @staticmethod
1✔
305
    def _apply_task(task: Task, args: Tuple, kwargs: Dict[str, Any]) -> Any:
1✔
306
        """Logs the time spent while running the task."""
307
        if args is None:
1✔
308
            args = ()
×
309
        if kwargs is None:
1✔
310
            kwargs = {}
×
311

312
        start = monotonic()
1✔
313
        try:
1✔
314
            return task.apply(*args, **kwargs)
1✔
315
        finally:
316
            delta = monotonic() - start
1✔
317
            logger.info("%s finished in %i seconds." % (task.name, delta))
1✔
318

319
    def _send_reply(
1✔
320
            self,
321
            reply_to: str,
322
            channel: amqp.Channel,
323
            result: Any,
324
            exc_info: Optional[ExcInfoType],
325
    ) -> None:
326
        logger.debug("Sending reply result=%r", result)
1✔
327

328
        reply = {'result': result}
1✔
329
        if exc_info:
1✔
330
            reply['exception'] = self._exc_info_dict(exc_info)
1✔
331

332
        try:
1✔
333
            body = json.dumps(reply)
1✔
334
        except Exception as e:
1✔
335
            logger.error('Cannot serialize result as JSON: %s', e)
1✔
336
            exc_info = sys.exc_info()
1✔
337
            reply = {'result': None, 'exception': self._exc_info_dict(exc_info)}
1✔
338
            body = json.dumps(reply)
1✔
339

340
        msg = amqp.Message(body=body)
1✔
341
        channel.basic_publish(msg, exchange="", routing_key=reply_to)
1✔
342

343
    @staticmethod
1✔
344
    def _exc_info_dict(exc_info: ExcInfoType) -> Dict[str, str]:
1✔
345
        type_, val, tb = exc_info
1✔
346
        return {
1✔
347
            'type': '%s.%s' % (type_.__module__, cast(Type[BaseException], type_).__name__),
348
            'value': str(val),
349
            'traceback': ''.join(traceback.format_tb(tb)),
350
        }
351

352
    def _watch_load(self) -> None:
1✔
353
        """Pause consuming messages if lood goes above the allowed limit."""
354
        while not self.shutdown_pending.wait(1):
×
355
            self._current_load = os.getloadavg()[0]
×
356

357
    @property
1✔
358
    def uptime(self) -> float:
1✔
359
        if not self._started_at:
1✔
360
            return 0
×
361

362
        return os.times().elapsed - self._started_at
1✔
363

364
    def _shutdown_timer(self) -> None:
1✔
365
        """Counts down from MAX_WORKER_RUN_TIME. When it reaches zero sutdown
366
        gracefully.
367

368
        """
369
        remaining = cast(float, self._max_run_time) - self.uptime
1✔
370
        if not self.shutdown_pending.wait(remaining):
1✔
371
            logger.warning('Run time reached zero')
1✔
372
            self.shutdown()
1✔
373

374
    def shutdown(self) -> None:
1✔
375
        """Exits after the current task is finished."""
376
        logger.warning("Shutdown requested")
1✔
377
        self.shutdown_pending.set()
1✔
378

379
    def _handle_sigint(self, signum: int, frame: Any) -> None:
1✔
380
        """Shutdown after processing current task."""
381
        logger.warning("Catched SIGINT")
1✔
382
        self.shutdown()
1✔
383

384
    def _handle_sigterm(self, signum: int, frame: Any) -> None:
1✔
385
        """Shutdown after processing current task."""
386
        logger.warning("Catched SIGTERM")
1✔
387
        self.shutdown()
1✔
388

389
    def _handle_sighup(self, signum: int, frame: Any) -> None:
1✔
390
        """Used internally to fail the task when connection to RabbitMQ is
391
        lost during the execution of the task.
392

393
        """
394
        logger.debug("Catched SIGHUP")
1✔
395
        error = self._heartbeat_error
1✔
396
        self._heartbeat_error = None
1✔
397
        raise HeartbeatError from error
1✔
398

399
    @staticmethod
1✔
400
    def _handle_sigusr1(signum: int, frame: Any) -> None:
1✔
401
        """Print stacktrace."""
402
        print('=' * 70)
1✔
403
        print(''.join(traceback.format_stack()))
1✔
404
        print('-' * 70)
1✔
405

406
    def _handle_sigusr2(self, signum: int, frame: Any) -> None:
1✔
407
        """Drop current task."""
408
        logger.warning("Catched SIGUSR2")
1✔
409
        if self.current_task:
1✔
410
            logger.warning("Dropping current task...")
1✔
411
            raise Discard
1✔
412

413
    def drop_task(self) -> None:
1✔
414
        os.kill(os.getpid(), signal.SIGUSR2)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc