• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 19233917490

10 Nov 2025 01:51PM UTC coverage: 86.34% (-0.01%) from 86.352%
19233917490

Pull #1103

github

web-flow
Merge 6ac07dc99 into ffc3f8457
Pull Request #1103: feat: use data connectors in remote sessions

0 of 6 new or added lines in 1 file covered. (0.0%)

4 existing lines in 3 files now uncovered.

22792 of 26398 relevant lines covered (86.34%)

1.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

61.98
/components/renku_data_services/notebooks/core_sessions.py
1
"""A selection of core functions for AmaltheaSessions."""
2

3
import base64
2✔
4
import json
2✔
5
import os
2✔
6
import random
2✔
7
import string
2✔
8
from collections.abc import AsyncIterator, Mapping, Sequence
2✔
9
from datetime import timedelta
2✔
10
from pathlib import PurePosixPath
2✔
11
from typing import Protocol, TypeVar, cast
2✔
12
from urllib.parse import urljoin, urlparse
2✔
13

14
import httpx
2✔
15
from kubernetes.client import V1ObjectMeta, V1Secret
2✔
16
from sanic import Request
2✔
17
from toml import dumps
2✔
18
from ulid import ULID
2✔
19
from yaml import safe_dump
2✔
20

21
import renku_data_services.notebooks.image_check as ic
2✔
22
from renku_data_services.app_config import logging
2✔
23
from renku_data_services.base_models import RESET, AnonymousAPIUser, APIUser, AuthenticatedAPIUser, ResetType
2✔
24
from renku_data_services.base_models.metrics import MetricsService
2✔
25
from renku_data_services.connected_services.db import ConnectedServicesRepository
2✔
26
from renku_data_services.crc.db import ClusterRepository, ResourcePoolRepository
2✔
27
from renku_data_services.crc.models import (
2✔
28
    ClusterSettings,
29
    GpuKind,
30
    RemoteConfigurationFirecrest,
31
    ResourceClass,
32
    ResourcePool,
33
)
34
from renku_data_services.data_connectors.db import (
2✔
35
    DataConnectorSecretRepository,
36
)
37
from renku_data_services.data_connectors.models import DataConnectorSecret, DataConnectorWithSecrets
2✔
38
from renku_data_services.errors import errors
2✔
39
from renku_data_services.k8s.models import K8sSecret, sanitizer
2✔
40
from renku_data_services.notebooks import apispec, core
2✔
41
from renku_data_services.notebooks.api.amalthea_patches import git_proxy, init_containers
2✔
42
from renku_data_services.notebooks.api.amalthea_patches.init_containers import user_secrets_extras
2✔
43
from renku_data_services.notebooks.api.classes.image import Image
2✔
44
from renku_data_services.notebooks.api.classes.repository import GitProvider, Repository
2✔
45
from renku_data_services.notebooks.api.schemas.cloud_storage import RCloneStorage
2✔
46
from renku_data_services.notebooks.config import GitProviderHelperProto, NotebooksConfig
2✔
47
from renku_data_services.notebooks.crs import (
2✔
48
    AmaltheaSessionSpec,
49
    AmaltheaSessionV1Alpha1,
50
    AmaltheaSessionV1Alpha1MetadataPatch,
51
    AmaltheaSessionV1Alpha1Patch,
52
    AmaltheaSessionV1Alpha1SpecPatch,
53
    AmaltheaSessionV1Alpha1SpecSessionPatch,
54
    Authentication,
55
    AuthenticationType,
56
    Culling,
57
    CullingPatch,
58
    DataSource,
59
    ExtraContainer,
60
    ExtraVolume,
61
    ExtraVolumeMount,
62
    ImagePullPolicy,
63
    ImagePullSecret,
64
    Ingress,
65
    InitContainer,
66
    Limits,
67
    LimitsStr,
68
    Metadata,
69
    ReconcileStrategy,
70
    Requests,
71
    RequestsStr,
72
    Resources,
73
    ResourcesPatch,
74
    SecretAsVolume,
75
    SecretAsVolumeItem,
76
    Session,
77
    SessionEnvItem,
78
    SessionLocation,
79
    ShmSizeStr,
80
    SizeStr,
81
    State,
82
    Storage,
83
)
84
from renku_data_services.notebooks.models import (
2✔
85
    ExtraSecret,
86
    SessionDataConnectorOverride,
87
    SessionEnvVar,
88
    SessionExtraResources,
89
    SessionLaunchRequest,
90
)
91
from renku_data_services.notebooks.util.kubernetes_ import (
2✔
92
    renku_2_make_server_name,
93
)
94
from renku_data_services.notebooks.utils import (
2✔
95
    node_affinity_from_resource_class,
96
    node_affinity_patch_from_resource_class,
97
    tolerations_from_resource_class,
98
)
99
from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository
2✔
100
from renku_data_services.project.models import Project, SessionSecret
2✔
101
from renku_data_services.session.db import SessionRepository
2✔
102
from renku_data_services.session.models import SessionLauncher
2✔
103
from renku_data_services.users.db import UserRepo
2✔
104
from renku_data_services.utils.cryptography import get_encryption_key
2✔
105

106
logger = logging.getLogger(__name__)
2✔
107

108

109
async def get_extra_init_containers(
2✔
110
    nb_config: NotebooksConfig,
111
    user: AnonymousAPIUser | AuthenticatedAPIUser,
112
    repositories: list[Repository],
113
    git_providers: list[GitProvider],
114
    storage_mount: PurePosixPath,
115
    work_dir: PurePosixPath,
116
    uid: int = 1000,
117
    gid: int = 1000,
118
) -> SessionExtraResources:
119
    """Get all extra init containers that should be added to an amalthea session."""
120
    # TODO: The above statement is not correct: the init container for user secrets is not included here
121
    cert_init, cert_vols = init_containers.certificates_container(nb_config)
1✔
122
    session_init_containers = [InitContainer.model_validate(sanitizer(cert_init))]
1✔
123
    extra_volumes = [ExtraVolume.model_validate(sanitizer(volume)) for volume in cert_vols]
1✔
124
    git_clone = await init_containers.git_clone_container_v2(
1✔
125
        user=user,
126
        config=nb_config,
127
        repositories=repositories,
128
        git_providers=git_providers,
129
        workspace_mount_path=storage_mount,
130
        work_dir=work_dir,
131
        uid=uid,
132
        gid=gid,
133
    )
134
    if git_clone is not None:
1✔
135
        session_init_containers.append(InitContainer.model_validate(git_clone))
×
136
    return SessionExtraResources(
1✔
137
        init_containers=session_init_containers,
138
        volumes=extra_volumes,
139
    )
140

141

142
async def get_extra_containers(
2✔
143
    nb_config: NotebooksConfig,
144
    user: AnonymousAPIUser | AuthenticatedAPIUser,
145
    repositories: list[Repository],
146
    git_providers: list[GitProvider],
147
) -> SessionExtraResources:
148
    """Get the extra containers added to amalthea sessions."""
149
    conts: list[ExtraContainer] = []
1✔
150
    git_proxy_container = await git_proxy.main_container(
1✔
151
        user=user, config=nb_config, repositories=repositories, git_providers=git_providers
152
    )
153
    if git_proxy_container:
1✔
154
        conts.append(ExtraContainer.model_validate(sanitizer(git_proxy_container)))
×
155
    return SessionExtraResources(containers=conts)
1✔
156

157

158
async def get_auth_secret_authenticated(
2✔
159
    nb_config: NotebooksConfig,
160
    user: AuthenticatedAPIUser,
161
    server_name: str,
162
    base_server_url: str,
163
    base_server_https_url: str,
164
    base_server_path: str,
165
) -> ExtraSecret:
166
    """Get the extra secrets that need to be added to the session for an authenticated user."""
167
    secret_data = {}
×
168

169
    parsed_proxy_url = urlparse(urljoin(base_server_url + "/", "oauth2"))
×
170
    vol = ExtraVolume(
×
171
        name="renku-authorized-emails",
172
        secret=SecretAsVolume(
173
            secretName=server_name,
174
            items=[SecretAsVolumeItem(key="authorized_emails", path="authorized_emails")],
175
        ),
176
    )
177
    secret_data["auth"] = dumps(
×
178
        {
179
            "provider": "oidc",
180
            "client_id": nb_config.sessions.oidc.client_id,
181
            "oidc_issuer_url": nb_config.sessions.oidc.issuer_url,
182
            "session_cookie_minimal": True,
183
            "skip_provider_button": True,
184
            # NOTE: If the redirect url is not HTTPS then some or identity providers will fail.
185
            "redirect_url": urljoin(base_server_https_url + "/", "oauth2/callback"),
186
            "cookie_path": base_server_path,
187
            "proxy_prefix": parsed_proxy_url.path,
188
            "authenticated_emails_file": "/authorized_emails",
189
            "client_secret": nb_config.sessions.oidc.client_secret,
190
            "cookie_secret": base64.urlsafe_b64encode(os.urandom(32)).decode(),
191
            "insecure_oidc_allow_unverified_email": nb_config.sessions.oidc.allow_unverified_email,
192
        }
193
    )
194
    secret_data["authorized_emails"] = user.email
×
195
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
×
196
    vol_mount = ExtraVolumeMount(
×
197
        name="renku-authorized-emails",
198
        mountPath="/authorized_emails",
199
        subPath="authorized_emails",
200
    )
201
    return ExtraSecret(secret, vol, vol_mount)
×
202

203

204
def get_auth_secret_anonymous(nb_config: NotebooksConfig, server_name: str, request: Request) -> ExtraSecret:
2✔
205
    """Get the extra secrets that need to be added to the session for an anonymous user."""
206
    # NOTE: We extract the session cookie value here in order to avoid creating a cookie.
207
    # The gateway encrypts and signs cookies so the user ID injected in the request headers does not
208
    # match the value of the session cookie.
209
    session_id = cast(str | None, request.cookies.get(nb_config.session_id_cookie_name))
1✔
210
    if not session_id:
1✔
211
        raise errors.UnauthorizedError(
×
212
            message=f"You have to have a renku session cookie at {nb_config.session_id_cookie_name} "
213
            "in order to launch an anonymous session."
214
        )
215
    # NOTE: Amalthea looks for the token value first in the cookie and then in the authorization header
216
    secret_data = {
1✔
217
        "auth": safe_dump(
218
            {
219
                "authproxy": {
220
                    "token": session_id,
221
                    "cookie_key": nb_config.session_id_cookie_name,
222
                    "verbose": True,
223
                }
224
            }
225
        )
226
    }
227
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
1✔
228
    return ExtraSecret(secret)
1✔
229

230

231
async def __get_gitlab_image_pull_secret(
2✔
232
    nb_config: NotebooksConfig, user: AuthenticatedAPIUser, image_pull_secret_name: str, access_token: str
233
) -> ExtraSecret:
234
    """Create a Kubernetes secret for private GitLab registry authentication."""
235

236
    k8s_namespace = await nb_config.k8s_client.namespace()
×
237

238
    registry_secret = {
×
239
        "auths": {
240
            nb_config.git.registry: {
241
                "Username": "oauth2",
242
                "Password": access_token,
243
                "Email": user.email,
244
            }
245
        }
246
    }
247
    registry_secret = json.dumps(registry_secret)
×
248

249
    secret_data = {".dockerconfigjson": registry_secret}
×
250
    secret = V1Secret(
×
251
        metadata=V1ObjectMeta(name=image_pull_secret_name, namespace=k8s_namespace),
252
        string_data=secret_data,
253
        type="kubernetes.io/dockerconfigjson",
254
    )
255

256
    return ExtraSecret(secret)
×
257

258

259
async def get_data_sources(
2✔
260
    nb_config: NotebooksConfig,
261
    user: AnonymousAPIUser | AuthenticatedAPIUser,
262
    server_name: str,
263
    session_location: SessionLocation,
264
    data_connectors_stream: AsyncIterator[DataConnectorWithSecrets],
265
    work_dir: PurePosixPath,
266
    data_connectors_overrides: list[SessionDataConnectorOverride],
267
    user_repo: UserRepo,
268
) -> SessionExtraResources:
269
    """Generate cloud storage related resources."""
270
    data_sources: list[DataSource] = []
1✔
271
    secrets: list[ExtraSecret] = []
1✔
272
    dcs: dict[str, RCloneStorage] = {}
1✔
273
    dcs_secrets: dict[str, list[DataConnectorSecret]] = {}
1✔
274
    user_secret_key: str | None = None
1✔
275
    async for dc in data_connectors_stream:
1✔
276
        mount_folder = (
×
277
            dc.data_connector.storage.target_path
278
            if PurePosixPath(dc.data_connector.storage.target_path).is_absolute()
279
            else (work_dir / dc.data_connector.storage.target_path).as_posix()
280
        )
281
        dcs[str(dc.data_connector.id)] = RCloneStorage(
×
282
            source_path=dc.data_connector.storage.source_path,
283
            mount_folder=mount_folder,
284
            configuration=dc.data_connector.storage.configuration,
285
            readonly=dc.data_connector.storage.readonly,
286
            name=dc.data_connector.name,
287
            secrets={str(secret.secret_id): secret.name for secret in dc.secrets},
288
            storage_class=nb_config.cloud_storage.storage_class,
289
        )
NEW
290
        if len(dc.secrets) > 0 and session_location == SessionLocation.local:
×
291
            dcs_secrets[str(dc.data_connector.id)] = dc.secrets
×
NEW
292
        elif len(dc.secrets) > 0 and session_location == SessionLocation.remote:
×
293
            # NOTE: special handling for remote sessions; collect all secrets in the "all" key
NEW
294
            dcs_secrets_all = dcs_secrets.get("all", [])
×
NEW
295
            dcs_secrets_all.extend(
×
296
                [
297
                    DataConnectorSecret(
298
                        name=f"{str(dc.data_connector.id)}_{dcs.name}",
299
                        user_id=dcs.user_id,
300
                        data_connector_id=dcs.data_connector_id,
301
                        secret_id=dcs.secret_id,
302
                    )
303
                    for dcs in dc.secrets
304
                ]
305
            )
NEW
306
            dcs_secrets["all"] = dcs_secrets_all
×
307
    if isinstance(user, AuthenticatedAPIUser) and len(dcs_secrets) > 0:
1✔
308
        secret_key = await user_repo.get_or_create_user_secret_key(user)
×
309
        user_secret_key = get_encryption_key(secret_key.encode(), user.id.encode()).decode("utf-8")
×
310
    # NOTE: Check the cloud storage overrides from the request body and if any match
311
    # then overwrite the projects cloud storages
312
    # NOTE: Cloud storages in the session launch request body that are not from the DB will cause a 404 error
313
    # TODO: Is this correct? -> NOTE: Overriding the configuration when a saved secret is there will cause a 422 error
314
    for dco in data_connectors_overrides:
1✔
315
        dc_id = str(dco.data_connector_id)
×
316
        if dc_id not in dcs:
×
317
            raise errors.MissingResourceError(
×
318
                message=f"You have requested a data connector with ID {dc_id} which does not exist "
319
                "or you don't have access to."
320
            )
321
        # NOTE: if 'skip' is true, we do not mount that data connector
322
        if dco.skip:
×
323
            del dcs[dc_id]
×
NEW
324
            dcs_secrets[dc_id] = []  # Unset any data connector secret
×
325
            continue
×
326
        if dco.target_path is not None and not PurePosixPath(dco.target_path).is_absolute():
×
327
            dco.target_path = (work_dir / dco.target_path).as_posix()
×
328
        dcs[dc_id] = dcs[dc_id].with_override(dco)
×
329

330
    # Handle potential duplicate target_path
331
    dcs = _deduplicate_target_paths(dcs)
1✔
332

333
    for cs_id, cs in dcs.items():
1✔
334
        secret_name = f"{server_name}-ds-{cs_id.lower()}"
×
335
        secret_key_needed = len(dcs_secrets.get(cs_id, [])) > 0
×
336
        if secret_key_needed and user_secret_key is None:
×
337
            raise errors.ProgrammingError(
×
338
                message=f"You have saved storage secrets for data connector {cs_id} "
339
                f"associated with your user ID {user.id} but no key to decrypt them, "
340
                "therefore we cannot mount the requested data connector. "
341
                "Please report this to the renku administrators."
342
            )
343
        secret = ExtraSecret(
×
344
            cs.secret(
345
                secret_name,
346
                await nb_config.k8s_client.namespace(),
347
                user_secret_key=user_secret_key if secret_key_needed else None,
348
            )
349
        )
350
        secrets.append(secret)
×
351
        data_sources.append(
×
352
            DataSource(
353
                mountPath=cs.mount_folder,
354
                secretRef=secret.ref(),
355
                accessMode="ReadOnlyMany" if cs.readonly else "ReadWriteOnce",
356
            )
357
        )
358
    return SessionExtraResources(
1✔
359
        data_sources=data_sources,
360
        secrets=secrets,
361
        data_connector_secrets=dcs_secrets,
362
    )
363

364

365
async def request_dc_secret_creation(
2✔
366
    user: AuthenticatedAPIUser | AnonymousAPIUser,
367
    nb_config: NotebooksConfig,
368
    manifest: AmaltheaSessionV1Alpha1,
369
    dc_secrets: dict[str, list[DataConnectorSecret]],
370
) -> None:
371
    """Request the specified data connector secrets to be created by the secret service."""
372
    if isinstance(user, AnonymousAPIUser):
1✔
373
        return
×
374
    owner_reference = {
1✔
375
        "apiVersion": manifest.apiVersion,
376
        "kind": manifest.kind,
377
        "name": manifest.metadata.name,
378
        "uid": manifest.metadata.uid,
379
    }
380
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
1✔
381
    headers = {"Authorization": f"bearer {user.access_token}"}
1✔
382

383
    cluster_id = None
1✔
384
    namespace = await nb_config.k8s_v2_client.namespace()
1✔
385
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
1✔
386
        cluster_id = cluster.id
1✔
387
        namespace = cluster.namespace
1✔
388

389
    for s_id, secrets in dc_secrets.items():
1✔
390
        if len(secrets) == 0:
×
391
            continue
×
392
        request_data = {
×
393
            "name": f"{manifest.metadata.name}-ds-{s_id.lower()}-secrets",
394
            "namespace": namespace,
395
            "secret_ids": [str(secret.secret_id) for secret in secrets],
396
            "owner_references": [owner_reference],
397
            "key_mapping": {str(secret.secret_id): secret.name for secret in secrets},
398
            "cluster_id": str(cluster_id),
399
        }
400
        async with httpx.AsyncClient(timeout=10) as client:
×
401
            res = await client.post(secrets_url, headers=headers, json=request_data)
×
402
            if res.status_code >= 300 or res.status_code < 200:
×
403
                raise errors.ProgrammingError(
×
404
                    message=f"The secret for data connector with {s_id} could not be "
405
                    f"successfully created, the status code was {res.status_code}."
406
                    "Please contact a Renku administrator.",
407
                    detail=res.text,
408
                )
409

410

411
def get_launcher_env_variables(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> list[SessionEnvItem]:
2✔
412
    """Get the environment variables from the launcher, with overrides from the request."""
413
    output: list[SessionEnvItem] = []
1✔
414
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
1✔
415
    for env in launcher.env_variables or []:
1✔
416
        if env.name in env_overrides:
×
417
            output.append(SessionEnvItem(name=env.name, value=env_overrides[env.name]))
×
418
        else:
419
            output.append(SessionEnvItem(name=env.name, value=env.value))
×
420
    return output
1✔
421

422

423
def verify_launcher_env_variable_overrides(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> None:
2✔
424
    """Raise an error if there are env variables that are not defined in the launcher."""
425
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
1✔
426
    known_env_names = {i.name for i in launcher.env_variables or []}
1✔
427
    unknown_env_names = set(env_overrides.keys()) - known_env_names
1✔
428
    if unknown_env_names:
1✔
429
        message = f"""The following environment variables are not defined in the session launcher: {unknown_env_names}.
×
430
            Please remove them from the launch request or add them to the session launcher."""
431
        raise errors.ValidationError(message=message)
×
432

433

434
async def request_session_secret_creation(
2✔
435
    user: AuthenticatedAPIUser | AnonymousAPIUser,
436
    nb_config: NotebooksConfig,
437
    manifest: AmaltheaSessionV1Alpha1,
438
    session_secrets: list[SessionSecret],
439
) -> None:
440
    """Request the specified user session secrets to be created by the secret service."""
441
    if isinstance(user, AnonymousAPIUser):
1✔
442
        return
×
443
    if not session_secrets:
1✔
444
        return
1✔
445
    owner_reference = {
×
446
        "apiVersion": manifest.apiVersion,
447
        "kind": manifest.kind,
448
        "name": manifest.metadata.name,
449
        "uid": manifest.metadata.uid,
450
    }
451
    key_mapping: dict[str, list[str]] = dict()
×
452
    for s in session_secrets:
×
453
        secret_id = str(s.secret_id)
×
454
        if secret_id not in key_mapping:
×
455
            key_mapping[secret_id] = list()
×
456
        key_mapping[secret_id].append(s.secret_slot.filename)
×
457

458
    cluster_id = None
×
459
    namespace = await nb_config.k8s_v2_client.namespace()
×
460
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
×
461
        cluster_id = cluster.id
×
462
        namespace = cluster.namespace
×
463

464
    request_data = {
×
465
        "name": f"{manifest.metadata.name}-secrets",
466
        "namespace": namespace,
467
        "secret_ids": [str(s.secret_id) for s in session_secrets],
468
        "owner_references": [owner_reference],
469
        "key_mapping": key_mapping,
470
        "cluster_id": str(cluster_id),
471
    }
472
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
×
473
    headers = {"Authorization": f"bearer {user.access_token}"}
×
474
    async with httpx.AsyncClient(timeout=10) as client:
×
475
        res = await client.post(secrets_url, headers=headers, json=request_data)
×
476
        if res.status_code >= 300 or res.status_code < 200:
×
477
            raise errors.ProgrammingError(
×
478
                message="The session secrets could not be successfully created, "
479
                f"the status code was {res.status_code}."
480
                "Please contact a Renku administrator.",
481
                detail=res.text,
482
            )
483

484

485
def resources_patch_from_resource_class(resource_class: ResourceClass) -> ResourcesPatch:
2✔
486
    """Convert the resource class to a k8s resources spec."""
487
    gpu_name = GpuKind.NVIDIA.value + "/gpu"
1✔
488
    resources = resources_from_resource_class(resource_class)
1✔
489
    requests: Mapping[str, Requests | RequestsStr | ResetType] | ResetType | None = None
1✔
490
    limits: Mapping[str, Limits | LimitsStr | ResetType] | ResetType | None = None
1✔
491
    defaul_requests = {"memory": RESET, "cpu": RESET, gpu_name: RESET}
1✔
492
    default_limits = {"memory": RESET, "cpu": RESET, gpu_name: RESET}
1✔
493
    if resources.requests is not None:
1✔
494
        requests = RESET if len(resources.requests.keys()) == 0 else {**defaul_requests, **resources.requests}
1✔
495
    if resources.limits is not None:
1✔
496
        limits = RESET if len(resources.limits.keys()) == 0 else {**default_limits, **resources.limits}
1✔
497
    return ResourcesPatch(requests=requests, limits=limits)
1✔
498

499

500
def resources_from_resource_class(resource_class: ResourceClass) -> Resources:
2✔
501
    """Convert the resource class to a k8s resources spec."""
502
    requests: dict[str, Requests | RequestsStr] = {
1✔
503
        "cpu": RequestsStr(str(round(resource_class.cpu * 1000)) + "m"),
504
        "memory": RequestsStr(f"{resource_class.memory}Gi"),
505
    }
506
    limits: dict[str, Limits | LimitsStr] = {"memory": LimitsStr(f"{resource_class.memory}Gi")}
1✔
507
    if resource_class.gpu > 0:
1✔
508
        gpu_name = GpuKind.NVIDIA.value + "/gpu"
×
509
        requests[gpu_name] = Requests(resource_class.gpu)
×
510
        # NOTE: GPUs have to be set in limits too since GPUs cannot be overcommited, if
511
        # not on some clusters this will cause the session to fully fail to start.
512
        limits[gpu_name] = Limits(resource_class.gpu)
×
513
    return Resources(requests=requests, limits=limits if len(limits) > 0 else None)
1✔
514

515

516
def repositories_from_project(project: Project, git_providers: list[GitProvider]) -> list[Repository]:
2✔
517
    """Get the list of git repositories from a project."""
518
    repositories: list[Repository] = []
1✔
519
    for repo in project.repositories:
1✔
520
        found_provider_id: str | None = None
×
521
        for provider in git_providers:
×
522
            if urlparse(provider.url).netloc == urlparse(repo).netloc:
×
523
                found_provider_id = provider.id
×
524
                break
×
525
        repositories.append(Repository(url=repo, provider=found_provider_id))
×
526
    return repositories
1✔
527

528

529
async def repositories_from_session(
2✔
530
    user: AnonymousAPIUser | AuthenticatedAPIUser,
531
    session: AmaltheaSessionV1Alpha1,
532
    project_repo: ProjectRepository,
533
    git_providers: list[GitProvider],
534
) -> list[Repository]:
535
    """Get the list of git repositories from a session."""
536
    try:
×
537
        project = await project_repo.get_project(user, session.project_id)
×
538
    except errors.MissingResourceError:
×
539
        return []
×
540
    return repositories_from_project(project, git_providers)
×
541

542

543
def get_culling(
2✔
544
    user: AuthenticatedAPIUser | AnonymousAPIUser, resource_pool: ResourcePool, nb_config: NotebooksConfig
545
) -> Culling:
546
    """Create the culling specification for an AmaltheaSession."""
547
    if user.is_anonymous:
1✔
548
        # NOTE: Anonymous sessions should not be hibernated at all, but there is no such option in Amalthea
549
        # So in this case we set a very low hibernation threshold so the session is deleted quickly after
550
        # it is hibernated.
551
        hibernation_threshold: timedelta | None = timedelta(seconds=1)
×
552
    else:
553
        hibernation_threshold = (
1✔
554
            timedelta(seconds=resource_pool.hibernation_threshold)
555
            if resource_pool.hibernation_threshold is not None
556
            else None
557
        )
558
    return Culling(
1✔
559
        maxAge=timedelta(seconds=nb_config.sessions.culling.registered.max_age_seconds),
560
        maxFailedDuration=timedelta(seconds=nb_config.sessions.culling.registered.failed_seconds),
561
        maxHibernatedDuration=hibernation_threshold,
562
        maxIdleDuration=timedelta(seconds=resource_pool.idle_threshold)
563
        if resource_pool.idle_threshold is not None
564
        else None,
565
        maxStartingDuration=timedelta(seconds=nb_config.sessions.culling.registered.pending_seconds),
566
    )
567

568

569
def get_culling_patch(
2✔
570
    user: AuthenticatedAPIUser | AnonymousAPIUser, resource_pool: ResourcePool, nb_config: NotebooksConfig
571
) -> CullingPatch:
572
    """Get the patch for the culling durations of a session."""
573
    culling = get_culling(user, resource_pool, nb_config)
1✔
574
    return CullingPatch(
1✔
575
        maxAge=culling.maxAge or RESET,
576
        maxFailedDuration=culling.maxFailedDuration or RESET,
577
        maxHibernatedDuration=culling.maxHibernatedDuration or RESET,
578
        maxIdleDuration=culling.maxIdleDuration or RESET,
579
        maxStartingDuration=culling.maxStartingDuration or RESET,
580
    )
581

582

583
async def __requires_image_pull_secret(nb_config: NotebooksConfig, image: str, internal_gitlab_user: APIUser) -> bool:
2✔
584
    """Determines if an image requires a pull secret based on its visibility and their GitLab access token."""
585

586
    parsed_image = Image.from_path(image)
×
587
    image_repo = parsed_image.repo_api()
×
588

589
    image_exists_publicly = await image_repo.image_exists(parsed_image)
×
590
    if image_exists_publicly:
×
591
        return False
×
592

593
    if parsed_image.hostname == nb_config.git.registry and internal_gitlab_user.access_token:
×
594
        image_repo = image_repo.with_oauth2_token(internal_gitlab_user.access_token)
×
595
        image_exists_privately = await image_repo.image_exists(parsed_image)
×
596
        if image_exists_privately:
×
597
            return True
×
598
    # No pull secret needed if the image is private and the user cannot access it
599
    return False
×
600

601

602
def __format_image_pull_secret(secret_name: str, access_token: str, registry_domain: str) -> ExtraSecret:
2✔
603
    registry_secret = {
×
604
        "auths": {registry_domain: {"auth": base64.b64encode(f"oauth2:{access_token}".encode()).decode()}}
605
    }
606
    registry_secret = json.dumps(registry_secret)
×
607
    registry_secret = base64.b64encode(registry_secret.encode()).decode()
×
608
    return ExtraSecret(
×
609
        V1Secret(
610
            data={".dockerconfigjson": registry_secret},
611
            metadata=V1ObjectMeta(name=secret_name),
612
            type="kubernetes.io/dockerconfigjson",
613
        )
614
    )
615

616

617
async def __get_connected_services_image_pull_secret(
2✔
618
    secret_name: str, connected_svcs_repo: ConnectedServicesRepository, image: str, user: APIUser
619
) -> ExtraSecret | None:
620
    """Return a secret for accessing the image if one is available for the given user."""
621
    image_parsed = Image.from_path(image)
1✔
622
    image_check_result = await ic.check_image(image_parsed, user, connected_svcs_repo, None)
1✔
623
    logger.debug(f"Set pull secret for {image} to connection {image_check_result.image_provider}")
1✔
624
    if not image_check_result.token:
1✔
625
        return None
1✔
626

627
    if not image_check_result.image_provider:
×
628
        return None
×
629

630
    return __format_image_pull_secret(
×
631
        secret_name=secret_name,
632
        access_token=image_check_result.token,
633
        registry_domain=image_check_result.image_provider.registry_url,
634
    )
635

636

637
async def get_image_pull_secret(
2✔
638
    image: str,
639
    server_name: str,
640
    nb_config: NotebooksConfig,
641
    user: APIUser,
642
    internal_gitlab_user: APIUser,
643
    connected_svcs_repo: ConnectedServicesRepository,
644
) -> ExtraSecret | None:
645
    """Get an image pull secret."""
646

647
    v2_secret = await __get_connected_services_image_pull_secret(
1✔
648
        f"{server_name}-image-secret", connected_svcs_repo, image, user
649
    )
650
    if v2_secret:
1✔
651
        return v2_secret
×
652

653
    if (
1✔
654
        nb_config.enable_internal_gitlab
655
        and isinstance(user, AuthenticatedAPIUser)
656
        and internal_gitlab_user.access_token is not None
657
    ):
658
        needs_pull_secret = await __requires_image_pull_secret(nb_config, image, internal_gitlab_user)
×
659
        if needs_pull_secret:
×
660
            v1_secret = await __get_gitlab_image_pull_secret(
×
661
                nb_config, user, f"{server_name}-image-secret-v1", internal_gitlab_user.access_token
662
            )
663
            return v1_secret
×
664

665
    return None
1✔
666

667

668
def get_remote_secret(
2✔
669
    user: AuthenticatedAPIUser | AnonymousAPIUser,
670
    config: NotebooksConfig,
671
    server_name: str,
672
    remote_provider_id: str,
673
    git_providers: list[GitProvider],
674
) -> ExtraSecret | None:
675
    """Returns the secret containing the configuration for the remote session controller."""
676
    if not user.is_authenticated or user.access_token is None or user.refresh_token is None:
×
677
        return None
×
678
    remote_provider = next(filter(lambda p: p.id == remote_provider_id, git_providers), None)
×
679
    if not remote_provider:
×
680
        return None
×
681
    renku_base_url = "https://" + config.sessions.ingress.host
×
682
    renku_base_url = renku_base_url.rstrip("/")
×
683
    renku_auth_token_uri = f"{renku_base_url}/auth/realms/{config.keycloak_realm}/protocol/openid-connect/token"
×
684
    secret_data = {
×
685
        "RSC_AUTH_KIND": "renku",
686
        "RSC_AUTH_TOKEN_URI": remote_provider.access_token_url,
687
        "RSC_AUTH_RENKU_ACCESS_TOKEN": user.access_token,
688
        "RSC_AUTH_RENKU_REFRESH_TOKEN": user.refresh_token,
689
        "RSC_AUTH_RENKU_TOKEN_URI": renku_auth_token_uri,
690
        "RSC_AUTH_RENKU_CLIENT_ID": config.sessions.git_proxy.renku_client_id,
691
        "RSC_AUTH_RENKU_CLIENT_SECRET": config.sessions.git_proxy.renku_client_secret,
692
    }
693
    secret_name = f"{server_name}-remote-secret"
×
694
    secret = V1Secret(metadata=V1ObjectMeta(name=secret_name), string_data=secret_data)
×
695
    return ExtraSecret(secret)
×
696

697

698
def get_remote_env(
2✔
699
    remote: RemoteConfigurationFirecrest,
700
) -> list[SessionEnvItem]:
701
    """Returns env variables used for remote sessions."""
702
    env = [
×
703
        SessionEnvItem(name="RSC_REMOTE_KIND", value=remote.kind.value),
704
        SessionEnvItem(name="RSC_FIRECREST_API_URL", value=remote.api_url),
705
        SessionEnvItem(name="RSC_FIRECREST_SYSTEM_NAME", value=remote.system_name),
706
        # TODO: remove fake start
707
        SessionEnvItem(name="RSC_FAKE_START", value="true"),
708
    ]
709
    if remote.partition:
×
710
        env.append(SessionEnvItem(name="RSC_FIRECREST_PARTITION", value=remote.partition))
×
711
    return env
×
712

713

714
async def start_session(
2✔
715
    request: Request,
716
    launch_request: SessionLaunchRequest,
717
    user: AnonymousAPIUser | AuthenticatedAPIUser,
718
    internal_gitlab_user: APIUser,
719
    nb_config: NotebooksConfig,
720
    git_provider_helper: GitProviderHelperProto,
721
    cluster_repo: ClusterRepository,
722
    data_connector_secret_repo: DataConnectorSecretRepository,
723
    project_repo: ProjectRepository,
724
    project_session_secret_repo: ProjectSessionSecretRepository,
725
    rp_repo: ResourcePoolRepository,
726
    session_repo: SessionRepository,
727
    user_repo: UserRepo,
728
    metrics: MetricsService,
729
    connected_svcs_repo: ConnectedServicesRepository,
730
) -> tuple[AmaltheaSessionV1Alpha1, bool]:
731
    """Start an Amalthea session.
732

733
    Returns a tuple where the first item is an instance of an Amalthea session
734
    and the second item is a boolean set to true iff a new session was created.
735
    """
736
    launcher = await session_repo.get_launcher(user=user, launcher_id=launch_request.launcher_id)
1✔
737
    launcher_id = launcher.id
1✔
738
    project = await project_repo.get_project(user=user, project_id=launcher.project_id)
1✔
739

740
    # Determine resource_class_id: the class can be overwritten at the user's request
741
    resource_class_id = launch_request.resource_class_id or launcher.resource_class_id
1✔
742

743
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(resource_class_id, user)
1✔
744

745
    server_name = renku_2_make_server_name(
1✔
746
        user=user, project_id=str(launcher.project_id), launcher_id=str(launcher_id), cluster_id=str(cluster.id)
747
    )
748
    existing_session = await nb_config.k8s_v2_client.get_session(name=server_name, safe_username=user.id)
1✔
749
    if existing_session is not None and existing_session.spec is not None:
1✔
750
        return existing_session, False
×
751

752
    # Fully determine the resource pool and resource class
753
    if resource_class_id is None:
1✔
754
        resource_pool = await rp_repo.get_default_resource_pool()
×
755
        resource_class = resource_pool.get_default_resource_class()
×
756
        if not resource_class and len(resource_pool.classes) > 0:
×
757
            resource_class = resource_pool.classes[0]
×
758
        if not resource_class or not resource_class.id:
×
759
            raise errors.ProgrammingError(message="Cannot find any resource classes in the default pool.")
×
760
        resource_class_id = resource_class.id
×
761
    else:
762
        resource_pool = await rp_repo.get_resource_pool_from_class(user, resource_class_id)
1✔
763
        resource_class = resource_pool.get_resource_class(resource_class_id)
1✔
764
        if not resource_class or not resource_class.id:
1✔
765
            raise errors.MissingResourceError(message=f"The resource class with ID {resource_class_id} does not exist.")
×
766
    await nb_config.crc_validator.validate_class_storage(user, resource_class.id, launch_request.disk_storage)
1✔
767
    disk_storage = launch_request.disk_storage or resource_class.default_storage
1✔
768

769
    # Determine session location
770
    session_location = SessionLocation.remote if resource_pool.remote else SessionLocation.local
1✔
771
    if session_location == SessionLocation.remote and not user.is_authenticated:
1✔
772
        raise errors.ValidationError(message="Anonymous users cannot start remote sessions.")
×
773

774
    environment = launcher.environment
1✔
775
    image = environment.container_image
1✔
776
    work_dir = environment.working_directory
1✔
777
    if not work_dir:
1✔
778
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
1✔
779
        work_dir_fallback = PurePosixPath("/home/jovyan")
1✔
780
        work_dir = image_workdir or work_dir_fallback
1✔
781
    storage_mount_fallback = work_dir / "work"
1✔
782
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
1✔
783
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
1✔
784
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
1✔
785
        user=user, project_id=project.id
786
    )
787
    data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id)
1✔
788
    git_providers = await git_provider_helper.get_providers(user=user)
1✔
789
    repositories = repositories_from_project(project, git_providers)
1✔
790

791
    # User secrets
792
    session_extras = SessionExtraResources()
1✔
793
    session_extras = session_extras.concat(
1✔
794
        user_secrets_extras(
795
            user=user,
796
            config=nb_config,
797
            secrets_mount_directory=secrets_mount_directory.as_posix(),
798
            k8s_secret_name=f"{server_name}-secrets",
799
            session_secrets=session_secrets,
800
        )
801
    )
802

803
    # Data connectors
804
    session_extras = session_extras.concat(
1✔
805
        await get_data_sources(
806
            nb_config=nb_config,
807
            server_name=server_name,
808
            user=user,
809
            session_location=session_location,
810
            data_connectors_stream=data_connectors_stream,
811
            work_dir=work_dir,
812
            data_connectors_overrides=launch_request.data_connectors_overrides or [],
813
            user_repo=user_repo,
814
        )
815
    )
816

817
    # More init containers
818
    session_extras = session_extras.concat(
1✔
819
        await get_extra_init_containers(
820
            nb_config,
821
            user,
822
            repositories,
823
            git_providers,
824
            storage_mount,
825
            work_dir,
826
            uid=environment.uid,
827
            gid=environment.gid,
828
        )
829
    )
830

831
    # Extra containers
832
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
1✔
833

834
    # Cluster settings (ingress, storage class, etc)
835
    cluster_settings: ClusterSettings
836
    try:
1✔
837
        cluster_settings = await cluster_repo.select(cluster.id)
1✔
838
    except errors.MissingResourceError:
1✔
839
        # Fallback to global, main cluster parameters
840
        cluster_settings = nb_config.local_cluster_settings()
1✔
841

842
    (
1✔
843
        base_server_path,
844
        base_server_url,
845
        base_server_https_url,
846
        host,
847
        tls_secret,
848
        ingress_class_name,
849
        ingress_annotations,
850
    ) = cluster_settings.get_ingress_parameters(server_name)
851
    storage_class = cluster_settings.get_storage_class()
1✔
852
    service_account_name = cluster_settings.service_account_name
1✔
853

854
    ui_path = f"{base_server_path}/{environment.default_url.lstrip('/')}"
1✔
855

856
    ingress = Ingress(
1✔
857
        host=host,
858
        ingressClassName=ingress_class_name,
859
        annotations=ingress_annotations,
860
        tlsSecret=tls_secret,
861
        pathPrefix=base_server_path,
862
    )
863

864
    # Annotations
865
    annotations: dict[str, str] = {
1✔
866
        "renku.io/project_id": str(launcher.project_id),
867
        "renku.io/launcher_id": str(launcher_id),
868
        "renku.io/resource_class_id": str(resource_class_id),
869
    }
870

871
    # Authentication
872
    if isinstance(user, AuthenticatedAPIUser):
1✔
873
        auth_secret = await get_auth_secret_authenticated(
×
874
            nb_config, user, server_name, base_server_url, base_server_https_url, base_server_path
875
        )
876
    else:
877
        auth_secret = get_auth_secret_anonymous(nb_config, server_name, request)
1✔
878
    session_extras = session_extras.concat(
1✔
879
        SessionExtraResources(
880
            secrets=[auth_secret],
881
            volumes=[auth_secret.volume] if auth_secret.volume else [],
882
        )
883
    )
884
    authn_extra_volume_mounts: list[ExtraVolumeMount] = []
1✔
885
    if auth_secret.volume_mount:
1✔
886
        authn_extra_volume_mounts.append(auth_secret.volume_mount)
×
887

888
    cert_vol_mounts = init_containers.certificates_volume_mounts(nb_config)
1✔
889
    if cert_vol_mounts:
1✔
890
        authn_extra_volume_mounts.extend(cert_vol_mounts)
1✔
891

892
    image_secret = await get_image_pull_secret(
1✔
893
        image=image,
894
        server_name=server_name,
895
        nb_config=nb_config,
896
        user=user,
897
        internal_gitlab_user=internal_gitlab_user,
898
        connected_svcs_repo=connected_svcs_repo,
899
    )
900
    if image_secret:
1✔
901
        session_extras = session_extras.concat(SessionExtraResources(secrets=[image_secret]))
×
902

903
    # Remote session configuration
904
    remote_secret = None
1✔
905
    if session_location == SessionLocation.remote:
1✔
906
        assert resource_pool.remote is not None
×
907
        if resource_pool.remote.provider_id is None:
×
908
            raise errors.ProgrammingError(
×
909
                message=f"The resource pool {resource_pool.id} configuration is not valid (missing field 'remote_provider_id')."  # noqa E501
910
            )
911
        remote_secret = get_remote_secret(
×
912
            user=user,
913
            config=nb_config,
914
            server_name=server_name,
915
            remote_provider_id=resource_pool.remote.provider_id,
916
            git_providers=git_providers,
917
        )
918
    if remote_secret is not None:
1✔
919
        session_extras = session_extras.concat(SessionExtraResources(secrets=[remote_secret]))
×
920

921
    # Raise an error if there are invalid environment variables in the request body
922
    verify_launcher_env_variable_overrides(launcher, launch_request)
1✔
923
    env = [
1✔
924
        SessionEnvItem(name="RENKU_BASE_URL_PATH", value=base_server_path),
925
        SessionEnvItem(name="RENKU_BASE_URL", value=base_server_url),
926
        SessionEnvItem(name="RENKU_MOUNT_DIR", value=storage_mount.as_posix()),
927
        SessionEnvItem(name="RENKU_SESSION", value="1"),
928
        SessionEnvItem(name="RENKU_SESSION_IP", value="0.0.0.0"),  # nosec B104
929
        SessionEnvItem(name="RENKU_SESSION_PORT", value=f"{environment.port}"),
930
        SessionEnvItem(name="RENKU_WORKING_DIR", value=work_dir.as_posix()),
931
        SessionEnvItem(name="RENKU_SECRETS_PATH", value=project.secrets_mount_directory.as_posix()),
932
        SessionEnvItem(name="RENKU_PROJECT_ID", value=str(project.id)),
933
        SessionEnvItem(name="RENKU_PROJECT_PATH", value=project.path.serialize()),
934
        SessionEnvItem(name="RENKU_LAUNCHER_ID", value=str(launcher.id)),
935
    ]
936
    if session_location == SessionLocation.remote:
1✔
937
        assert resource_pool.remote is not None
×
938
        env.extend(
×
939
            get_remote_env(
940
                remote=resource_pool.remote,
941
            )
942
        )
943
    launcher_env_variables = get_launcher_env_variables(launcher, launch_request)
1✔
944
    env.extend(launcher_env_variables)
1✔
945

946
    session = AmaltheaSessionV1Alpha1(
1✔
947
        metadata=Metadata(name=server_name, annotations=annotations),
948
        spec=AmaltheaSessionSpec(
949
            location=session_location,
950
            imagePullSecrets=[ImagePullSecret(name=image_secret.name, adopt=True)] if image_secret else [],
951
            codeRepositories=[],
952
            hibernated=False,
953
            reconcileStrategy=ReconcileStrategy.whenFailedOrHibernated,
954
            priorityClassName=resource_class.quota,
955
            session=Session(
956
                image=image,
957
                imagePullPolicy=ImagePullPolicy.Always,
958
                urlPath=ui_path,
959
                port=environment.port,
960
                storage=Storage(
961
                    className=storage_class,
962
                    size=SizeStr(str(disk_storage) + "G"),
963
                    mountPath=storage_mount.as_posix(),
964
                ),
965
                workingDir=work_dir.as_posix(),
966
                runAsUser=environment.uid,
967
                runAsGroup=environment.gid,
968
                resources=resources_from_resource_class(resource_class),
969
                extraVolumeMounts=session_extras.volume_mounts,
970
                command=environment.command,
971
                args=environment.args,
972
                shmSize=ShmSizeStr("1G"),
973
                stripURLPath=environment.strip_path_prefix,
974
                env=env,
975
                remoteSecretRef=remote_secret.ref() if remote_secret else None,
976
            ),
977
            ingress=ingress,
978
            extraContainers=session_extras.containers,
979
            initContainers=session_extras.init_containers,
980
            extraVolumes=session_extras.volumes,
981
            culling=get_culling(user, resource_pool, nb_config),
982
            authentication=Authentication(
983
                enabled=True,
984
                type=AuthenticationType.oauth2proxy
985
                if isinstance(user, AuthenticatedAPIUser)
986
                else AuthenticationType.token,
987
                secretRef=auth_secret.key_ref("auth"),
988
                extraVolumeMounts=authn_extra_volume_mounts,
989
            ),
990
            dataSources=session_extras.data_sources,
991
            tolerations=tolerations_from_resource_class(resource_class, nb_config.sessions.tolerations_model),
992
            affinity=node_affinity_from_resource_class(resource_class, nb_config.sessions.affinity_model),
993
            serviceAccountName=service_account_name,
994
        ),
995
    )
996
    secrets_to_create = session_extras.secrets or []
1✔
997
    for s in secrets_to_create:
1✔
998
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
1✔
999
    try:
1✔
1000
        session = await nb_config.k8s_v2_client.create_session(session, user)
1✔
1001
    except Exception as err:
×
1002
        for s in secrets_to_create:
×
1003
            await nb_config.k8s_v2_client.delete_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1004
        raise errors.ProgrammingError(message="Could not start the amalthea session") from err
×
1005
    else:
1006
        try:
1✔
1007
            await request_session_secret_creation(user, nb_config, session, session_secrets)
1✔
1008
            data_connector_secrets = session_extras.data_connector_secrets or dict()
1✔
1009
            await request_dc_secret_creation(user, nb_config, session, data_connector_secrets)
1✔
1010
        except Exception:
×
1011
            await nb_config.k8s_v2_client.delete_session(server_name, user.id)
×
1012
            raise
×
1013

1014
    await metrics.user_requested_session_launch(
1✔
1015
        user=user,
1016
        metadata={
1017
            "cpu": int(resource_class.cpu * 1000),
1018
            "memory": resource_class.memory,
1019
            "gpu": resource_class.gpu,
1020
            "storage": disk_storage,
1021
            "resource_class_id": resource_class.id,
1022
            "resource_pool_id": resource_pool.id or "",
1023
            "resource_class_name": f"{resource_pool.name}.{resource_class.name}",
1024
            "session_id": server_name,
1025
        },
1026
    )
1027
    return session, True
1✔
1028

1029

1030
async def patch_session(
2✔
1031
    body: apispec.SessionPatchRequest,
1032
    session_id: str,
1033
    user: AnonymousAPIUser | AuthenticatedAPIUser,
1034
    internal_gitlab_user: APIUser,
1035
    nb_config: NotebooksConfig,
1036
    git_provider_helper: GitProviderHelperProto,
1037
    project_repo: ProjectRepository,
1038
    project_session_secret_repo: ProjectSessionSecretRepository,
1039
    rp_repo: ResourcePoolRepository,
1040
    session_repo: SessionRepository,
1041
    connected_svcs_repo: ConnectedServicesRepository,
1042
    metrics: MetricsService,
1043
) -> AmaltheaSessionV1Alpha1:
1044
    """Patch an Amalthea session."""
1045
    session = await nb_config.k8s_v2_client.get_session(session_id, user.id)
1✔
1046
    if session is None:
1✔
1047
        raise errors.MissingResourceError(message=f"The session with ID {session_id} does not exist")
1✔
1048
    if session.spec is None:
1✔
1049
        raise errors.ProgrammingError(
×
1050
            message=f"The session {session_id} being patched is missing the expected 'spec' field.", quiet=True
1051
        )
1052
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(session.resource_class_id(), user)
1✔
1053

1054
    patch = AmaltheaSessionV1Alpha1Patch(spec=AmaltheaSessionV1Alpha1SpecPatch())
1✔
1055
    is_getting_hibernated: bool = False
1✔
1056

1057
    # Hibernation
1058
    # TODO: Some patching should only be done when the session is in some states to avoid inadvertent restarts
1059
    # Refresh tokens for git proxy
1060
    if (
1✔
1061
        body.state is not None
1062
        and body.state.value.lower() == State.Hibernated.value.lower()
1063
        and body.state.value.lower() != session.status.state.value.lower()
1064
    ):
1065
        # Session is being hibernated
1066
        patch.spec.hibernated = True
1✔
1067
        is_getting_hibernated = True
1✔
1068
    elif (
1✔
1069
        body.state is not None
1070
        and body.state.value.lower() == State.Running.value.lower()
1071
        and session.status.state.value.lower() != body.state.value.lower()
1072
    ):
1073
        # Session is being resumed
1074
        patch.spec.hibernated = False
×
1075
        await metrics.user_requested_session_resume(user, metadata={"session_id": session_id})
×
1076

1077
    # Resource class
1078
    if body.resource_class_id is not None:
1✔
1079
        new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(body.resource_class_id, user)
1✔
1080
        if new_cluster.id != cluster.id:
1✔
1081
            raise errors.ValidationError(
×
1082
                message=(
1083
                    f"The requested resource class {body.resource_class_id} is not in the "
1084
                    f"same cluster {cluster.id} as the current resource class {session.resource_class_id()}."
1085
                )
1086
            )
1087
        rp = await rp_repo.get_resource_pool_from_class(user, body.resource_class_id)
1✔
1088
        rc = rp.get_resource_class(body.resource_class_id)
1✔
1089
        if not rc:
1✔
1090
            raise errors.MissingResourceError(
×
1091
                message=f"The resource class you requested with ID {body.resource_class_id} does not exist"
1092
            )
1093
        if not patch.metadata:
1✔
1094
            patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch()
1✔
1095
        # Patch the resource pool and class ID in the annotations
1096
        patch.metadata.annotations = {"renku.io/resource_pool_id": str(rp.id)}
1✔
1097
        patch.metadata.annotations = {"renku.io/resource_class_id": str(body.resource_class_id)}
1✔
1098
        if not patch.spec.session:
1✔
1099
            patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
1✔
1100
        patch.spec.session.resources = resources_patch_from_resource_class(rc)
1✔
1101
        # Tolerations
1102
        tolerations = tolerations_from_resource_class(rc, nb_config.sessions.tolerations_model)
1✔
1103
        patch.spec.tolerations = tolerations
1✔
1104
        # Affinities
1105
        patch.spec.affinity = node_affinity_patch_from_resource_class(rc, nb_config.sessions.affinity_model)
1✔
1106
        # Priority class (if a quota is being used)
1107
        if rc.quota is None:
1✔
1108
            patch.spec.priorityClassName = RESET
×
1109
        patch.spec.culling = get_culling_patch(user, rp, nb_config)
1✔
1110
        # Service account name
1111
        if rp.cluster is not None:
1✔
1112
            patch.spec.service_account_name = (
×
1113
                rp.cluster.service_account_name if rp.cluster.service_account_name is not None else RESET
1114
            )
1115

1116
    # If the session is being hibernated we do not need to patch anything else that is
1117
    # not specifically called for in the request body, we can refresh things when the user resumes.
1118
    if is_getting_hibernated:
1✔
1119
        return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch.to_rfc7386())
1✔
1120

1121
    server_name = session.metadata.name
1✔
1122
    launcher = await session_repo.get_launcher(user, session.launcher_id)
1✔
1123
    project = await project_repo.get_project(user=user, project_id=session.project_id)
1✔
1124
    environment = launcher.environment
1✔
1125
    work_dir = environment.working_directory
1✔
1126
    if not work_dir:
1✔
1127
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
1✔
1128
        work_dir_fallback = PurePosixPath("/home/jovyan")
1✔
1129
        work_dir = image_workdir or work_dir_fallback
1✔
1130
    storage_mount_fallback = work_dir / "work"
1✔
1131
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
1✔
1132
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
1✔
1133
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
1✔
1134
        user=user, project_id=project.id
1135
    )
1136
    git_providers = await git_provider_helper.get_providers(user=user)
1✔
1137
    repositories = repositories_from_project(project, git_providers)
1✔
1138

1139
    # User secrets
1140
    session_extras = SessionExtraResources()
1✔
1141
    session_extras = session_extras.concat(
1✔
1142
        user_secrets_extras(
1143
            user=user,
1144
            config=nb_config,
1145
            secrets_mount_directory=secrets_mount_directory.as_posix(),
1146
            k8s_secret_name=f"{server_name}-secrets",
1147
            session_secrets=session_secrets,
1148
        )
1149
    )
1150

1151
    # Data connectors: skip
1152
    # TODO: How can we patch data connectors? Should we even patch them?
1153
    # TODO: The fact that `start_session()` accepts overrides for data connectors
1154
    # TODO: but that we do not save these overrides (e.g. as annotations) means that
1155
    # TODO: we cannot patch data connectors upon resume.
1156
    # TODO: If we did, we would lose the user's provided overrides (e.g. unsaved credentials).
1157

1158
    # More init containers
1159
    session_extras = session_extras.concat(
1✔
1160
        await get_extra_init_containers(
1161
            nb_config,
1162
            user,
1163
            repositories,
1164
            git_providers,
1165
            storage_mount,
1166
            work_dir,
1167
            uid=environment.uid,
1168
            gid=environment.gid,
1169
        )
1170
    )
1171

1172
    # Extra containers
1173
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
1✔
1174

1175
    # Patching the image pull secret
1176
    image = session.spec.session.image
1✔
1177
    image_pull_secret = await get_image_pull_secret(
1✔
1178
        image=image,
1179
        server_name=server_name,
1180
        nb_config=nb_config,
1181
        connected_svcs_repo=connected_svcs_repo,
1182
        user=user,
1183
        internal_gitlab_user=internal_gitlab_user,
1184
    )
1185
    if image_pull_secret:
1✔
1186
        session_extras.concat(SessionExtraResources(secrets=[image_pull_secret]))
×
1187
        patch.spec.imagePullSecrets = [ImagePullSecret(name=image_pull_secret.name, adopt=image_pull_secret.adopt)]
×
1188
    else:
1189
        patch.spec.imagePullSecrets = RESET
1✔
1190

1191
    # Construct session patch
1192
    patch.spec.extraContainers = _make_patch_spec_list(
1✔
1193
        existing=session.spec.extraContainers or [], updated=session_extras.containers
1194
    )
1195
    patch.spec.initContainers = _make_patch_spec_list(
1✔
1196
        existing=session.spec.initContainers or [], updated=session_extras.init_containers
1197
    )
1198
    patch.spec.extraVolumes = _make_patch_spec_list(
1✔
1199
        existing=session.spec.extraVolumes or [], updated=session_extras.volumes
1200
    )
1201
    if not patch.spec.session:
1✔
1202
        patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
×
1203
    patch.spec.session.extraVolumeMounts = _make_patch_spec_list(
1✔
1204
        existing=session.spec.session.extraVolumeMounts or [], updated=session_extras.volume_mounts
1205
    )
1206

1207
    secrets_to_create = session_extras.secrets or []
1✔
1208
    for s in secrets_to_create:
1✔
1209
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1210

1211
    patch_serialized = patch.to_rfc7386()
1✔
1212
    if len(patch_serialized) == 0:
1✔
1213
        return session
×
1214

1215
    return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch_serialized)
1✔
1216

1217

1218
def _deduplicate_target_paths(dcs: dict[str, RCloneStorage]) -> dict[str, RCloneStorage]:
2✔
1219
    """Ensures that the target paths for all storages are unique.
1220

1221
    This method will attempt to de-duplicate the target_path for all items passed in,
1222
    and raise an error if it fails to generate unique target_path.
1223
    """
1224
    result_dcs: dict[str, RCloneStorage] = {}
1✔
1225
    mount_folders: dict[str, list[str]] = {}
1✔
1226

1227
    def _find_mount_folder(dc: RCloneStorage) -> str:
1✔
1228
        mount_folder = dc.mount_folder
×
1229
        if mount_folder not in mount_folders:
×
1230
            return mount_folder
×
1231
        # 1. Try with a "-1", "-2", etc. suffix
1232
        mount_folder_try = f"{mount_folder}-{len(mount_folders[mount_folder])}"
×
1233
        if mount_folder_try not in mount_folders:
×
1234
            return mount_folder_try
×
1235
        # 2. Try with a random suffix
1236
        suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(4)])  # nosec B311
×
1237
        mount_folder_try = f"{mount_folder}-{suffix}"
×
1238
        if mount_folder_try not in mount_folders:
×
1239
            return mount_folder_try
×
1240
        raise errors.ValidationError(
×
1241
            message=f"Could not start session because two or more data connectors ({', '.join(mount_folders[mount_folder])}) share the same mount point '{mount_folder}'"  # noqa E501
1242
        )
1243

1244
    for dc_id, dc in dcs.items():
1✔
1245
        original_mount_folder = dc.mount_folder
×
1246
        new_mount_folder = _find_mount_folder(dc)
×
1247
        # Keep track of the original mount folder here
1248
        if new_mount_folder != original_mount_folder:
×
1249
            logger.warning(f"Re-assigning data connector {dc_id} to mount point '{new_mount_folder}'")
×
1250
            dc_ids = mount_folders.get(original_mount_folder, [])
×
1251
            dc_ids.append(dc_id)
×
1252
            mount_folders[original_mount_folder] = dc_ids
×
1253
        # Keep track of the assigned mount folder here
1254
        dc_ids = mount_folders.get(new_mount_folder, [])
×
1255
        dc_ids.append(dc_id)
×
1256
        mount_folders[new_mount_folder] = dc_ids
×
1257
        result_dcs[dc_id] = dc.with_override(
×
1258
            override=SessionDataConnectorOverride(
1259
                skip=False,
1260
                data_connector_id=ULID.from_str(dc_id),
1261
                target_path=new_mount_folder,
1262
                configuration=None,
1263
                source_path=None,
1264
                readonly=None,
1265
            )
1266
        )
1267

1268
    return result_dcs
1✔
1269

1270

1271
class _NamedResource(Protocol):
2✔
1272
    """Represents a resource with a name."""
1273

1274
    name: str
2✔
1275

1276

1277
_T = TypeVar("_T", bound=_NamedResource)
2✔
1278

1279

1280
def _make_patch_spec_list(existing: Sequence[_T], updated: Sequence[_T]) -> list[_T] | None:
2✔
1281
    """Merges updated into existing by upserting items identified by their name.
1282

1283
    This method is used to construct session patches, merging session resources by name (containers, volumes, etc.).
1284
    """
1285
    patch_list = None
1✔
1286
    if updated:
1✔
1287
        patch_list = list(existing)
1✔
1288
        upsert_list = list(updated)
1✔
1289
        for upsert_item in upsert_list:
1✔
1290
            # Find out if the upsert_item needs to be added or updated
1291
            # found = next(enumerate(filter(lambda item: item.name == upsert_item.name, patch_list)), None)
1292
            found = next(filter(lambda t: t[1].name == upsert_item.name, enumerate(patch_list)), None)
1✔
1293
            if found is not None:
1✔
1294
                idx, _ = found
1✔
1295
                patch_list[idx] = upsert_item
1✔
1296
            else:
1297
                patch_list.append(upsert_item)
1✔
1298
    return patch_list
1✔
1299

1300

1301
def validate_session_post_request(body: apispec.SessionPostRequest) -> SessionLaunchRequest:
2✔
1302
    """Validate a session launch request."""
1303
    data_connectors_overrides = (
1✔
1304
        [
1305
            SessionDataConnectorOverride(
1306
                skip=dc.skip,
1307
                data_connector_id=ULID.from_str(dc.data_connector_id),
1308
                configuration=dc.configuration,
1309
                source_path=dc.source_path,
1310
                target_path=dc.target_path,
1311
                readonly=dc.readonly,
1312
            )
1313
            for dc in body.data_connectors_overrides
1314
        ]
1315
        if body.data_connectors_overrides
1316
        else None
1317
    )
1318
    env_variable_overrides = (
1✔
1319
        [SessionEnvVar(name=ev.name, value=ev.value) for ev in body.env_variable_overrides]
1320
        if body.env_variable_overrides
1321
        else None
1322
    )
1323
    return SessionLaunchRequest(
1✔
1324
        launcher_id=ULID.from_str(body.launcher_id),
1325
        disk_storage=body.disk_storage,
1326
        resource_class_id=body.resource_class_id,
1327
        data_connectors_overrides=data_connectors_overrides,
1328
        env_variable_overrides=env_variable_overrides,
1329
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc