• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 14382014257

10 Apr 2025 01:42PM UTC coverage: 86.576% (+0.2%) from 86.351%
14382014257

Pull #759

github

web-flow
Merge 470ff1568 into 74eb7d965
Pull Request #759: feat: add new service cache and migrations

412 of 486 new or added lines in 15 files covered. (84.77%)

18 existing lines in 6 files now uncovered.

20232 of 23369 relevant lines covered (86.58%)

1.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

63.1
/components/renku_data_services/notebooks/api/classes/k8s_client.py
1
"""An abstraction over the kr8s kubernetes client and the k8s-watcher."""
2

3
import base64
2✔
4
import json
2✔
5
from contextlib import suppress
2✔
6
from typing import Any, Generic, Optional, TypeVar, cast
2✔
7

8
import httpx
2✔
9
from box import Box
2✔
10
from kr8s import NotFoundError, ServerError
2✔
11
from kr8s.asyncio.objects import APIObject, Pod, Secret, StatefulSet
2✔
12
from kubernetes.client import ApiClient, V1Secret
2✔
13

14
from renku_data_services.errors import errors
2✔
15
from renku_data_services.k8s_watcher.db import CachedK8sClient
2✔
16
from renku_data_services.k8s_watcher.models import ClusterId, K8sObject, K8sObjectMeta, ListFilter
2✔
17
from renku_data_services.notebooks.api.classes.auth import GitlabToken, RenkuTokens
2✔
18
from renku_data_services.notebooks.constants import JUPYTER_SESSION_KIND, JUPYTER_SESSION_VERSION
2✔
19
from renku_data_services.notebooks.crs import AmaltheaSessionV1Alpha1, JupyterServerV1Alpha1
2✔
20
from renku_data_services.notebooks.errors.programming import ProgrammingError
2✔
21
from renku_data_services.notebooks.util.kubernetes_ import find_env_var
2✔
22

23
sanitize_for_serialization = ApiClient().sanitize_for_serialization
2✔
24

25

26
# NOTE The type ignore below is because the kr8s library has no type stubs, they claim pyright better handles type hints
27
class JupyterServerV1Alpha1Kr8s(APIObject):
2✔
28
    """Spec for jupyter servers used by the k8s client."""
29

30
    kind: str = JUPYTER_SESSION_KIND
2✔
31
    version: str = JUPYTER_SESSION_VERSION
2✔
32
    namespaced: bool = True
2✔
33
    plural: str = "jupyterservers"
2✔
34
    singular: str = "jupyterserver"
2✔
35
    scalable: bool = False
2✔
36
    endpoint: str = "jupyterservers"
2✔
37

38

39
_SessionType = TypeVar("_SessionType", JupyterServerV1Alpha1, AmaltheaSessionV1Alpha1)
2✔
40

41

42
class K8sClient(Generic[_SessionType]):
2✔
43
    """The K8s client that combines a namespaced client and a jupyter server cache."""
44

45
    def __init__(
2✔
46
        self,
47
        cached_client: CachedK8sClient,
48
        username_label: str,
49
        namespace: str,
50
        cluster: ClusterId,
51
        server_type: type[_SessionType],
52
        server_kind: str,
53
        server_api_version: str,
54
    ):
55
        self.cached_client: CachedK8sClient = cached_client
2✔
56
        self.username_label = username_label
2✔
57
        self.namespace = namespace
2✔
58
        self.cluster = cluster
2✔
59
        self.server_type: type[_SessionType] = server_type
2✔
60
        self.server_kind = server_kind
2✔
61
        self.server_api_version = server_api_version
2✔
62
        if not self.username_label:
2✔
NEW
63
            raise ProgrammingError("username_label has to be provided to K8sClient")
×
64

65
    async def list_servers(self, safe_username: str) -> list[_SessionType]:
2✔
66
        """Get a list of servers that belong to a user.
67

68
        Attempt to use the cache first but if the cache fails then use the k8s API.
69
        """
70
        return [
1✔
71
            self.server_type.model_validate(s.manifest)
72
            async for s in self.cached_client.list(
73
                ListFilter(
74
                    kind=self.server_kind,
75
                    version=self.server_api_version,
76
                    user_id=safe_username,
77
                    label_selector={self.username_label: safe_username},
78
                    namespace=self.namespace,
79
                )
80
            )
81
        ]
82

83
    async def get_server(self, name: str, safe_username: str) -> _SessionType | None:
2✔
84
        """Attempt to get a specific server by name from the cache.
85

86
        If the request to the cache fails, fallback to the k8s API.
87
        """
88
        server = await self.cached_client.get(
1✔
89
            K8sObjectMeta(
90
                kind=self.server_kind,
91
                version=self.server_api_version,
92
                user_id=safe_username,
93
                name=name,
94
                namespace=self.namespace,
95
                cluster=self.cluster,
96
            )
97
        )
98
        if server is None:
1✔
99
            return None
1✔
100
        server = self.server_type.model_validate(server.manifest)
1✔
101

102
        return server
1✔
103

104
    async def get_server_logs(
2✔
105
        self, server_name: str, safe_username: str, max_log_lines: Optional[int] = None
106
    ) -> dict[str, str]:
107
        """Get the logs from the server."""
108
        # NOTE: this get_server ensures the user has access to the server without it you could read someone elses logs
109
        server = await self.get_server(server_name, safe_username)
1✔
110
        if not server:
1✔
111
            raise errors.MissingResourceError(
1✔
112
                message=f"Cannot find server {server_name} for user {safe_username} to retrieve logs."
113
            )
114
        pod_name = f"{server_name}-0"
1✔
115
        result = await self.cached_client.get_api_object(
1✔
116
            K8sObjectMeta(
117
                name=pod_name, namespace=self.namespace, cluster=self.cluster, kind=Pod.kind, version=Pod.version
118
            )
119
        )
120
        logs: dict[str, str] = {}
1✔
121
        if result is None:
1✔
NEW
122
            return logs
×
123
        pod = Pod(resource=result.obj, namespace=result.meta.namespace, api=result.obj.api)
1✔
124
        containers = [container.name for container in pod.spec.containers + pod.spec.get("initContainers", [])]
1✔
125
        for container in containers:
1✔
126
            try:
1✔
127
                # NOTE: calling pod.logs without a container name set crashes the library
128
                clogs: list[str] = [clog async for clog in pod.logs(container=container, tail_lines=max_log_lines)]
1✔
129
            except httpx.ResponseNotRead:
×
130
                # NOTE: This occurs when the container is still starting but we try to read its logs
131
                continue
×
132
            except NotFoundError:
×
NEW
133
                raise errors.MissingResourceError(message=f"The session pod {pod_name} does not exist.")
×
134
            except ServerError as err:
×
135
                if err.status == 404:
×
NEW
136
                    raise errors.MissingResourceError(message=f"The session pod {pod_name} does not exist.")
×
137
                raise
×
138
            else:
139
                logs[container] = "\n".join(clogs)
1✔
140
        return logs
1✔
141

142
    async def create_server(self, manifest: _SessionType, safe_username: str) -> _SessionType:
2✔
143
        """Create a server."""
144
        server_name = manifest.metadata.name
1✔
145
        server = await self.get_server(server_name, safe_username)
1✔
146
        if server:
1✔
147
            # NOTE: server already exists
NEW
148
            return server
×
149
        manifest.metadata.labels[self.username_label] = safe_username
1✔
150
        result = await self.cached_client.create(
1✔
151
            K8sObject(
152
                name=server_name,
153
                namespace=self.namespace,
154
                cluster=self.cluster,
155
                kind=self.server_kind,
156
                version=self.server_api_version,
157
                user_id=safe_username,
158
                manifest=Box(manifest.model_dump(exclude_none=True, mode="json")),
159
            )
160
        )
161
        return self.server_type.model_validate(result.manifest)
1✔
162

163
    async def patch_server(
2✔
164
        self, server_name: str, safe_username: str, patch: dict[str, Any] | list[dict[str, Any]]
165
    ) -> _SessionType:
166
        """Patch a server."""
167
        server = await self.cached_client.get(
1✔
168
            K8sObjectMeta(
169
                kind=self.server_kind,
170
                version=self.server_api_version,
171
                user_id=safe_username,
172
                name=server_name,
173
                namespace=self.namespace,
174
                cluster=self.cluster,
175
            )
176
        )
177
        if not server:
1✔
NEW
178
            raise errors.MissingResourceError(
×
179
                message=f"Cannot find server {server_name} for user {safe_username} in order to patch it."
180
            )
181
        result = await self.cached_client.patch(server, patch=patch)
1✔
182
        return self.server_type.model_validate(result.manifest)
1✔
183

184
    async def patch_statefulset(
2✔
185
        self, server_name: str, patch: dict[str, Any] | list[dict[str, Any]]
186
    ) -> StatefulSet | None:
187
        """Patch a statefulset."""
NEW
188
        result = await self.cached_client.get_api_object(
×
189
            K8sObjectMeta(
190
                name=server_name,
191
                namespace=self.namespace,
192
                cluster=self.cluster,
193
                kind=StatefulSet.kind,
194
                version=StatefulSet.version,
195
            )
196
        )
NEW
197
        if result is None:
×
UNCOV
198
            return None
×
NEW
199
        sts = StatefulSet(resource=result.obj, namespace=result.meta.namespace, api=result.obj.api)
×
NEW
200
        await sts.patch(patch=patch)
×
201

NEW
202
        return sts
×
203

204
    async def delete_server(self, server_name: str, safe_username: str) -> None:
2✔
205
        """Delete the server."""
206
        return await self.cached_client.delete(
1✔
207
            K8sObjectMeta(
208
                kind=self.server_kind,
209
                version=self.server_api_version,
210
                user_id=safe_username,
211
                name=server_name,
212
                namespace=self.namespace,
213
                cluster=self.cluster,
214
            )
215
        )
216

217
    async def patch_tokens(self, server_name: str, renku_tokens: RenkuTokens, gitlab_token: GitlabToken) -> None:
2✔
218
        """Patch the Renku and Gitlab access tokens used in a session."""
NEW
219
        result = await self.cached_client.get_api_object(
×
220
            K8sObjectMeta(
221
                name=server_name,
222
                namespace=self.namespace,
223
                cluster=self.cluster,
224
                kind=StatefulSet.kind,
225
                version=StatefulSet.version,
226
            )
227
        )
NEW
228
        if result is None:
×
NEW
229
            return None
×
NEW
230
        sts = StatefulSet(resource=result.obj, namespace=result.meta.namespace, api=result.obj.api)
×
NEW
231
        patches = self._get_statefulset_token_patches(sts, renku_tokens)
×
NEW
232
        await sts.patch(patch=patches, type="json")
×
NEW
233
        await self.patch_image_pull_secret(server_name, gitlab_token)
×
234

235
    async def patch_image_pull_secret(self, server_name: str, gitlab_token: GitlabToken) -> None:
2✔
236
        """Patch the image pull secret used in a Renku session."""
237
        secret_name = f"{server_name}-image-secret"
×
NEW
238
        result = await self.cached_client.get_api_object(
×
239
            K8sObjectMeta(
240
                name=secret_name,
241
                namespace=self.namespace,
242
                cluster=self.cluster,
243
                kind=Secret.kind,
244
                version=Secret.version,
245
            )
246
        )
NEW
247
        if result is None:
×
248
            return None
×
NEW
249
        secret = Secret(resource=result.obj, namespace=result.meta.namespace, api=result.obj.api)
×
250

251
        secret_data = secret.data.to_dict()
×
252
        old_docker_config = json.loads(base64.b64decode(secret_data[".dockerconfigjson"]).decode())
×
253
        hostname = next(iter(old_docker_config["auths"].keys()), None)
×
254
        if not hostname:
×
255
            raise ProgrammingError(
×
256
                "Failed to refresh the access credentials in the image pull secret.",
257
                detail="Please contact a Renku administrator.",
258
            )
259
        new_docker_config = {
×
260
            "auths": {
261
                hostname: {
262
                    "Username": "oauth2",
263
                    "Password": gitlab_token.access_token,
264
                    "Email": old_docker_config["auths"][hostname]["Email"],
265
                }
266
            }
267
        }
268
        patch_path = "/data/.dockerconfigjson"
×
269
        patch = [
×
270
            {
271
                "op": "replace",
272
                "path": patch_path,
273
                "value": base64.b64encode(json.dumps(new_docker_config).encode()).decode(),
274
            }
275
        ]
276
        await secret.patch(patch, type="json")
×
277

278
    @staticmethod
2✔
279
    def _get_statefulset_token_patches(sts: StatefulSet, renku_tokens: RenkuTokens) -> list[dict[str, str]]:
2✔
280
        """Patch the Renku and Gitlab access tokens that are used in the session statefulset."""
281
        containers = cast(list[Box], sts.spec.template.spec.containers)
1✔
282
        init_containers = cast(list[Box], sts.spec.template.spec.initContainers)
1✔
283

284
        git_proxy_container_index, git_proxy_container = next(
1✔
285
            ((i, c) for i, c in enumerate(containers) if c.name == "git-proxy"),
286
            (None, None),
287
        )
288
        git_clone_container_index, git_clone_container = next(
1✔
289
            ((i, c) for i, c in enumerate(init_containers) if c.name == "git-clone"),
290
            (None, None),
291
        )
292
        secrets_container_index, secrets_container = next(
1✔
293
            ((i, c) for i, c in enumerate(init_containers) if c.name == "init-user-secrets"),
294
            (None, None),
295
        )
296

297
        def _get_env(container: Box) -> list[Box]:
1✔
298
            return cast(list[Box], container.env)
1✔
299

300
        git_proxy_renku_access_token_env = (
1✔
301
            find_env_var(_get_env(git_proxy_container), "GIT_PROXY_RENKU_ACCESS_TOKEN")
302
            if git_proxy_container is not None
303
            else None
304
        )
305
        git_proxy_renku_refresh_token_env = (
1✔
306
            find_env_var(_get_env(git_proxy_container), "GIT_PROXY_RENKU_REFRESH_TOKEN")
307
            if git_proxy_container is not None
308
            else None
309
        )
310
        git_clone_renku_access_token_env = (
1✔
311
            find_env_var(_get_env(git_clone_container), "GIT_CLONE_USER__RENKU_TOKEN")
312
            if git_clone_container is not None
313
            else None
314
        )
315
        secrets_renku_access_token_env = (
1✔
316
            find_env_var(_get_env(secrets_container), "RENKU_ACCESS_TOKEN") if secrets_container is not None else None
317
        )
318

319
        patches = list()
1✔
320
        if git_proxy_container_index is not None and git_proxy_renku_access_token_env is not None:
1✔
321
            patches.append(
1✔
322
                {
323
                    "op": "replace",
324
                    "path": (
325
                        f"/spec/template/spec/containers/{git_proxy_container_index}"
326
                        f"/env/{git_proxy_renku_access_token_env[0]}/value"
327
                    ),
328
                    "value": renku_tokens.access_token,
329
                }
330
            )
331
        if git_proxy_container_index is not None and git_proxy_renku_refresh_token_env is not None:
1✔
332
            patches.append(
1✔
333
                {
334
                    "op": "replace",
335
                    "path": (
336
                        f"/spec/template/spec/containers/{git_proxy_container_index}"
337
                        f"/env/{git_proxy_renku_refresh_token_env[0]}/value"
338
                    ),
339
                    "value": renku_tokens.refresh_token,
340
                },
341
            )
342
        if git_clone_container_index is not None and git_clone_renku_access_token_env is not None:
1✔
343
            patches.append(
1✔
344
                {
345
                    "op": "replace",
346
                    "path": (
347
                        f"/spec/template/spec/initContainers/{git_clone_container_index}"
348
                        f"/env/{git_clone_renku_access_token_env[0]}/value"
349
                    ),
350
                    "value": renku_tokens.access_token,
351
                },
352
            )
353
        if secrets_container_index is not None and secrets_renku_access_token_env is not None:
1✔
354
            patches.append(
1✔
355
                {
356
                    "op": "replace",
357
                    "path": (
358
                        f"/spec/template/spec/initContainers/{secrets_container_index}"
359
                        f"/env/{secrets_renku_access_token_env[0]}/value"
360
                    ),
361
                    "value": renku_tokens.access_token,
362
                },
363
            )
364

365
        return patches
1✔
366

367
    @property
2✔
368
    def preferred_namespace(self) -> str:
2✔
369
        """Get the preferred namespace for creating jupyter servers."""
370
        return self.namespace
1✔
371

372
    async def create_secret(self, secret: V1Secret) -> V1Secret:
2✔
373
        """Create a secret."""
NEW
374
        assert secret.metadata is not None
×
NEW
375
        secret_obj = K8sObject(
×
376
            name=secret.metadata.name,
377
            namespace=self.namespace,
378
            cluster=self.cluster,
379
            kind=Secret.kind,
380
            version=Secret.version,
381
            manifest=Box(sanitize_for_serialization(secret)),
382
        )
383
        try:
×
NEW
384
            result = await self.cached_client.create(secret_obj)
×
385
        except ServerError as err:
×
386
            if err.response and err.response.status_code == 409:
×
NEW
387
                annotations: Box | None = secret_obj.manifest.metadata.get("annotations")
×
NEW
388
                labels: Box | None = secret_obj.manifest.metadata.get("labels")
×
UNCOV
389
                patches = [
×
390
                    {
391
                        "op": "replace",
392
                        "path": "/data",
393
                        "value": secret.data or {},
394
                    },
395
                    {
396
                        "op": "replace",
397
                        "path": "/stringData",
398
                        "value": secret.string_data or {},
399
                    },
400
                    {
401
                        "op": "replace",
402
                        "path": "/metadata/annotations",
403
                        "value": annotations.to_dict() if annotations is not None else {},
404
                    },
405
                    {
406
                        "op": "replace",
407
                        "path": "/metadata/labels",
408
                        "value": labels.to_dict() if labels is not None else {},
409
                    },
410
                ]
NEW
411
                result = await self.cached_client.patch(secret_obj, patches)
×
412
            else:
NEW
413
                raise
×
NEW
414
        return V1Secret(
×
415
            metadata=result.manifest.metadata,
416
            data=result.manifest.get("data", {}),
417
            string_data=result.manifest.get("stringData", {}),
418
            type=result.manifest.get("type"),
419
        )
420

421
    async def delete_secret(self, name: str) -> None:
2✔
422
        """Delete a secret."""
NEW
423
        return await self.cached_client.delete(
×
424
            K8sObjectMeta(
425
                name=name,
426
                namespace=self.namespace,
427
                cluster=self.cluster,
428
                kind=Secret.kind,
429
                version=Secret.version,
430
            )
431
        )
432

433
    async def patch_secret(self, name: str, patch: dict[str, Any] | list[dict[str, Any]]) -> None:
2✔
434
        """Patch a secret."""
NEW
435
        result = await self.cached_client.get_api_object(
×
436
            K8sObjectMeta(
437
                name=name,
438
                namespace=self.namespace,
439
                cluster=self.cluster,
440
                kind=Secret.kind,
441
                version=Secret.version,
442
            )
443
        )
NEW
444
        if result is None:
×
NEW
445
            raise errors.MissingResourceError(message=f"Cannot find secret {name}.")
×
NEW
446
        secret = result.obj
×
NEW
447
        assert isinstance(secret, Secret)
×
448

449
        patch_type: str | None = None  # rfc7386 patch
×
450
        if isinstance(patch, list):
×
451
            patch_type = "json"  # rfc6902 patch
×
452

453
        with suppress(NotFoundError):
×
454
            await secret.patch(patch, type=patch_type)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc