• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 14382014257

10 Apr 2025 01:42PM UTC coverage: 86.576% (+0.2%) from 86.351%
14382014257

Pull #759

github

web-flow
Merge 470ff1568 into 74eb7d965
Pull Request #759: feat: add new service cache and migrations

412 of 486 new or added lines in 15 files covered. (84.77%)

18 existing lines in 6 files now uncovered.

20232 of 23369 relevant lines covered (86.58%)

1.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.6
/components/renku_data_services/k8s_watcher/db.py
1
"""K8s watcher database and k8s wrappers."""
2

3
from __future__ import annotations
2✔
4

5
import asyncio
2✔
6
import contextlib
2✔
7
from asyncio import Task
2✔
8
from collections.abc import AsyncIterable, Awaitable, Callable
2✔
9
from dataclasses import dataclass
2✔
10
from datetime import timedelta
2✔
11
from typing import Any, Self, cast
2✔
12

13
import kr8s
2✔
14
import sqlalchemy
2✔
15
from box import Box
2✔
16
from kr8s.asyncio import Api
2✔
17
from kr8s.asyncio.objects import APIObject
2✔
18
from sanic.log import logger
2✔
19
from sqlalchemy import bindparam, select
2✔
20
from sqlalchemy.dialects.postgresql import JSONB
2✔
21
from sqlalchemy.ext.asyncio import AsyncSession
2✔
22

23
from renku_data_services.errors import errors
2✔
24
from renku_data_services.k8s_watcher.models import ClusterId, K8sObject, K8sObjectMeta, ListFilter
2✔
25
from renku_data_services.k8s_watcher.orm import K8sObjectORM
2✔
26

27

28
@dataclass(eq=True, frozen=True)
2✔
29
class Cluster:
2✔
30
    """Representation of a k8s cluster."""
31

32
    id: ClusterId
2✔
33
    namespace: str
2✔
34
    api: Api
2✔
35

36

37
@dataclass
2✔
38
class APIObjectInCluster:
2✔
39
    """An kr8s k8s object from a specific cluster."""
40

41
    obj: APIObject
2✔
42
    cluster: ClusterId
2✔
43

44
    @property
2✔
45
    def user_id(self) -> str | None:
2✔
46
        """Extract the user id from annotations."""
47
        user_id = user_id_from_api_object(self.obj)
1✔
48
        return user_id
1✔
49

50
    @property
2✔
51
    def meta(self) -> K8sObjectMeta:
2✔
52
        """Extract the metadata from an api object."""
53
        return K8sObjectMeta(
1✔
54
            name=self.obj.name,
55
            namespace=self.obj.namespace or "default",
56
            cluster=self.cluster,
57
            version=self.obj.version,
58
            kind=self.obj.kind,
59
            user_id=self.user_id,
60
        )
61

62
    def to_k8s_object(self) -> K8sObject:
2✔
63
        """Convert the api object to a regular k8s object."""
64
        if self.obj.name is None or self.obj.namespace is None:
1✔
NEW
65
            raise errors.ProgrammingError()
×
66
        return K8sObject(
1✔
67
            name=self.obj.name,
68
            namespace=self.obj.namespace,
69
            kind=self.obj.kind,
70
            version=self.obj.version,
71
            manifest=Box(self.obj.to_dict()),
72
            cluster=self.cluster,
73
            user_id=self.user_id,
74
        )
75

76
    @classmethod
2✔
77
    def from_k8s_object(cls, obj: K8sObject, api: Api | None = None) -> Self:
2✔
78
        """Convert a regular k8s object to an api object."""
79

80
        class _APIObj(APIObject):
1✔
81
            kind = obj.meta.kind
1✔
82
            version = obj.meta.version
1✔
83
            singular = obj.meta.singular
1✔
84
            plural = obj.meta.plural
1✔
85
            endpoint = obj.meta.plural
1✔
86
            namespaced = obj.meta.namespaced
1✔
87

88
        return cls(
1✔
89
            obj=_APIObj(
90
                resource=obj.manifest,
91
                namespace=obj.meta.namespace,
92
                api=api,
93
            ),
94
            cluster=obj.cluster,
95
        )
96

97

98
type EventHandler = Callable[[APIObjectInCluster], Awaitable[None]]
2✔
99

100

101
class K8sClient:
2✔
102
    """A wrapper around a kr8s k8s client, acts on all resources over many clusters."""
103

104
    def __init__(self, clusters: dict[ClusterId, Cluster]) -> None:
2✔
105
        self.__clusters = clusters
2✔
106

107
    def __get_cluster_or_die(self, cluster_id: ClusterId) -> Cluster:
2✔
108
        cluster = self.__clusters.get(cluster_id)
1✔
109
        if not cluster:
1✔
NEW
110
            raise errors.MissingResourceError(
×
111
                message=f"Could not find cluster with id {cluster_id} in the list of clusters."
112
            )
113
        return cluster
1✔
114

115
    async def create(self, obj: K8sObject) -> K8sObject:
2✔
116
        """Create the k8s object."""
117
        cluster = self.__get_cluster_or_die(obj.cluster)
1✔
118
        api_obj = APIObjectInCluster.from_k8s_object(obj, cluster.api)
1✔
119
        await api_obj.obj.create()
1✔
120
        return api_obj.meta.with_manifest(api_obj.obj.to_dict())
1✔
121

122
    async def patch(self, meta: K8sObjectMeta, patch: dict[str, Any] | list[dict[str, Any]]) -> K8sObject:
2✔
123
        """Patch a k8s object.
124

125
        If the patch is a list we assume that we have a rfc6902 json patch like
126
        `[{ "op": "add", "path": "/a/b/c", "value": [ "foo", "bar" ] }]`.
127
        If the patch is a dictionary then it is considered to be a rfc7386 json merge patch.
128
        """
129
        obj = await self._get(meta)
1✔
130
        if not obj:
1✔
NEW
131
            raise errors.MissingResourceError(message=f"The k8s resource with metadata {meta} cannot be found.")
×
132
        patch_type = "json" if isinstance(patch, list) else None
1✔
133
        await obj.obj.patch(patch, type=patch_type)
1✔
134
        return meta.with_manifest(obj.obj.to_dict())
1✔
135

136
    async def delete(self, meta: K8sObjectMeta) -> None:
2✔
137
        """Delete a k8s object."""
138
        obj = await self._get(meta)
1✔
139
        if not obj:
1✔
140
            return None
1✔
141
        with contextlib.suppress(kr8s.NotFoundError):
1✔
142
            await obj.obj.delete(propagation_policy="Foreground")
1✔
143

144
    async def _get(self, meta: K8sObjectMeta) -> APIObjectInCluster | None:
2✔
145
        return await anext(aiter(self.__list(meta.to_list_filter())), None)
1✔
146

147
    async def get(self, meta: K8sObjectMeta) -> K8sObject | None:
2✔
148
        """Get a specific k8s object, None is returned if the object does not exist."""
NEW
149
        obj = await self._get(meta)
×
NEW
150
        if not obj:
×
NEW
151
            return None
×
NEW
152
        return meta.with_manifest(obj.obj.to_dict())
×
153

154
    async def __list(self, filter: ListFilter) -> AsyncIterable[APIObjectInCluster]:
2✔
155
        clusters = list(self.__clusters.values())
1✔
156
        if filter.cluster:
1✔
157
            single_cluster = self.__clusters.get(filter.cluster)
1✔
158
            clusters = [single_cluster] if single_cluster else []
1✔
159
        for cluster in clusters:
1✔
160
            if filter.namespace is not None and filter.namespace != cluster.namespace:
1✔
NEW
161
                continue
×
162
            names = [filter.name] if filter.name else []
1✔
163

164
            try:
1✔
165
                res = await cluster.api.async_get(
1✔
166
                    filter.kind,
167
                    *names,
168
                    label_selector=filter.label_selector,
169
                    namespace=filter.namespace,
170
                )
NEW
171
            except (kr8s.ServerError, kr8s.APITimeoutError):
×
NEW
172
                continue
×
173

174
            if not isinstance(res, list):
1✔
NEW
175
                res = [res]
×
176
            for r in res:
1✔
177
                yield APIObjectInCluster(r, cluster.id)
1✔
178

179
    async def list(self, filter: ListFilter) -> AsyncIterable[K8sObject]:
2✔
180
        """List all k8s objects."""
NEW
181
        results = self.__list(filter)
×
NEW
182
        async for r in results:
×
NEW
183
            yield r.to_k8s_object()
×
184

185

186
class K8sDbCache:
2✔
187
    """Caching k8s objects in postgres."""
188

189
    def __init__(self, session_maker: Callable[..., AsyncSession]) -> None:
2✔
190
        self.__session_maker = session_maker
2✔
191

192
    async def __get(self, meta: K8sObjectMeta, session: AsyncSession) -> K8sObjectORM | None:
2✔
193
        stmt = (
1✔
194
            select(K8sObjectORM)
195
            .where(K8sObjectORM.name == meta.name)
196
            .where(K8sObjectORM.namespace == meta.namespace)
197
            .where(K8sObjectORM.cluster == meta.cluster)
198
            .where(K8sObjectORM.kind == meta.kind)
199
            .where(K8sObjectORM.version == meta.version)
200
        )
201
        if meta.user_id is not None:
1✔
202
            stmt = stmt.where(K8sObjectORM.user_id == meta.user_id)
1✔
203

204
        obj_orm = await session.scalar(stmt)
1✔
205
        return obj_orm
1✔
206

207
    async def upsert(self, obj: K8sObject) -> None:
2✔
208
        """Insert or update an object in the cache."""
209
        if obj.user_id is None:
1✔
NEW
210
            raise errors.ValidationError(message="user_id is required to upsert k8s object.")
×
211
        async with self.__session_maker() as session, session.begin():
1✔
212
            obj_orm = await self.__get(obj.meta, session)
1✔
213
            if obj_orm is not None:
1✔
214
                obj_orm.manifest = obj.manifest
1✔
215
                await session.flush()
1✔
216
                return
1✔
217
            obj_orm = K8sObjectORM(
1✔
218
                name=obj.name,
219
                namespace=obj.namespace or "default",
220
                kind=obj.kind,
221
                version=obj.version,
222
                manifest=obj.manifest,
223
                cluster=obj.cluster,
224
                user_id=obj.user_id,
225
            )
226
            session.add(obj_orm)
1✔
227
            await session.flush()
1✔
228
            return
1✔
229

230
    async def delete(self, meta: K8sObjectMeta) -> None:
2✔
231
        """Delete an object from the cache."""
232
        async with self.__session_maker() as session, session.begin():
1✔
233
            obj_orm = await self.__get(meta, session)
1✔
234
            if obj_orm is None:
1✔
235
                return
1✔
236
            await session.delete(obj_orm)
1✔
237
            return
1✔
238

239
    async def get(self, meta: K8sObjectMeta) -> K8sObject | None:
2✔
240
        """Get a single object from the cache."""
241
        async with self.__session_maker() as session, session.begin():
1✔
242
            obj = await self.__get(meta, session)
1✔
243
            if not obj:
1✔
244
                return None
1✔
245
            return meta.with_manifest(obj.manifest)
1✔
246

247
    async def list(self, filter: ListFilter) -> AsyncIterable[K8sObject]:
2✔
248
        """List objects from the cache."""
249
        async with self.__session_maker() as session, session.begin():
1✔
250
            stmt = select(K8sObjectORM)
1✔
251
            if filter.name:
1✔
NEW
252
                stmt = stmt.where(K8sObjectORM.name == filter.name)
×
253
            if filter.namespace:
1✔
254
                stmt = stmt.where(K8sObjectORM.namespace == filter.namespace)
1✔
255
            if filter.cluster:
1✔
NEW
256
                stmt = stmt.where(K8sObjectORM.cluster == filter.cluster)
×
257
            if filter.kind:
1✔
258
                stmt = stmt.where(K8sObjectORM.kind == filter.kind)
1✔
259
            if filter.version:
1✔
260
                stmt = stmt.where(K8sObjectORM.version == filter.version)
1✔
261
            if filter.user_id:
1✔
262
                stmt = stmt.where(K8sObjectORM.user_id == filter.user_id)
1✔
263
            if filter.label_selector:
1✔
264
                stmt = stmt.where(
1✔
265
                    # K8sObjectORM.manifest.comparator.contains({"metadata": {"labels": filter.label_selector}})
266
                    sqlalchemy.text("manifest -> 'metadata' -> 'labels' @> :labels").bindparams(
267
                        bindparam("labels", filter.label_selector, type_=JSONB)
268
                    )
269
                )
270
            async for res in await session.stream_scalars(stmt):
1✔
271
                yield res.dump()
1✔
272

273

274
class CachedK8sClient(K8sClient):
2✔
275
    """A wrapper around a kr8s k8s client.
276

277
    Provides access to a cache for listing and reading resources but fallback to the cluster for other operations.
278
    """
279

280
    def __init__(self, clusters: dict[ClusterId, Cluster], cache: K8sDbCache, kinds_to_cache: list[str]) -> None:
2✔
281
        super().__init__(clusters)
2✔
282
        self.cache = cache
2✔
283
        self.__kinds_to_cache = kinds_to_cache
2✔
284

285
    async def create(self, obj: K8sObject) -> K8sObject:
2✔
286
        """Create the k8s object."""
287
        obj = await super().create(obj)
1✔
288
        if obj.meta.kind in self.__kinds_to_cache:
1✔
289
            await self.cache.upsert(obj)
1✔
290
        return obj
1✔
291

292
    async def patch(self, meta: K8sObjectMeta, patch: dict[str, Any] | list[dict[str, Any]]) -> K8sObject:
2✔
293
        """Patch a k8s object."""
294
        obj = await super().patch(meta, patch)
1✔
295
        if meta.kind in self.__kinds_to_cache:
1✔
296
            await self.cache.upsert(obj)
1✔
297
        return obj
1✔
298

299
    async def delete(self, meta: K8sObjectMeta) -> None:
2✔
300
        """Delete a k8s object."""
301
        await super().delete(meta)
1✔
302
        if meta.kind in self.__kinds_to_cache:
1✔
303
            await self.cache.delete(meta)
1✔
304

305
    async def get(self, meta: K8sObjectMeta) -> K8sObject | None:
2✔
306
        """Get a specific k8s object, None is returned if the object does not exist."""
307
        if meta.kind in self.__kinds_to_cache:
1✔
308
            res = await self.cache.get(meta)
1✔
309
        else:
NEW
310
            res = await super().get(meta)
×
311
        if res is None:
1✔
312
            return None
1✔
313
        return res
1✔
314

315
    async def get_api_object(self, meta: K8sObjectMeta) -> APIObjectInCluster | None:
2✔
316
        """Get a kr8s object directly, bypassing the cache.
317

318
        Note: only use this if you actually need to do k8s operations.
319
        """
320
        res = await super()._get(meta)
1✔
321
        if res is None:
1✔
NEW
322
            return None
×
323
        return res
1✔
324

325
    async def list(self, filter: ListFilter) -> AsyncIterable[K8sObject]:
2✔
326
        """List all k8s objects."""
327
        results = self.cache.list(filter) if filter.kind in self.__kinds_to_cache else super().list(filter)
1✔
328
        async for res in results:
1✔
329
            yield res
1✔
330

331

332
class K8sWatcher:
2✔
333
    """Watch k8s events and call the handler with every event."""
334

335
    def __init__(self, handler: EventHandler, clusters: dict[ClusterId, Cluster], kinds: list[str]) -> None:
2✔
336
        self.__handler = handler
1✔
337
        self.__tasks: dict[ClusterId, Task] | None = None
1✔
338
        self.__kinds = kinds
1✔
339
        self.__clusters = clusters
1✔
340

341
    async def __watch_kind(self, kind: str, cluster: Cluster) -> None:
2✔
342
        while True:
1✔
343
            try:
1✔
344
                watch = cluster.api.async_watch(kind=kind, namespace=cluster.namespace)
1✔
345
                async for _, obj in watch:
1✔
346
                    await self.__handler(APIObjectInCluster(obj, cluster.id))
1✔
347
                    # in some cases, the kr8s loop above just never yields, especially if there's exceptions which
348
                    # can bypass async scheduling. This sleep here is as a last line of defence so this code does not
349
                    # execute indefinitely and prevent an other resource kind from being watched.
350
                    await asyncio.sleep(0)
1✔
351
            except Exception:
1✔
352
                # without sleeping, this can just hang the code as exceptions seem to bypass the async scheduler
353
                await asyncio.sleep(1)
1✔
354
                pass
1✔
355

356
    async def __run_single(self, cluster: Cluster) -> None:
2✔
357
        # The loops and error handling here will need some testing and love
358
        for kind in self.__kinds:
1✔
359
            asyncio.create_task(self.__watch_kind(kind, cluster))
1✔
360

361
    async def start(self) -> None:
2✔
362
        """Start the watcher."""
363
        if self.__tasks is None:
1✔
364
            self.__tasks = {}
1✔
365
        for cluster in self.__clusters.values():
1✔
366
            self.__tasks[cluster.id] = asyncio.create_task(self.__run_single(cluster))
1✔
367

368
    async def wait(self) -> None:
2✔
369
        """Wait for all tasks.
370

371
        This is mainly used to block the main function.
372
        """
NEW
373
        if self.__tasks is None:
×
NEW
374
            return
×
NEW
375
        await asyncio.gather(*self.__tasks.values())
×
376

377
    async def stop(self, timeout: timedelta = timedelta(seconds=10)) -> None:
2✔
378
        """Stop the watcher or timeout."""
379
        if self.__tasks is None:
1✔
NEW
380
            return
×
381
        for task in self.__tasks.values():
1✔
382
            if task.done():
1✔
383
                continue
1✔
NEW
384
            task.cancel()
×
NEW
385
            try:
×
NEW
386
                async with asyncio.timeout(timeout.total_seconds()):
×
387
                    # with contextlib.suppress(CancelledError):
NEW
388
                    await task
×
NEW
389
            except TimeoutError:
×
NEW
390
                logger.error("timeout trying to cancel k8s watcher task")
×
NEW
391
                continue
×
392

393

394
def k8s_object_handler(cache: K8sDbCache) -> EventHandler:
2✔
395
    """Listens and to k8s events and updates the cache."""
396

397
    async def handler(obj: APIObjectInCluster) -> None:
1✔
398
        if obj.obj.metadata.get("deletionTimestamp"):
1✔
399
            # The object is being deleted
400
            await cache.delete(obj.meta)
1✔
401
            return
1✔
402
        await cache.upsert(obj.to_k8s_object())
1✔
403

404
    return handler
1✔
405

406

407
def user_id_from_api_object(obj: APIObject) -> str | None:
2✔
408
    """Get the user id from an api object."""
409
    match obj.kind:
1✔
410
        case "JupyterServer":
1✔
411
            return cast(str, obj.metadata.labels["renku.io/userId"])
1✔
412
        case "AmaltheaSession":
1✔
NEW
413
            return cast(str, obj.metadata.labels["renku.io/safe-username"])
×
414
        case _:
1✔
415
            return None
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc