• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 8540651701

03 Apr 2024 02:38PM UTC coverage: 89.274% (+0.06%) from 89.212%
8540651701

push

gihub-action

web-flow
update to python 3.12 (#167)

5718 of 6405 relevant lines covered (89.27%)

0.89 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.67
/components/renku_data_services/message_queue/redis_queue.py
1
"""Message queue implementation for redis streams."""
1✔
2

3
import base64
1✔
4
import copy
1✔
5
import glob
1✔
6
import inspect
1✔
7
import json
1✔
8
from collections.abc import Callable
1✔
9
from dataclasses import dataclass
1✔
10
from datetime import datetime
1✔
11
from functools import wraps
1✔
12
from io import BytesIO
1✔
13
from pathlib import Path
1✔
14
from types import NoneType, UnionType
1✔
15
from typing import Optional, TypeVar, Union
1✔
16

17
from dataclasses_avroschema.schema_generator import AvroModel
1✔
18
from dataclasses_avroschema.utils import standardize_custom_type
1✔
19
from fastavro import parse_schema, schemaless_reader, schemaless_writer
1✔
20
from ulid import ULID
1✔
21

22
from renku_data_services.message_queue.avro_models.io.renku.events.v1.header import Header
1✔
23
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_authorization_added import (
1✔
24
    ProjectAuthorizationAdded,
25
)
26
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_authorization_removed import (
1✔
27
    ProjectAuthorizationRemoved,
28
)
29
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_authorization_updated import (
1✔
30
    ProjectAuthorizationUpdated,
31
)
32
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_created import ProjectCreated
1✔
33
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_removed import ProjectRemoved
1✔
34
from renku_data_services.message_queue.avro_models.io.renku.events.v1.project_updated import ProjectUpdated
1✔
35
from renku_data_services.message_queue.avro_models.io.renku.events.v1.user_added import UserAdded
1✔
36
from renku_data_services.message_queue.avro_models.io.renku.events.v1.user_removed import UserRemoved
1✔
37
from renku_data_services.message_queue.avro_models.io.renku.events.v1.user_updated import UserUpdated
1✔
38
from renku_data_services.message_queue.config import RedisConfig
1✔
39
from renku_data_services.message_queue.interface import IMessageQueue
1✔
40

41
_root = Path(__file__).parent.resolve()
1✔
42
_filter = f"{_root}/schemas/**/*.avsc"
1✔
43
_schemas = {}
1✔
44
for file in glob.glob(_filter, recursive=True):
1✔
45
    with open(file) as f:
1✔
46
        _schema = json.load(f)
1✔
47
        if "name" in _schema:
1✔
48
            _name = _schema["name"]
1✔
49
            _namespace = _schema.get("namespace")
1✔
50
            if _namespace:
1✔
51
                _name = f"{_namespace}.{_name}"
1✔
52
            _schemas[_name] = _schema
1✔
53

54

55
def serialize_binary(obj: AvroModel) -> bytes:
1✔
56
    """Serialize a message with avro, making sure to use the original schema."""
57
    schema = parse_schema(schema=json.loads(getattr(obj, "_schema", obj.avro_schema())), named_schemas=_schemas)
1✔
58
    fo = BytesIO()
1✔
59
    schemaless_writer(fo, schema, obj.asdict(standardize_factory=standardize_custom_type))
1✔
60
    return fo.getvalue()
1✔
61

62

63
T = TypeVar("T", bound=AvroModel)
1✔
64

65

66
def deserialize_binary(data: bytes, model: type[T]) -> T:
1✔
67
    """Deserialize an avro binary message, using the original schema."""
68
    input_stream = BytesIO(data)
1✔
69
    schema = parse_schema(schema=json.loads(getattr(model, "_schema", model.avro_schema())), named_schemas=_schemas)
1✔
70

71
    payload = schemaless_reader(input_stream, schema, schema)
1✔
72
    input_stream.flush()
1✔
73
    obj = model.parse_obj(payload)  # type: ignore
1✔
74

75
    return obj
1✔
76

77

78
def create_header(message_type: str, content_type: str = "application/avro+binary") -> Header:
1✔
79
    """Create a message header."""
80
    return Header(
1✔
81
        type=message_type,
82
        source="renku-data-services",
83
        dataContentType=content_type,
84
        schemaVersion="1",
85
        time=datetime.utcnow(),
86
        requestId=ULID().hex,
87
    )
88

89

90
def dispatch_message(transform: Callable[..., Union[AvroModel, Optional[AvroModel]]]):
1✔
91
    """Sends a message on the message queue.
92

93
    The transform method is called with the arguments and result of the wrapped method. It is responsible for
94
    creating the message type to dispatch. The message is sent based on the return type of the transform method.
95
    This wrapper takes care of guaranteed at-least-once delivery of messages by using a backup 'events' table that
96
    stores messages for redelivery shold sending fail. For this to work correctly, the messages need to be stored
97
    in the events table in the same database transaction as the metadata update that they are related to.
98
    All this is to ensure that downstream consumers are kept up to date. They are expected to handle multiple
99
    delivery of the same message correctly.
100
    This code addresses these potential error cases:
101
    - Data being persisted in our database but no message being sent due to errors/restarts of the service at the
102
      wrong time.
103
    - Redis not being available.
104
    Downstream consumers are expected to handle the following:
105
    - The same message being delivered more than once. Deduplication can be done due to the message ids being
106
      the identical.
107
    - Messages being delivered out of order. This should be super rare, e.g. a user edits a project, message delivery
108
      fails duf to redis being down, the user then deletes the project and message delivery works. Then the first
109
      message is delivered again and this works, meaning downstream the project deletion arrives before the project
110
      update. Order can be maintained due to the timestamps in the messages.
111
    """
112

113
    def decorator(f):
1✔
114
        @wraps(f)
1✔
115
        async def message_wrapper(self, session, *args, **kwargs):
1✔
116
            result = await f(self, session, *args, **kwargs)
1✔
117
            payload = transform(result, *args, **kwargs)
1✔
118

119
            if payload is None:
1✔
120
                # don't send message if transform returned None
121
                return result
×
122

123
            signature = inspect.signature(transform).return_annotation
1✔
124

125
            # Handle type unions
126
            non_none_types = None
1✔
127
            if isinstance(signature, UnionType):
1✔
128
                non_none_types = [t for t in signature.__args__ if t != NoneType]
1✔
129
            elif isinstance(signature, str) and " | " in signature:
1✔
130
                non_none_types = [t for t in signature.split(" | ") if t != "None"]
1✔
131

132
            if non_none_types is not None:
1✔
133
                if len(non_none_types) != 1:
1✔
134
                    raise NotImplementedError(f"Only optional types are supported, got {signature}")
×
135
                signature = non_none_types[0]
1✔
136
            if not isinstance(signature, str):
1✔
137
                # depending on 'from _future_ import annotations' this can be a string or a type
138
                signature = signature.__qualname__
1✔
139

140
            match signature:
1✔
141
                case ProjectCreated.__qualname__:
1✔
142
                    queue_name = "project.created"
1✔
143
                case ProjectUpdated.__qualname__:
1✔
144
                    queue_name = "project.updated"
1✔
145
                case ProjectRemoved.__qualname__:
1✔
146
                    queue_name = "project.removed"
1✔
147
                case UserAdded.__qualname__:
1✔
148
                    queue_name = "user.added"
1✔
149
                case UserUpdated.__qualname__:
1✔
150
                    queue_name = "user.updated"
1✔
151
                case UserRemoved.__qualname__:
1✔
152
                    queue_name = "user.removed"
1✔
153
                case ProjectAuthorizationAdded.__qualname__:
1✔
154
                    queue_name = "projectAuth.added"
1✔
155
                case ProjectAuthorizationUpdated.__qualname__:
1✔
156
                    queue_name = "projectAuth.updated"
1✔
157
                case ProjectAuthorizationRemoved.__qualname__:
1✔
158
                    queue_name = "projectAuth.removed"
1✔
159
                case _:
×
160
                    raise NotImplementedError(f"Can't create message using transform {transform}:{signature}")
×
161
            headers = create_header(queue_name)
1✔
162
            message_id = ULID().hex
1✔
163
            message: dict[bytes | memoryview | str | int | float, bytes | memoryview | str | int | float] = {
1✔
164
                "id": message_id,
165
                "headers": headers.serialize_json(),
166
                "payload": base64.b64encode(serialize_binary(payload)).decode(),
167
            }
168
            event_id = await self.event_repo.store_event(session, queue_name, message)
1✔
169
            session.commit()
1✔
170

171
            try:
1✔
172
                await self.message_queue.send_message(queue_name, message)
1✔
173
            except:  # noqa:E722
1✔
174
                return result
1✔
175
            await self.event_repo.delete_event(event_id)
1✔
176
            return result
1✔
177

178
        return message_wrapper
1✔
179

180
    return decorator
1✔
181

182

183
@dataclass
1✔
184
class RedisQueue(IMessageQueue):
1✔
185
    """Redis streams queue implementation."""
1✔
186

187
    config: RedisConfig
1✔
188

189
    async def send_message(
1✔
190
        self,
191
        channel: str,
192
        message: dict[bytes | memoryview | str | int | float, bytes | memoryview | str | int | float],
193
    ):
194
        """Send a message on a channel."""
195
        message = copy.copy(message)
1✔
196
        if "payload" in message:
1✔
197
            message["payload"] = base64.b64decode(message["payload"])  # type: ignore
1✔
198

199
        await self.config.redis_connection.xadd(channel, message)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc