• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

flamingo-run / django-cloud-tasks / 13136871008

04 Feb 2025 01:46PM UTC coverage: 86.693% (+0.03%) from 86.661%
13136871008

push

github

frnsimoes
build: bump version

112 of 162 branches covered (69.14%)

Branch coverage included in aggregate %.

976 of 1093 relevant lines covered (89.3%)

2.68 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

79.26
/django_cloud_tasks/tasks/task.py
1
import abc
3✔
2
import inspect
3✔
3
from dataclasses import dataclass, fields
3✔
4
from datetime import datetime, timedelta, timezone
3✔
5
from functools import lru_cache
3✔
6
from random import randint
3✔
7
from typing import Any, Self
3✔
8
from urllib.parse import urljoin
3✔
9
from concurrent.futures import ThreadPoolExecutor
3✔
10
from django.apps import apps
3✔
11
from django.urls import reverse
3✔
12
from django.utils.timezone import now
3✔
13
from gcp_pilot.exceptions import DeletedRecently
3✔
14
from gcp_pilot.tasks import CloudTasks
3✔
15
from google.cloud.tasks_v2 import Task as GoogleCloudTask
3✔
16

17
from django_cloud_tasks.apps import DjangoCloudTasksAppConfig
3✔
18
from django_cloud_tasks.context import get_current_headers
3✔
19
from django_cloud_tasks.serializers import deserialize, serialize
3✔
20
import json
3✔
21
from django.http import HttpRequest
3✔
22

23

24
def register(task_class) -> None:
3✔
25
    app: DjangoCloudTasksAppConfig = apps.get_app_config("django_cloud_tasks")
3✔
26
    app.register_task(task_class=task_class)
3✔
27

28

29
@dataclass
3✔
30
class TaskMetadata:
3✔
31
    task_id: str
3✔
32
    queue_name: str
3✔
33
    dispatch_number: int  # number of dispatches (0 means first attempt)
3✔
34
    execution_number: int  # number of responses received (excluding 5XX)
3✔
35
    eta: datetime
3✔
36
    previous_response: str | None = None
3✔
37
    previous_failure: str | None = None
3✔
38
    project_id: str | None = None
3✔
39
    custom_headers: dict | None = None
3✔
40
    is_cloud_scheduler: bool | None = None
3✔
41
    cloud_scheduler_schedule_time: datetime | None = None
3✔
42
    cloud_scheduler_job_name: str | None = None
3✔
43

44
    def __post_init__(self):
3✔
45
        self.custom_headers = get_current_headers()
3✔
46
        self._max_attempts = None
3✔
47

48
    @classmethod
3✔
49
    def from_headers(cls, headers: dict) -> Self:
3✔
50
        # Available data: https://cloud.google.com/tasks/docs/creating-http-target-tasks#handler
51
        cloud_tasks_prefix = "X-Cloudtasks-"
3✔
52
        cloud_scheduler_prefix = "X-Cloudscheduler"
3✔
53

54
        if (attempt_str := headers.get(f"{cloud_tasks_prefix}Taskexecutioncount")) is not None:
3✔
55
            execution_number = int(attempt_str)
3✔
56
        else:
57
            execution_number = None
3✔
58

59
        if (retry_str := headers.get(f"{cloud_tasks_prefix}Taskretrycount")) is not None:
3✔
60
            dispatch_number = int(retry_str)
3✔
61
        else:
62
            dispatch_number = None
3✔
63

64
        if eta_epoch := headers.get(f"{cloud_tasks_prefix}Tasketa"):
3✔
65
            eta = datetime.fromtimestamp(int(eta_epoch.split(".")[0]), tz=timezone.utc)
3✔
66
        else:
67
            eta = None
3✔
68

69
        cloud_scheduler_job_name = headers.get(f"{cloud_scheduler_prefix}-Jobname")
3✔
70

71
        if schedule_time_str := headers.get(f"{cloud_scheduler_prefix}-Scheduletime"):
3✔
72
            try:
3✔
73
                schedule_time = datetime.fromisoformat(schedule_time_str)
3✔
74
            except ValueError:
×
75
                schedule_time = None
×
76
        else:
77
            schedule_time = None
3✔
78

79
        is_cloud_scheduler = headers.get(cloud_scheduler_prefix) == "true"
3✔
80

81
        return cls(
3✔
82
            project_id=headers.get(f"{cloud_tasks_prefix}Projectname"),
83
            queue_name=headers.get(f"{cloud_tasks_prefix}Queuename"),
84
            task_id=headers.get(f"{cloud_tasks_prefix}Taskname"),
85
            dispatch_number=dispatch_number,
86
            execution_number=execution_number,
87
            eta=eta,
88
            previous_response=headers.get(f"{cloud_tasks_prefix}TaskPreviousResponse"),
89
            previous_failure=headers.get(f"{cloud_tasks_prefix}TaskRetryReason"),
90
            is_cloud_scheduler=is_cloud_scheduler,
91
            cloud_scheduler_schedule_time=schedule_time,
92
            cloud_scheduler_job_name=cloud_scheduler_job_name,
93
        )
94

95
    def to_headers(self) -> dict:
3✔
96
        cloud_tasks_prefix = "X-Cloudtasks-"
3✔
97
        cloud_tasks_headers = {
3✔
98
            f"{cloud_tasks_prefix}Taskname": self.task_id,
99
            f"{cloud_tasks_prefix}Queuename": self.queue_name,
100
            f"{cloud_tasks_prefix}Projectname": self.project_id,
101
            f"{cloud_tasks_prefix}Taskexecutioncount": str(self.execution_number),
102
            f"{cloud_tasks_prefix}Taskretrycount": str(self.dispatch_number),
103
            f"{cloud_tasks_prefix}Tasketa": str(int(self.eta.timestamp())),
104
            f"{cloud_tasks_prefix}TaskPreviousResponse": self.previous_response,
105
            f"{cloud_tasks_prefix}TaskRetryReason": self.previous_failure,
106
        }
107

108
        if self.is_cloud_scheduler:
3✔
109
            cloud_scheduler_prefix = "X-Cloudscheduler"
3✔
110
            cloud_scheduler_headers = {
3✔
111
                f"{cloud_scheduler_prefix}-Jobname": self.cloud_scheduler_job_name,
112
                f"{cloud_scheduler_prefix}-Scheduletime": self.cloud_scheduler_schedule_time.isoformat(),
113
                f"{cloud_scheduler_prefix}": "true",
114
            }
115
            return cloud_tasks_headers | cloud_scheduler_headers
3✔
116

117
        return cloud_tasks_headers
3✔
118

119
    @classmethod
3✔
120
    def from_task_obj(cls, task_obj: GoogleCloudTask) -> Self:
3✔
121
        _, project_id, _, _, _, queue_name, _, task_id = task_obj.name.split("/")  # TODO: use regex
×
122
        return cls(
×
123
            project_id=project_id,
124
            queue_name=queue_name,
125
            task_id=task_id,
126
            dispatch_number=task_obj.dispatch_count,
127
            execution_number=task_obj.response_count,
128
            eta=task_obj.schedule_time,
129
            previous_response=None,
130
            previous_failure=None,
131
            custom_headers=dict(task_obj.http_request.headers),
132
        )
133

134
    @classmethod
3✔
135
    def build_eager(cls, task_class) -> Self:
3✔
136
        return cls(
3✔
137
            project_id=None,
138
            queue_name=task_class.queue(),
139
            task_id="--SYNC--",
140
            dispatch_number=0,
141
            execution_number=0,
142
            eta=now(),
143
            previous_response=None,
144
            previous_failure=None,
145
        )
146

147
    @property
3✔
148
    def max_retries(self) -> int:
3✔
149
        queue = CloudTasks(project_id=self.project_id).get_queue(queue_name=self.queue_name)
×
150
        return queue.retry_config.max_attempts
×
151

152
    @property
3✔
153
    def attempt_number(self) -> int:
3✔
154
        return self.dispatch_number + 1
3✔
155

156
    @property
3✔
157
    def first_attempt(self) -> bool:
3✔
158
        return self.dispatch_number == 0
×
159

160
    @property
3✔
161
    def last_attempt(self) -> bool:
3✔
162
        return self.attempt_number == self.max_retries
×
163

164
    @property
3✔
165
    def eager(self) -> bool:
3✔
166
        return self.task_id == "--SYNC--"
×
167

168
    def __eq__(self, other) -> bool:
3✔
169
        if not isinstance(other, TaskMetadata):
3✔
170
            return False
3✔
171

172
        check_fields = [field.name for field in fields(TaskMetadata)]
3✔
173
        for field in check_fields:
3✔
174
            try:
3✔
175
                value = getattr(self, field)
3✔
176
                other_value = getattr(other, field)
3✔
177
                if value != other_value:
3✔
178
                    return False
3✔
179
            except (AttributeError, ValueError):
×
180
                return False
×
181
        return True
3✔
182

183

184
class TaskMeta(type):
3✔
185
    def __new__(cls, name, bases, attrs):
3✔
186
        klass = type.__new__(cls, name, bases, attrs)
3✔
187
        if not inspect.isabstract(klass) and abc.ABC not in bases:
3✔
188
            register(task_class=klass)
3✔
189
        return klass
3✔
190

191
    def __str__(self):
3✔
192
        return self.__name__
3✔
193

194

195
class DjangoCloudTask(abc.ABCMeta, TaskMeta): ...
3✔
196

197

198
class Task(abc.ABC, metaclass=DjangoCloudTask):
3✔
199
    only_once: bool = False
3✔
200

201
    def __init__(self, metadata: TaskMetadata | None = None):
3✔
202
        self._metadata = metadata or TaskMetadata.build_eager(task_class=self.__class__)
3✔
203

204
    @abc.abstractmethod
3✔
205
    def run(self, **kwargs):
3✔
206
        raise NotImplementedError()
×
207

208
    def process(self, **task_kwargs) -> Any:
3✔
209
        return self.run(**task_kwargs)
3✔
210

211
    @classmethod
3✔
212
    def sync(cls, **kwargs):
3✔
213
        return cls().run(**kwargs)
3✔
214

215
    @classmethod
3✔
216
    def asap(cls, **kwargs):
3✔
217
        return cls.push(task_kwargs=kwargs)
3✔
218

219
    @classmethod
3✔
220
    def later(cls, task_kwargs: dict, eta: int | timedelta | datetime, queue: str = None, headers: dict | None = None):
3✔
221
        delay_in_seconds = cls._calculate_delay_in_seconds(eta=eta)
3✔
222
        cls._validate_delay(delay_in_seconds=delay_in_seconds)
3✔
223
        return cls.push(
3✔
224
            task_kwargs=task_kwargs,
225
            queue=queue,
226
            headers=headers,
227
            delay_in_seconds=delay_in_seconds,
228
        )
229

230
    @staticmethod
3✔
231
    def _calculate_delay_in_seconds(eta: int | timedelta | datetime) -> float | int:
3✔
232
        if isinstance(eta, int) or isinstance(eta, float):
3✔
233
            return eta
3✔
234
        elif isinstance(eta, timedelta):
3✔
235
            return eta.total_seconds()
3✔
236
        elif isinstance(eta, datetime):
3✔
237
            return (eta - now()).total_seconds()
3✔
238
        else:
239
            raise ValueError(
3✔
240
                f"Unsupported schedule {eta} of type {eta.__class__.__name__}. Must be int, timedelta or datetime."
241
            )
242

243
    @staticmethod
3✔
244
    def _validate_delay(delay_in_seconds: int | float):
3✔
245
        max_eta_task = get_config("tasks_max_eta")
3✔
246
        if max_eta_task is not None and delay_in_seconds > max_eta_task:
3✔
247
            raise ValueError(f"Invalid delay time {delay_in_seconds}, maximum is {max_eta_task}")
3✔
248

249
    @classmethod
3✔
250
    def until(cls, task_kwargs: dict, max_eta: datetime, queue: str = None, headers: dict | None = None):
3✔
251
        if not isinstance(max_eta, datetime):
×
252
            raise ValueError("max_date must be a datetime")
×
253
        if max_eta < now():
×
254
            raise ValueError("max_date must be in the future")
×
255

256
        max_seconds = (max_eta - now()).total_seconds()
×
257
        delay_in_seconds = randint(0, int(max_seconds))
×
258
        return cls.push(
×
259
            task_kwargs=task_kwargs,
260
            queue=queue,
261
            headers=headers,
262
            delay_in_seconds=delay_in_seconds,
263
        )
264

265
    @classmethod
3✔
266
    def push(
3✔
267
        cls,
268
        task_kwargs: dict,
269
        headers: dict | None = None,
270
        queue: str | None = None,
271
        delay_in_seconds: int | float | None = None,
272
        task_timeout: timedelta | None = None,
273
    ):
274
        payload = serialize(value=task_kwargs)
3✔
275

276
        if cls.eager():
3✔
277
            return cls.sync(**deserialize(value=payload))
3✔
278

279
        client = cls._get_tasks_client()
3✔
280

281
        headers = get_current_headers() | (headers or {})
3✔
282
        headers.setdefault("X-CloudTasks-Projectname", client.project_id)
3✔
283

284
        api_kwargs = {
3✔
285
            "queue_name": queue or cls.queue(),
286
            "url": cls.url(),
287
            "payload": payload,
288
            "headers": headers,
289
            "task_timeout": task_timeout or cls.get_task_timeout(),
290
        }
291

292
        if delay_in_seconds:
3✔
293
            api_kwargs["delay_in_seconds"] = delay_in_seconds
3✔
294

295
        if cls.only_once:
3✔
296
            api_kwargs.update(
3✔
297
                {
298
                    "task_name": cls.name(),
299
                    "unique": False,
300
                }
301
            )
302

303
        try:
3✔
304
            outcome = client.push(**api_kwargs)
3✔
305
        except DeletedRecently:
3✔
306
            # If the task queue was "accidentally" removed, GCP does not let us recreate it in 1 week
307
            # so we'll use a temporary queue (defined in settings) for some time
308
            backup_queue_name = apps.get_app_config("django_cloud_tasks").get_backup_queue_name(
3✔
309
                original_name=cls.queue(),
310
            )
311
            if not backup_queue_name:
3!
312
                raise
×
313

314
            api_kwargs["queue_name"] = backup_queue_name
3✔
315
            outcome = cls._get_tasks_client().push(**api_kwargs)
3✔
316

317
        task_metadata_class = get_config(name="task_metadata_class")
3✔
318
        return task_metadata_class.from_task_obj(task_obj=outcome)
3✔
319

320
    @classmethod
3✔
321
    def get_task_timeout(cls):
3✔
322
        return None
3✔
323

324
    @classmethod
3✔
325
    def debug(cls, task_id: str):
3✔
326
        client = cls._get_tasks_client()
×
327
        task_obj = client.get_task(queue_name=cls.queue(), task_name=task_id)
×
328
        task_kwargs = json.loads(task_obj.http_request.body)
×
329

330
        task_metadata_class = get_config(name="task_metadata_class")
×
331
        metadata = task_metadata_class.from_task_obj(task_obj=task_obj)
×
332
        return cls(metadata=metadata).run(**task_kwargs)
×
333

334
    @classmethod
3✔
335
    def discard(cls, task_id: str | None = None, min_retries: int = 0):
3✔
336
        client = cls._get_tasks_client()
×
337
        if task_id:
×
338
            task_objects = [client.get_task(queue_name=cls.queue(), task_name=task_id)]
×
339
        else:
340
            task_objects = client.list_tasks(queue_name=cls.queue())
×
341

342
        def process(task_obj):
×
343
            task_name = task_obj.http_request.url.rsplit("/", 1)[-1]
×
344
            task_id = task_obj.name.split("/")[-1]
×
345
            client.delete_task(queue_name=cls.queue(), task_name=task_id)
×
346
            return f"{task_name}/{task_id}"
×
347

348
        def jobs():
×
349
            for task_obj in task_objects:
×
350
                task_name = task_obj.http_request.url.rsplit("/", 1)[-1]
×
351
                if task_name == cls.name() and task_obj.dispatch_count >= min_retries:
×
352
                    yield task_obj
×
353

354
        pool = ThreadPoolExecutor()
×
355

356
        outputs = []
×
357
        for output in pool.map(process, jobs()):
×
358
            outputs.append(output)
×
359
        return outputs
×
360

361
    @classmethod
3✔
362
    def name(cls) -> str:
3✔
363
        return str(cls)
3✔
364

365
    @classmethod
3✔
366
    def queue(cls) -> str:
3✔
367
        app_name = get_config(name="app_name")
3✔
368
        return app_name or "tasks"
3✔
369

370
    @classmethod
3✔
371
    @lru_cache()
3✔
372
    def url(cls) -> str:
3✔
373
        domain = get_config(name="domain")
3✔
374
        url_name = get_config(name="tasks_url_name")
3✔
375
        path = reverse(url_name, args=(cls.name(),))
3✔
376
        return urljoin(domain, path)
3✔
377

378
    @classmethod
3✔
379
    @lru_cache()
3✔
380
    def eager(cls) -> bool:
3✔
381
        return get_config(name="eager")
3✔
382

383
    @classmethod
3✔
384
    @lru_cache()
3✔
385
    def _get_tasks_client(cls) -> CloudTasks:
3✔
386
        return CloudTasks()
3✔
387

388

389
def get_config(name: str) -> Any:
3✔
390
    app: DjangoCloudTasksAppConfig = apps.get_app_config("django_cloud_tasks")
3✔
391
    return getattr(app, name)
3✔
392

393

394
def is_task_route(request: HttpRequest) -> bool:
3✔
395
    parts = request.path.removesuffix("/").rsplit("/", 1)
3✔
396
    if len(parts) != 2:
3!
397
        return False
×
398

399
    _, task_name = parts
3✔
400
    if not task_name:
3!
401
        return False
×
402

403
    expected_url = reverse(get_config(name="tasks_url_name"), args=(task_name,))
3✔
404
    return request.path == expected_url
3✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc