• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 18585923528

17 Oct 2025 07:36AM UTC coverage: 83.21% (-0.2%) from 83.365%
18585923528

Pull #1068

github

web-flow
Merge 5051fe076 into 6cc42274c
Pull Request #1068: feat: update data connectors when resuming a session

47 of 118 new or added lines in 4 files covered. (39.83%)

11 existing lines in 5 files now uncovered.

21787 of 26183 relevant lines covered (83.21%)

1.49 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

16.44
/components/renku_data_services/notebooks/core_sessions.py
1
"""A selection of core functions for AmaltheaSessions."""
2

3
import base64
2✔
4
import json
2✔
5
import os
2✔
6
import random
2✔
7
import string
2✔
8
from collections.abc import AsyncIterator, Sequence
2✔
9
from datetime import timedelta
2✔
10
from pathlib import PurePosixPath
2✔
11
from typing import Protocol, TypeVar, cast
2✔
12
from urllib.parse import urljoin, urlparse
2✔
13

14
import httpx
2✔
15
from kubernetes.client import V1ObjectMeta, V1Secret
2✔
16
from sanic import Request
2✔
17
from toml import dumps
2✔
18
from ulid import ULID
2✔
19
from yaml import safe_dump
2✔
20

21
import renku_data_services.notebooks.image_check as ic
2✔
22
from renku_data_services.app_config import logging
2✔
23
from renku_data_services.base_models import AnonymousAPIUser, APIUser, AuthenticatedAPIUser
2✔
24
from renku_data_services.base_models.metrics import MetricsService
2✔
25
from renku_data_services.connected_services.db import ConnectedServicesRepository
2✔
26
from renku_data_services.crc.db import ClusterRepository, ResourcePoolRepository
2✔
27
from renku_data_services.crc.models import (
2✔
28
    ClusterSettings,
29
    GpuKind,
30
    RemoteConfigurationFirecrest,
31
    ResourceClass,
32
    ResourcePool,
33
)
34
from renku_data_services.data_connectors.db import (
2✔
35
    DataConnectorSecretRepository,
36
)
37
from renku_data_services.data_connectors.models import DataConnectorSecret, DataConnectorWithSecrets
2✔
38
from renku_data_services.errors import errors
2✔
39
from renku_data_services.k8s.models import K8sSecret, sanitizer
2✔
40
from renku_data_services.notebooks import apispec, core
2✔
41
from renku_data_services.notebooks.api.amalthea_patches import git_proxy, init_containers
2✔
42
from renku_data_services.notebooks.api.amalthea_patches.init_containers import user_secrets_extras
2✔
43
from renku_data_services.notebooks.api.classes.image import Image
2✔
44
from renku_data_services.notebooks.api.classes.repository import GitProvider, Repository
2✔
45
from renku_data_services.notebooks.api.schemas.cloud_storage import RCloneStorage
2✔
46
from renku_data_services.notebooks.config import GitProviderHelperProto, NotebooksConfig
2✔
47
from renku_data_services.notebooks.crs import (
2✔
48
    AmaltheaSessionSpec,
49
    AmaltheaSessionV1Alpha1,
50
    AmaltheaSessionV1Alpha1MetadataPatch,
51
    AmaltheaSessionV1Alpha1Patch,
52
    AmaltheaSessionV1Alpha1SpecPatch,
53
    AmaltheaSessionV1Alpha1SpecSessionPatch,
54
    Authentication,
55
    AuthenticationType,
56
    Culling,
57
    DataSource,
58
    ExtraContainer,
59
    ExtraVolume,
60
    ExtraVolumeMount,
61
    ImagePullPolicy,
62
    ImagePullSecret,
63
    Ingress,
64
    InitContainer,
65
    Limits,
66
    LimitsStr,
67
    Metadata,
68
    ReconcileStrategy,
69
    Requests,
70
    RequestsStr,
71
    Resources,
72
    SecretAsVolume,
73
    SecretAsVolumeItem,
74
    Session,
75
    SessionEnvItem,
76
    SessionLocation,
77
    ShmSizeStr,
78
    SizeStr,
79
    Storage,
80
)
81
from renku_data_services.notebooks.models import (
2✔
82
    ExtraSecret,
83
    SessionDataConnectorOverride,
84
    SessionEnvVar,
85
    SessionExtraResources,
86
    SessionLaunchRequest,
87
    SessionPatchRequest,
88
    SessionState,
89
)
90
from renku_data_services.notebooks.util.kubernetes_ import (
2✔
91
    renku_2_make_server_name,
92
)
93
from renku_data_services.notebooks.utils import (
2✔
94
    node_affinity_from_resource_class,
95
    tolerations_from_resource_class,
96
)
97
from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository
2✔
98
from renku_data_services.project.models import Project, SessionSecret
2✔
99
from renku_data_services.session.db import SessionRepository
2✔
100
from renku_data_services.session.models import SessionLauncher
2✔
101
from renku_data_services.users.db import UserRepo
2✔
102
from renku_data_services.utils.cryptography import get_encryption_key
2✔
103

104
logger = logging.getLogger(__name__)
2✔
105

106

107
async def get_extra_init_containers(
2✔
108
    nb_config: NotebooksConfig,
109
    user: AnonymousAPIUser | AuthenticatedAPIUser,
110
    repositories: list[Repository],
111
    git_providers: list[GitProvider],
112
    storage_mount: PurePosixPath,
113
    work_dir: PurePosixPath,
114
    uid: int = 1000,
115
    gid: int = 1000,
116
) -> SessionExtraResources:
117
    """Get all extra init containers that should be added to an amalthea session."""
118
    # TODO: The above statement is not correct: the init container for user secrets is not included here
119
    cert_init, cert_vols = init_containers.certificates_container(nb_config)
×
120
    session_init_containers = [InitContainer.model_validate(sanitizer(cert_init))]
×
121
    extra_volumes = [ExtraVolume.model_validate(sanitizer(volume)) for volume in cert_vols]
×
122
    git_clone = await init_containers.git_clone_container_v2(
×
123
        user=user,
124
        config=nb_config,
125
        repositories=repositories,
126
        git_providers=git_providers,
127
        workspace_mount_path=storage_mount,
128
        work_dir=work_dir,
129
        uid=uid,
130
        gid=gid,
131
    )
132
    if git_clone is not None:
×
133
        session_init_containers.append(InitContainer.model_validate(git_clone))
×
134
    return SessionExtraResources(
×
135
        init_containers=session_init_containers,
136
        volumes=extra_volumes,
137
    )
138

139

140
async def get_extra_containers(
2✔
141
    nb_config: NotebooksConfig,
142
    user: AnonymousAPIUser | AuthenticatedAPIUser,
143
    repositories: list[Repository],
144
    git_providers: list[GitProvider],
145
) -> SessionExtraResources:
146
    """Get the extra containers added to amalthea sessions."""
147
    conts: list[ExtraContainer] = []
×
148
    git_proxy_container = await git_proxy.main_container(
×
149
        user=user, config=nb_config, repositories=repositories, git_providers=git_providers
150
    )
151
    if git_proxy_container:
×
152
        conts.append(ExtraContainer.model_validate(sanitizer(git_proxy_container)))
×
153
    return SessionExtraResources(containers=conts)
×
154

155

156
async def get_auth_secret_authenticated(
2✔
157
    nb_config: NotebooksConfig,
158
    user: AuthenticatedAPIUser,
159
    server_name: str,
160
    base_server_url: str,
161
    base_server_https_url: str,
162
    base_server_path: str,
163
) -> ExtraSecret:
164
    """Get the extra secrets that need to be added to the session for an authenticated user."""
165
    secret_data = {}
×
166

167
    parsed_proxy_url = urlparse(urljoin(base_server_url + "/", "oauth2"))
×
168
    vol = ExtraVolume(
×
169
        name="renku-authorized-emails",
170
        secret=SecretAsVolume(
171
            secretName=server_name,
172
            items=[SecretAsVolumeItem(key="authorized_emails", path="authorized_emails")],
173
        ),
174
    )
175
    secret_data["auth"] = dumps(
×
176
        {
177
            "provider": "oidc",
178
            "client_id": nb_config.sessions.oidc.client_id,
179
            "oidc_issuer_url": nb_config.sessions.oidc.issuer_url,
180
            "session_cookie_minimal": True,
181
            "skip_provider_button": True,
182
            # NOTE: If the redirect url is not HTTPS then some or identity providers will fail.
183
            "redirect_url": urljoin(base_server_https_url + "/", "oauth2/callback"),
184
            "cookie_path": base_server_path,
185
            "proxy_prefix": parsed_proxy_url.path,
186
            "authenticated_emails_file": "/authorized_emails",
187
            "client_secret": nb_config.sessions.oidc.client_secret,
188
            "cookie_secret": base64.urlsafe_b64encode(os.urandom(32)).decode(),
189
            "insecure_oidc_allow_unverified_email": nb_config.sessions.oidc.allow_unverified_email,
190
        }
191
    )
192
    secret_data["authorized_emails"] = user.email
×
193
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
×
194
    vol_mount = ExtraVolumeMount(
×
195
        name="renku-authorized-emails",
196
        mountPath="/authorized_emails",
197
        subPath="authorized_emails",
198
    )
199
    return ExtraSecret(secret, vol, vol_mount)
×
200

201

202
def get_auth_secret_anonymous(nb_config: NotebooksConfig, server_name: str, request: Request) -> ExtraSecret:
2✔
203
    """Get the extra secrets that need to be added to the session for an anonymous user."""
204
    # NOTE: We extract the session cookie value here in order to avoid creating a cookie.
205
    # The gateway encrypts and signs cookies so the user ID injected in the request headers does not
206
    # match the value of the session cookie.
207
    session_id = cast(str | None, request.cookies.get(nb_config.session_id_cookie_name))
×
208
    if not session_id:
×
209
        raise errors.UnauthorizedError(
×
210
            message=f"You have to have a renku session cookie at {nb_config.session_id_cookie_name} "
211
            "in order to launch an anonymous session."
212
        )
213
    # NOTE: Amalthea looks for the token value first in the cookie and then in the authorization header
214
    secret_data = {
×
215
        "auth": safe_dump(
216
            {
217
                "authproxy": {
218
                    "token": session_id,
219
                    "cookie_key": nb_config.session_id_cookie_name,
220
                    "verbose": True,
221
                }
222
            }
223
        )
224
    }
225
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
×
226
    return ExtraSecret(secret)
×
227

228

229
def __get_gitlab_image_pull_secret(
2✔
230
    nb_config: NotebooksConfig, user: AuthenticatedAPIUser, image_pull_secret_name: str, access_token: str
231
) -> ExtraSecret:
232
    """Create a Kubernetes secret for private GitLab registry authentication."""
233

234
    k8s_namespace = nb_config.k8s_client.namespace()
×
235

236
    registry_secret = {
×
237
        "auths": {
238
            nb_config.git.registry: {
239
                "Username": "oauth2",
240
                "Password": access_token,
241
                "Email": user.email,
242
            }
243
        }
244
    }
245
    registry_secret = json.dumps(registry_secret)
×
246

247
    secret_data = {".dockerconfigjson": registry_secret}
×
248
    secret = V1Secret(
×
249
        metadata=V1ObjectMeta(name=image_pull_secret_name, namespace=k8s_namespace),
250
        string_data=secret_data,
251
        type="kubernetes.io/dockerconfigjson",
252
    )
253

254
    return ExtraSecret(secret)
×
255

256

257
async def get_data_sources(
2✔
258
    nb_config: NotebooksConfig,
259
    user: AnonymousAPIUser | AuthenticatedAPIUser,
260
    server_name: str,
261
    data_connectors_stream: AsyncIterator[DataConnectorWithSecrets],
262
    work_dir: PurePosixPath,
263
    data_connectors_overrides: list[SessionDataConnectorOverride],
264
    user_repo: UserRepo,
265
) -> SessionExtraResources:
266
    """Generate cloud storage related resources."""
267
    data_sources: list[DataSource] = []
×
268
    secrets: list[ExtraSecret] = []
×
269
    dcs: dict[str, RCloneStorage] = {}
×
270
    dcs_secrets: dict[str, list[DataConnectorSecret]] = {}
×
NEW
271
    skipped_dcs: set[str] = set()
×
272
    user_secret_key: str | None = None
×
273
    async for dc in data_connectors_stream:
×
274
        mount_folder = (
×
275
            dc.data_connector.storage.target_path
276
            if PurePosixPath(dc.data_connector.storage.target_path).is_absolute()
277
            else (work_dir / dc.data_connector.storage.target_path).as_posix()
278
        )
279
        dcs[str(dc.data_connector.id)] = RCloneStorage(
×
280
            source_path=dc.data_connector.storage.source_path,
281
            mount_folder=mount_folder,
282
            configuration=dc.data_connector.storage.configuration,
283
            readonly=dc.data_connector.storage.readonly,
284
            name=dc.data_connector.name,
285
            secrets={str(secret.secret_id): secret.name for secret in dc.secrets},
286
            storage_class=nb_config.cloud_storage.storage_class,
287
        )
288
        if len(dc.secrets) > 0:
×
289
            dcs_secrets[str(dc.data_connector.id)] = dc.secrets
×
290
    if isinstance(user, AuthenticatedAPIUser) and len(dcs_secrets) > 0:
×
291
        secret_key = await user_repo.get_or_create_user_secret_key(user)
×
292
        user_secret_key = get_encryption_key(secret_key.encode(), user.id.encode()).decode("utf-8")
×
293
    # NOTE: Check the cloud storage overrides from the request body and if any match
294
    # then overwrite the projects cloud storages
295
    # NOTE: Cloud storages in the session launch request body that are not from the DB will cause a 404 error
296
    # TODO: Is this correct? -> NOTE: Overriding the configuration when a saved secret is there will cause a 422 error
NEW
297
    for dco in data_connectors_overrides:
×
NEW
298
        dc_id = str(dco.data_connector_id)
×
NEW
299
        if dc_id not in dcs:
×
UNCOV
300
            raise errors.MissingResourceError(
×
301
                message=f"You have requested a data connector with ID {dc_id} which does not exist "
302
                "or you don't have access to."
303
            )
304
        # NOTE: if 'skip' is true, we do not mount that data connector
NEW
305
        if dco.skip:
×
NEW
306
            skipped_dcs.add(dc_id)
×
NEW
307
            del dcs[dc_id]
×
NEW
308
            continue
×
NEW
309
        if dco.target_path is not None and not PurePosixPath(dco.target_path).is_absolute():
×
NEW
310
            dco.target_path = (work_dir / dco.target_path).as_posix()
×
NEW
311
        dcs[dc_id] = dcs[dc_id].with_override(dco)
×
312

313
    # Handle potential duplicate target_path
314
    dcs = _deduplicate_target_paths(dcs)
×
315

316
    for cs_id, cs in dcs.items():
×
317
        secret_name = f"{server_name}-ds-{cs_id.lower()}"
×
318
        secret_key_needed = len(dcs_secrets.get(cs_id, [])) > 0
×
319
        if secret_key_needed and user_secret_key is None:
×
320
            raise errors.ProgrammingError(
×
321
                message=f"You have saved storage secrets for data connector {cs_id} "
322
                f"associated with your user ID {user.id} but no key to decrypt them, "
323
                "therefore we cannot mount the requested data connector. "
324
                "Please report this to the renku administrators."
325
            )
326
        secret = ExtraSecret(
×
327
            cs.secret(
328
                secret_name,
329
                await nb_config.k8s_client.namespace(),
330
                user_secret_key=user_secret_key if secret_key_needed else None,
331
            )
332
        )
333
        secrets.append(secret)
×
334
        data_sources.append(
×
335
            DataSource(
336
                mountPath=cs.mount_folder,
337
                secretRef=secret.ref(),
338
                accessMode="ReadOnlyMany" if cs.readonly else "ReadWriteOnce",
339
            )
340
        )
341

342
    # Add annotations to track skipped data connectors
343
    # annotations: dict[str, str] = {"renku.io/mounted_data_connectors_ids": json.dumps(sorted(dcs.keys()))}
NEW
344
    annotations: dict[str, str] = dict()
×
NEW
345
    if skipped_dcs:
×
NEW
346
        annotations["renku.io/skipped_data_connectors_ids"] = json.dumps(sorted(skipped_dcs))
×
347

UNCOV
348
    return SessionExtraResources(
×
349
        annotations=annotations,
350
        data_sources=data_sources,
351
        secrets=secrets,
352
        data_connector_secrets=dcs_secrets,
353
    )
354

355

356
async def patch_data_sources(
2✔
357
    existing_session: AmaltheaSessionV1Alpha1,
358
    nb_config: NotebooksConfig,
359
    user: AnonymousAPIUser | AuthenticatedAPIUser,
360
    server_name: str,
361
    data_connectors_stream: AsyncIterator[DataConnectorWithSecrets],
362
    work_dir: PurePosixPath,
363
    data_connectors_overrides: list[SessionDataConnectorOverride],
364
    user_repo: UserRepo,
365
) -> None:  # -> SessionExtraResources:
366
    """Handle patching data sources."""
367

368
    # First, collect the data connectors we already mount in the session
NEW
369
    existing_dcs: set[str] = set()
×
NEW
370
    secret_name_prefix = f"{server_name}-ds-"
×
NEW
371
    for ds in existing_session.spec.dataSources or []:
×
NEW
372
        if not ds.secretRef:
×
NEW
373
            continue
×
NEW
374
        if ds.secretRef.name.startswith(secret_name_prefix):
×
NEW
375
            dc_id = str(ULID.from_str(ds.secretRef.name[len(secret_name_prefix) :].upper()))
×
NEW
376
            existing_dcs.add(dc_id)
×
NEW
377
    logger.warning(f"existing_dcs = {existing_dcs}")
×
378

379
    # Collect the data connectors we already skip
NEW
380
    existing_skipped_dcs: set[str] = set(
×
381
        json.loads(existing_session.metadata.annotations.get("renku.io/skipped_data_connectors_ids", "[]"))
382
    )
NEW
383
    logger.warning(f"existing_skipped_dcs = {existing_skipped_dcs}")
×
384

385
    # Collect the previously skipped data connectors we should mount now
NEW
386
    newly_unskipped_dcs: set[str] = set()
×
NEW
387
    for dco in data_connectors_overrides:
×
NEW
388
        dc_id = str(dco.data_connector_id)
×
NEW
389
        if dc_id in existing_skipped_dcs and not dco.skip:
×
NEW
390
            newly_unskipped_dcs.add(dc_id)
×
NEW
391
    logger.warning(f"newly_unskipped_dcs = {newly_unskipped_dcs}")
×
392

393
    # Collect the new data connectors
NEW
394
    new_dcs: dict[str, DataConnectorWithSecrets] = dict()
×
NEW
395
    async for dc in data_connectors_stream:
×
NEW
396
        dc_id = str(dc.data_connector.id)
×
NEW
397
        if (dc_id in newly_unskipped_dcs) or ((dc_id not in existing_dcs) and (dc_id not in existing_skipped_dcs)):
×
NEW
398
            new_dcs[dc_id] = dc
×
NEW
399
    logger.warning(f"new_dcs = {sorted(new_dcs.keys())}")
×
400

NEW
401
    async def new_dcs_stream() -> AsyncIterator[DataConnectorWithSecrets]:
×
NEW
402
        for dc in new_dcs.values():
×
NEW
403
            yield dc
×
404

NEW
405
    session_extras = await get_data_sources(
×
406
        nb_config=nb_config,
407
        server_name=server_name,
408
        user=user,
409
        data_connectors_stream=new_dcs_stream(),
410
        work_dir=work_dir,
411
        data_connectors_overrides=data_connectors_overrides,
412
        user_repo=user_repo,
413
    )
NEW
414
    logger.warning(f"session_extras.annotations = {session_extras.annotations}")
×
NEW
415
    logger.warning(f"session_extras.data_sources = {session_extras.data_sources}")
×
416

NEW
417
    pass
×
418

419

420
async def request_dc_secret_creation(
2✔
421
    user: AuthenticatedAPIUser | AnonymousAPIUser,
422
    nb_config: NotebooksConfig,
423
    manifest: AmaltheaSessionV1Alpha1,
424
    dc_secrets: dict[str, list[DataConnectorSecret]],
425
) -> None:
426
    """Request the specified data connector secrets to be created by the secret service."""
427
    if isinstance(user, AnonymousAPIUser):
×
428
        return
×
429
    owner_reference = {
×
430
        "apiVersion": manifest.apiVersion,
431
        "kind": manifest.kind,
432
        "name": manifest.metadata.name,
433
        "uid": manifest.metadata.uid,
434
    }
435
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
×
436
    headers = {"Authorization": f"bearer {user.access_token}"}
×
437

438
    cluster_id = None
×
439
    namespace = await nb_config.k8s_v2_client.namespace()
×
440
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
×
441
        cluster_id = cluster.id
×
442
        namespace = cluster.namespace
×
443

444
    for s_id, secrets in dc_secrets.items():
×
445
        if len(secrets) == 0:
×
446
            continue
×
447
        request_data = {
×
448
            "name": f"{manifest.metadata.name}-ds-{s_id.lower()}-secrets",
449
            "namespace": namespace,
450
            "secret_ids": [str(secret.secret_id) for secret in secrets],
451
            "owner_references": [owner_reference],
452
            "key_mapping": {str(secret.secret_id): secret.name for secret in secrets},
453
            "cluster_id": str(cluster_id),
454
        }
455
        async with httpx.AsyncClient(timeout=10) as client:
×
456
            res = await client.post(secrets_url, headers=headers, json=request_data)
×
457
            if res.status_code >= 300 or res.status_code < 200:
×
458
                raise errors.ProgrammingError(
×
459
                    message=f"The secret for data connector with {s_id} could not be "
460
                    f"successfully created, the status code was {res.status_code}."
461
                    "Please contact a Renku administrator.",
462
                    detail=res.text,
463
                )
464

465

466
def get_launcher_env_variables(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> list[SessionEnvItem]:
2✔
467
    """Get the environment variables from the launcher, with overrides from the request."""
468
    output: list[SessionEnvItem] = []
×
NEW
469
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
×
470
    for env in launcher.env_variables or []:
×
471
        if env.name in env_overrides:
×
472
            output.append(SessionEnvItem(name=env.name, value=env_overrides[env.name]))
×
473
        else:
474
            output.append(SessionEnvItem(name=env.name, value=env.value))
×
475
    return output
×
476

477

478
def verify_launcher_env_variable_overrides(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> None:
2✔
479
    """Raise an error if there are env variables that are not defined in the launcher."""
NEW
480
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
×
481
    known_env_names = {i.name for i in launcher.env_variables or []}
×
482
    unknown_env_names = set(env_overrides.keys()) - known_env_names
×
483
    if unknown_env_names:
×
484
        message = f"""The following environment variables are not defined in the session launcher: {unknown_env_names}.
×
485
            Please remove them from the launch request or add them to the session launcher."""
486
        raise errors.ValidationError(message=message)
×
487

488

489
async def request_session_secret_creation(
2✔
490
    user: AuthenticatedAPIUser | AnonymousAPIUser,
491
    nb_config: NotebooksConfig,
492
    manifest: AmaltheaSessionV1Alpha1,
493
    session_secrets: list[SessionSecret],
494
) -> None:
495
    """Request the specified user session secrets to be created by the secret service."""
496
    if isinstance(user, AnonymousAPIUser):
×
497
        return
×
498
    if not session_secrets:
×
499
        return
×
500
    owner_reference = {
×
501
        "apiVersion": manifest.apiVersion,
502
        "kind": manifest.kind,
503
        "name": manifest.metadata.name,
504
        "uid": manifest.metadata.uid,
505
    }
506
    key_mapping: dict[str, list[str]] = dict()
×
507
    for s in session_secrets:
×
508
        secret_id = str(s.secret_id)
×
509
        if secret_id not in key_mapping:
×
510
            key_mapping[secret_id] = list()
×
511
        key_mapping[secret_id].append(s.secret_slot.filename)
×
512

513
    cluster_id = None
×
514
    namespace = await nb_config.k8s_v2_client.namespace()
×
515
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
×
516
        cluster_id = cluster.id
×
517
        namespace = cluster.namespace
×
518

519
    request_data = {
×
520
        "name": f"{manifest.metadata.name}-secrets",
521
        "namespace": namespace,
522
        "secret_ids": [str(s.secret_id) for s in session_secrets],
523
        "owner_references": [owner_reference],
524
        "key_mapping": key_mapping,
525
        "cluster_id": str(cluster_id),
526
    }
527
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
×
528
    headers = {"Authorization": f"bearer {user.access_token}"}
×
529
    async with httpx.AsyncClient(timeout=10) as client:
×
530
        res = await client.post(secrets_url, headers=headers, json=request_data)
×
531
        if res.status_code >= 300 or res.status_code < 200:
×
532
            raise errors.ProgrammingError(
×
533
                message="The session secrets could not be successfully created, "
534
                f"the status code was {res.status_code}."
535
                "Please contact a Renku administrator.",
536
                detail=res.text,
537
            )
538

539

540
def resources_from_resource_class(resource_class: ResourceClass) -> Resources:
2✔
541
    """Convert the resource class to a k8s resources spec."""
542
    requests: dict[str, Requests | RequestsStr] = {
×
543
        "cpu": RequestsStr(str(round(resource_class.cpu * 1000)) + "m"),
544
        "memory": RequestsStr(f"{resource_class.memory}Gi"),
545
    }
546
    limits: dict[str, Limits | LimitsStr] = {"memory": LimitsStr(f"{resource_class.memory}Gi")}
×
547
    if resource_class.gpu > 0:
×
548
        gpu_name = GpuKind.NVIDIA.value + "/gpu"
×
549
        requests[gpu_name] = Requests(resource_class.gpu)
×
550
        # NOTE: GPUs have to be set in limits too since GPUs cannot be overcommited, if
551
        # not on some clusters this will cause the session to fully fail to start.
552
        limits[gpu_name] = Limits(resource_class.gpu)
×
553
    return Resources(requests=requests, limits=limits if len(limits) > 0 else None)
×
554

555

556
def repositories_from_project(project: Project, git_providers: list[GitProvider]) -> list[Repository]:
2✔
557
    """Get the list of git repositories from a project."""
558
    repositories: list[Repository] = []
×
559
    for repo in project.repositories:
×
560
        found_provider_id: str | None = None
×
561
        for provider in git_providers:
×
562
            if urlparse(provider.url).netloc == urlparse(repo).netloc:
×
563
                found_provider_id = provider.id
×
564
                break
×
565
        repositories.append(Repository(url=repo, provider=found_provider_id))
×
566
    return repositories
×
567

568

569
async def repositories_from_session(
2✔
570
    user: AnonymousAPIUser | AuthenticatedAPIUser,
571
    session: AmaltheaSessionV1Alpha1,
572
    project_repo: ProjectRepository,
573
    git_providers: list[GitProvider],
574
) -> list[Repository]:
575
    """Get the list of git repositories from a session."""
576
    try:
×
577
        project = await project_repo.get_project(user, session.project_id)
×
578
    except errors.MissingResourceError:
×
579
        return []
×
580
    return repositories_from_project(project, git_providers)
×
581

582

583
def get_culling(
2✔
584
    user: AuthenticatedAPIUser | AnonymousAPIUser, resource_pool: ResourcePool, nb_config: NotebooksConfig
585
) -> Culling:
586
    """Create the culling specification for an AmaltheaSession."""
587
    idle_threshold_seconds = resource_pool.idle_threshold or nb_config.sessions.culling.registered.idle_seconds
×
588
    if user.is_anonymous:
×
589
        # NOTE: Anonymous sessions should not be hibernated at all, but there is no such option in Amalthea
590
        # So in this case we set a very low hibernation threshold so the session is deleted quickly after
591
        # it is hibernated.
592
        hibernation_threshold_seconds = 1
×
593
    else:
594
        hibernation_threshold_seconds = (
×
595
            resource_pool.hibernation_threshold or nb_config.sessions.culling.registered.hibernated_seconds
596
        )
597
    return Culling(
×
598
        maxAge=timedelta(seconds=nb_config.sessions.culling.registered.max_age_seconds),
599
        maxFailedDuration=timedelta(seconds=nb_config.sessions.culling.registered.failed_seconds),
600
        maxHibernatedDuration=timedelta(seconds=hibernation_threshold_seconds),
601
        maxIdleDuration=timedelta(seconds=idle_threshold_seconds),
602
        maxStartingDuration=timedelta(seconds=nb_config.sessions.culling.registered.pending_seconds),
603
    )
604

605

606
async def __requires_image_pull_secret(nb_config: NotebooksConfig, image: str, internal_gitlab_user: APIUser) -> bool:
2✔
607
    """Determines if an image requires a pull secret based on its visibility and their GitLab access token."""
608

609
    parsed_image = Image.from_path(image)
×
610
    image_repo = parsed_image.repo_api()
×
611

612
    image_exists_publicly = await image_repo.image_exists(parsed_image)
×
613
    if image_exists_publicly:
×
614
        return False
×
615

616
    if parsed_image.hostname == nb_config.git.registry and internal_gitlab_user.access_token:
×
617
        image_repo = image_repo.with_oauth2_token(internal_gitlab_user.access_token)
×
618
        image_exists_privately = await image_repo.image_exists(parsed_image)
×
619
        if image_exists_privately:
×
620
            return True
×
621
    # No pull secret needed if the image is private and the user cannot access it
622
    return False
×
623

624

625
def __format_image_pull_secret(secret_name: str, access_token: str, registry_domain: str) -> ExtraSecret:
2✔
626
    registry_secret = {
×
627
        "auths": {registry_domain: {"auth": base64.b64encode(f"oauth2:{access_token}".encode()).decode()}}
628
    }
629
    registry_secret = json.dumps(registry_secret)
×
630
    registry_secret = base64.b64encode(registry_secret.encode()).decode()
×
631
    return ExtraSecret(
×
632
        V1Secret(
633
            data={".dockerconfigjson": registry_secret},
634
            metadata=V1ObjectMeta(name=secret_name),
635
            type="kubernetes.io/dockerconfigjson",
636
        )
637
    )
638

639

640
async def __get_connected_services_image_pull_secret(
2✔
641
    secret_name: str, connected_svcs_repo: ConnectedServicesRepository, image: str, user: APIUser
642
) -> ExtraSecret | None:
643
    """Return a secret for accessing the image if one is available for the given user."""
644
    image_parsed = Image.from_path(image)
×
645
    image_check_result = await ic.check_image(image_parsed, user, connected_svcs_repo, None)
×
646
    logger.debug(f"Set pull secret for {image} to connection {image_check_result.image_provider}")
×
647
    if not image_check_result.token:
×
648
        return None
×
649

650
    if not image_check_result.image_provider:
×
651
        return None
×
652

653
    return __format_image_pull_secret(
×
654
        secret_name=secret_name,
655
        access_token=image_check_result.token,
656
        registry_domain=image_check_result.image_provider.registry_url,
657
    )
658

659

660
async def get_image_pull_secret(
2✔
661
    image: str,
662
    server_name: str,
663
    nb_config: NotebooksConfig,
664
    user: APIUser,
665
    internal_gitlab_user: APIUser,
666
    connected_svcs_repo: ConnectedServicesRepository,
667
) -> ExtraSecret | None:
668
    """Get an image pull secret."""
669

670
    v2_secret = await __get_connected_services_image_pull_secret(
×
671
        f"{server_name}-image-secret", connected_svcs_repo, image, user
672
    )
673
    if v2_secret:
×
674
        return v2_secret
×
675

676
    if (
×
677
        nb_config.enable_internal_gitlab
678
        and isinstance(user, AuthenticatedAPIUser)
679
        and internal_gitlab_user.access_token is not None
680
    ):
681
        needs_pull_secret = await __requires_image_pull_secret(nb_config, image, internal_gitlab_user)
×
682
        if needs_pull_secret:
×
683
            v1_secret = __get_gitlab_image_pull_secret(
×
684
                nb_config, user, f"{server_name}-image-secret-v1", internal_gitlab_user.access_token
685
            )
686
            return v1_secret
×
687

688
    return None
×
689

690

691
def get_remote_secret(
2✔
692
    user: AuthenticatedAPIUser | AnonymousAPIUser,
693
    config: NotebooksConfig,
694
    server_name: str,
695
    remote_provider_id: str,
696
    git_providers: list[GitProvider],
697
) -> ExtraSecret | None:
698
    """Returns the secret containing the configuration for the remote session controller."""
699
    if not user.is_authenticated or user.access_token is None or user.refresh_token is None:
×
700
        return None
×
701
    remote_provider = next(filter(lambda p: p.id == remote_provider_id, git_providers), None)
×
702
    if not remote_provider:
×
703
        return None
×
704
    renku_base_url = "https://" + config.sessions.ingress.host
×
705
    renku_base_url = renku_base_url.rstrip("/")
×
706
    renku_auth_token_uri = f"{renku_base_url}/auth/realms/{config.keycloak_realm}/protocol/openid-connect/token"
×
707
    secret_data = {
×
708
        "RSC_AUTH_KIND": "renku",
709
        "RSC_AUTH_TOKEN_URI": remote_provider.access_token_url,
710
        "RSC_AUTH_RENKU_ACCESS_TOKEN": user.access_token,
711
        "RSC_AUTH_RENKU_REFRESH_TOKEN": user.refresh_token,
712
        "RSC_AUTH_RENKU_TOKEN_URI": renku_auth_token_uri,
713
        "RSC_AUTH_RENKU_CLIENT_ID": config.sessions.git_proxy.renku_client_id,
714
        "RSC_AUTH_RENKU_CLIENT_SECRET": config.sessions.git_proxy.renku_client_secret,
715
    }
716
    secret_name = f"{server_name}-remote-secret"
×
717
    secret = V1Secret(metadata=V1ObjectMeta(name=secret_name), string_data=secret_data)
×
718
    return ExtraSecret(secret)
×
719

720

721
def get_remote_env(
2✔
722
    remote: RemoteConfigurationFirecrest,
723
) -> list[SessionEnvItem]:
724
    """Returns env variables used for remote sessions."""
725
    env = [
×
726
        SessionEnvItem(name="RSC_REMOTE_KIND", value=remote.kind.value),
727
        SessionEnvItem(name="RSC_FIRECREST_API_URL", value=remote.api_url),
728
        SessionEnvItem(name="RSC_FIRECREST_SYSTEM_NAME", value=remote.system_name),
729
    ]
730
    if remote.partition:
×
731
        env.append(SessionEnvItem(name="RSC_FIRECREST_PARTITION", value=remote.partition))
×
732
    return env
×
733

734

735
async def start_session(
2✔
736
    request: Request,
737
    launch_request: SessionLaunchRequest,
738
    user: AnonymousAPIUser | AuthenticatedAPIUser,
739
    internal_gitlab_user: APIUser,
740
    nb_config: NotebooksConfig,
741
    git_provider_helper: GitProviderHelperProto,
742
    cluster_repo: ClusterRepository,
743
    data_connector_secret_repo: DataConnectorSecretRepository,
744
    project_repo: ProjectRepository,
745
    project_session_secret_repo: ProjectSessionSecretRepository,
746
    rp_repo: ResourcePoolRepository,
747
    session_repo: SessionRepository,
748
    user_repo: UserRepo,
749
    metrics: MetricsService,
750
    connected_svcs_repo: ConnectedServicesRepository,
751
) -> tuple[AmaltheaSessionV1Alpha1, bool]:
752
    """Start an Amalthea session.
753

754
    Returns a tuple where the first item is an instance of an Amalthea session
755
    and the second item is a boolean set to true iff a new session was created.
756
    """
NEW
757
    launcher = await session_repo.get_launcher(user=user, launcher_id=launch_request.launcher_id)
×
NEW
758
    launcher_id = launcher.id
×
UNCOV
759
    project = await project_repo.get_project(user=user, project_id=launcher.project_id)
×
760

761
    # Determine resource_class_id: the class can be overwritten at the user's request
NEW
762
    resource_class_id = launch_request.resource_class_id or launcher.resource_class_id
×
763

764
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(resource_class_id, user)
×
765

766
    server_name = renku_2_make_server_name(
×
767
        user=user, project_id=str(launcher.project_id), launcher_id=str(launcher_id), cluster_id=str(cluster.id)
768
    )
NEW
769
    existing_session = await nb_config.k8s_v2_client.get_session(name=server_name, safe_username=user.id)
×
770
    if existing_session is not None and existing_session.spec is not None:
×
771
        return existing_session, False
×
772

773
    # Fully determine the resource pool and resource class
774
    if resource_class_id is None:
×
775
        resource_pool = await rp_repo.get_default_resource_pool()
×
776
        resource_class = resource_pool.get_default_resource_class()
×
777
        if not resource_class and len(resource_pool.classes) > 0:
×
778
            resource_class = resource_pool.classes[0]
×
779
        if not resource_class or not resource_class.id:
×
780
            raise errors.ProgrammingError(message="Cannot find any resource classes in the default pool.")
×
781
        resource_class_id = resource_class.id
×
782
    else:
783
        resource_pool = await rp_repo.get_resource_pool_from_class(user, resource_class_id)
×
784
        resource_class = resource_pool.get_resource_class(resource_class_id)
×
785
        if not resource_class or not resource_class.id:
×
786
            raise errors.MissingResourceError(message=f"The resource class with ID {resource_class_id} does not exist.")
×
NEW
787
    await nb_config.crc_validator.validate_class_storage(user, resource_class.id, launch_request.disk_storage)
×
NEW
788
    disk_storage = launch_request.disk_storage or resource_class.default_storage
×
789

790
    # Determine session location
791
    session_location = SessionLocation.remote if resource_pool.remote else SessionLocation.local
×
792
    if session_location == SessionLocation.remote and not user.is_authenticated:
×
793
        raise errors.ValidationError(message="Anonymous users cannot start remote sessions.")
×
794

795
    environment = launcher.environment
×
796
    image = environment.container_image
×
797
    work_dir = environment.working_directory
×
798
    if not work_dir:
×
799
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
×
800
        work_dir_fallback = PurePosixPath("/home/jovyan")
×
801
        work_dir = image_workdir or work_dir_fallback
×
802
    storage_mount_fallback = work_dir / "work"
×
803
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
×
804
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
×
805
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
×
806
        user=user, project_id=project.id
807
    )
808
    data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id)
×
809
    git_providers = await git_provider_helper.get_providers(user=user)
×
810
    repositories = repositories_from_project(project, git_providers)
×
811

812
    # User secrets
813
    session_extras = SessionExtraResources()
×
814
    session_extras = session_extras.concat(
×
815
        user_secrets_extras(
816
            user=user,
817
            config=nb_config,
818
            secrets_mount_directory=secrets_mount_directory.as_posix(),
819
            k8s_secret_name=f"{server_name}-secrets",
820
            session_secrets=session_secrets,
821
        )
822
    )
823

824
    # Data connectors
825
    session_extras = session_extras.concat(
×
826
        await get_data_sources(
827
            nb_config=nb_config,
828
            server_name=server_name,
829
            user=user,
830
            data_connectors_stream=data_connectors_stream,
831
            work_dir=work_dir,
832
            data_connectors_overrides=launch_request.data_connectors_overrides or [],
833
            user_repo=user_repo,
834
        )
835
    )
836

837
    # More init containers
838
    session_extras = session_extras.concat(
×
839
        await get_extra_init_containers(
840
            nb_config,
841
            user,
842
            repositories,
843
            git_providers,
844
            storage_mount,
845
            work_dir,
846
            uid=environment.uid,
847
            gid=environment.gid,
848
        )
849
    )
850

851
    # Extra containers
852
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
×
853

854
    # Cluster settings (ingress, storage class, etc)
855
    cluster_settings: ClusterSettings
856
    try:
×
857
        cluster_settings = await cluster_repo.select(cluster.id)
×
858
    except errors.MissingResourceError:
×
859
        # Fallback to global, main cluster parameters
860
        cluster_settings = nb_config.local_cluster_settings()
×
861

862
    (
×
863
        base_server_path,
864
        base_server_url,
865
        base_server_https_url,
866
        host,
867
        tls_secret,
868
        ingress_class_name,
869
        ingress_annotations,
870
    ) = cluster_settings.get_ingress_parameters(server_name)
871
    storage_class = cluster_settings.get_storage_class()
×
872
    service_account_name = cluster_settings.service_account_name
×
873

874
    ui_path = f"{base_server_path}/{environment.default_url.lstrip('/')}"
×
875

876
    ingress = Ingress(
×
877
        host=host,
878
        ingressClassName=ingress_class_name,
879
        annotations=ingress_annotations,
880
        tlsSecret=tls_secret,
881
        pathPrefix=base_server_path,
882
    )
883

884
    # Annotations
NEW
885
    session_extras = session_extras.concat(
×
886
        SessionExtraResources(
887
            annotations={
888
                "renku.io/project_id": str(launcher.project_id),
889
                "renku.io/launcher_id": str(launcher_id),
890
                "renku.io/resource_class_id": str(resource_class_id),
891
            }
892
        )
893
    )
894

895
    # Authentication
896
    if isinstance(user, AuthenticatedAPIUser):
×
897
        auth_secret = await get_auth_secret_authenticated(
×
898
            nb_config, user, server_name, base_server_url, base_server_https_url, base_server_path
899
        )
900
    else:
901
        auth_secret = get_auth_secret_anonymous(nb_config, server_name, request)
×
902
    session_extras = session_extras.concat(
×
903
        SessionExtraResources(
904
            secrets=[auth_secret],
905
            volumes=[auth_secret.volume] if auth_secret.volume else [],
906
        )
907
    )
908
    authn_extra_volume_mounts: list[ExtraVolumeMount] = []
×
909
    if auth_secret.volume_mount:
×
910
        authn_extra_volume_mounts.append(auth_secret.volume_mount)
×
911

912
    cert_vol_mounts = init_containers.certificates_volume_mounts(nb_config)
×
913
    if cert_vol_mounts:
×
914
        authn_extra_volume_mounts.extend(cert_vol_mounts)
×
915

916
    image_secret = await get_image_pull_secret(
×
917
        image=image,
918
        server_name=server_name,
919
        nb_config=nb_config,
920
        user=user,
921
        internal_gitlab_user=internal_gitlab_user,
922
        connected_svcs_repo=connected_svcs_repo,
923
    )
924
    if image_secret:
×
925
        session_extras = session_extras.concat(SessionExtraResources(secrets=[image_secret]))
×
926

927
    # Remote session configuration
928
    remote_secret = None
×
929
    if session_location == SessionLocation.remote:
×
930
        assert resource_pool.remote is not None
×
931
        if resource_pool.remote.provider_id is None:
×
932
            raise errors.ProgrammingError(
×
933
                message=f"The resource pool {resource_pool.id} configuration is not valid (missing field 'remote_provider_id')."  # noqa E501
934
            )
935
        remote_secret = get_remote_secret(
×
936
            user=user,
937
            config=nb_config,
938
            server_name=server_name,
939
            remote_provider_id=resource_pool.remote.provider_id,
940
            git_providers=git_providers,
941
        )
942
    if remote_secret is not None:
×
943
        session_extras = session_extras.concat(SessionExtraResources(secrets=[remote_secret]))
×
944

945
    # Raise an error if there are invalid environment variables in the request body
NEW
946
    verify_launcher_env_variable_overrides(launcher, launch_request)
×
947
    env = [
×
948
        SessionEnvItem(name="RENKU_BASE_URL_PATH", value=base_server_path),
949
        SessionEnvItem(name="RENKU_BASE_URL", value=base_server_url),
950
        SessionEnvItem(name="RENKU_MOUNT_DIR", value=storage_mount.as_posix()),
951
        SessionEnvItem(name="RENKU_SESSION", value="1"),
952
        SessionEnvItem(name="RENKU_SESSION_IP", value="0.0.0.0"),  # nosec B104
953
        SessionEnvItem(name="RENKU_SESSION_PORT", value=f"{environment.port}"),
954
        SessionEnvItem(name="RENKU_WORKING_DIR", value=work_dir.as_posix()),
955
        SessionEnvItem(name="RENKU_SECRETS_PATH", value=project.secrets_mount_directory.as_posix()),
956
        SessionEnvItem(name="RENKU_PROJECT_ID", value=str(project.id)),
957
        SessionEnvItem(name="RENKU_PROJECT_PATH", value=project.path.serialize()),
958
        SessionEnvItem(name="RENKU_LAUNCHER_ID", value=str(launcher.id)),
959
    ]
960
    if session_location == SessionLocation.remote:
×
961
        assert resource_pool.remote is not None
×
962
        env.extend(
×
963
            get_remote_env(
964
                remote=resource_pool.remote,
965
            )
966
        )
NEW
967
    launcher_env_variables = get_launcher_env_variables(launcher, launch_request)
×
968
    env.extend(launcher_env_variables)
×
969

970
    session = AmaltheaSessionV1Alpha1(
×
971
        metadata=Metadata(name=server_name, annotations=session_extras.annotations),
972
        spec=AmaltheaSessionSpec(
973
            location=session_location,
974
            imagePullSecrets=[ImagePullSecret(name=image_secret.name, adopt=True)] if image_secret else [],
975
            codeRepositories=[],
976
            hibernated=False,
977
            reconcileStrategy=ReconcileStrategy.whenFailedOrHibernated,
978
            priorityClassName=resource_class.quota,
979
            session=Session(
980
                image=image,
981
                imagePullPolicy=ImagePullPolicy.Always,
982
                urlPath=ui_path,
983
                port=environment.port,
984
                storage=Storage(
985
                    className=storage_class,
986
                    size=SizeStr(str(disk_storage) + "G"),
987
                    mountPath=storage_mount.as_posix(),
988
                ),
989
                workingDir=work_dir.as_posix(),
990
                runAsUser=environment.uid,
991
                runAsGroup=environment.gid,
992
                resources=resources_from_resource_class(resource_class),
993
                extraVolumeMounts=session_extras.volume_mounts,
994
                command=environment.command,
995
                args=environment.args,
996
                shmSize=ShmSizeStr("1G"),
997
                stripURLPath=environment.strip_path_prefix,
998
                env=env,
999
                remoteSecretRef=remote_secret.ref() if remote_secret else None,
1000
            ),
1001
            ingress=ingress,
1002
            extraContainers=session_extras.containers,
1003
            initContainers=session_extras.init_containers,
1004
            extraVolumes=session_extras.volumes,
1005
            culling=get_culling(user, resource_pool, nb_config),
1006
            authentication=Authentication(
1007
                enabled=True,
1008
                type=AuthenticationType.oauth2proxy
1009
                if isinstance(user, AuthenticatedAPIUser)
1010
                else AuthenticationType.token,
1011
                secretRef=auth_secret.key_ref("auth"),
1012
                extraVolumeMounts=authn_extra_volume_mounts,
1013
            ),
1014
            dataSources=session_extras.data_sources,
1015
            tolerations=tolerations_from_resource_class(resource_class, nb_config.sessions.tolerations_model),
1016
            affinity=node_affinity_from_resource_class(resource_class, nb_config.sessions.affinity_model),
1017
            serviceAccountName=service_account_name,
1018
        ),
1019
    )
1020
    secrets_to_create = session_extras.secrets or []
×
1021
    for s in secrets_to_create:
×
1022
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1023
    try:
×
1024
        session = await nb_config.k8s_v2_client.create_session(session, user)
×
1025
    except Exception as err:
×
1026
        for s in secrets_to_create:
×
1027
            await nb_config.k8s_v2_client.delete_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1028
        raise errors.ProgrammingError(message="Could not start the amalthea session") from err
×
1029
    else:
1030
        try:
×
1031
            await request_session_secret_creation(user, nb_config, session, session_secrets)
×
1032
            data_connector_secrets = session_extras.data_connector_secrets or dict()
×
1033
            await request_dc_secret_creation(user, nb_config, session, data_connector_secrets)
×
1034
        except Exception:
×
1035
            await nb_config.k8s_v2_client.delete_session(server_name, user.id)
×
1036
            raise
×
1037

1038
    await metrics.user_requested_session_launch(
×
1039
        user=user,
1040
        metadata={
1041
            "cpu": int(resource_class.cpu * 1000),
1042
            "memory": resource_class.memory,
1043
            "gpu": resource_class.gpu,
1044
            "storage": disk_storage,
1045
            "resource_class_id": resource_class.id,
1046
            "resource_pool_id": resource_pool.id or "",
1047
            "resource_class_name": f"{resource_pool.name}.{resource_class.name}",
1048
            "session_id": server_name,
1049
        },
1050
    )
1051
    return session, True
×
1052

1053

1054
async def patch_session(
2✔
1055
    patch_request: SessionPatchRequest,
1056
    session_id: str,
1057
    user: AnonymousAPIUser | AuthenticatedAPIUser,
1058
    internal_gitlab_user: APIUser,
1059
    nb_config: NotebooksConfig,
1060
    git_provider_helper: GitProviderHelperProto,
1061
    connected_svcs_repo: ConnectedServicesRepository,
1062
    data_connector_secret_repo: DataConnectorSecretRepository,
1063
    project_repo: ProjectRepository,
1064
    project_session_secret_repo: ProjectSessionSecretRepository,
1065
    rp_repo: ResourcePoolRepository,
1066
    session_repo: SessionRepository,
1067
    user_repo: UserRepo,
1068
    metrics: MetricsService,
1069
) -> AmaltheaSessionV1Alpha1:
1070
    """Patch an Amalthea session."""
1071
    session = await nb_config.k8s_v2_client.get_session(session_id, user.id)
×
1072
    if session is None:
×
1073
        raise errors.MissingResourceError(message=f"The session with ID {session_id} does not exist")
×
1074
    if session.spec is None:
×
1075
        raise errors.ProgrammingError(
×
1076
            message=f"The session {session_id} being patched is missing the expected 'spec' field.", quiet=True
1077
        )
1078
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(session.resource_class_id(), user)
×
1079

1080
    patch = AmaltheaSessionV1Alpha1Patch(spec=AmaltheaSessionV1Alpha1SpecPatch())
×
1081
    is_getting_hibernated: bool = False
×
1082

1083
    # Hibernation
1084
    # TODO: Some patching should only be done when the session is in some states to avoid inadvertent restarts
1085
    # Refresh tokens for git proxy
1086
    if (
×
1087
        patch_request.state is not None
1088
        and patch_request.state == SessionState.hibernated
1089
        and patch_request.state.value.lower() != session.status.state.value.lower()
1090
    ):
1091
        # Session is being hibernated
1092
        patch.spec.hibernated = True
×
1093
        is_getting_hibernated = True
×
1094
    elif (
×
1095
        patch_request.state is not None
1096
        and patch_request.state == SessionState.running
1097
        and session.status.state.value.lower() != patch_request.state.value.lower()
1098
    ):
1099
        # Session is being resumed
1100
        patch.spec.hibernated = False
×
1101
        await metrics.user_requested_session_resume(user, metadata={"session_id": session_id})
×
1102

1103
    # Resource class
NEW
1104
    if patch_request.resource_class_id is not None:
×
NEW
1105
        new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(patch_request.resource_class_id, user)
×
1106
        if new_cluster.id != cluster.id:
×
1107
            raise errors.ValidationError(
×
1108
                message=(
1109
                    f"The requested resource class {patch_request.resource_class_id} is not in the "
1110
                    f"same cluster {cluster.id} as the current resource class {session.resource_class_id()}."
1111
                )
1112
            )
NEW
1113
        rp = await rp_repo.get_resource_pool_from_class(user, patch_request.resource_class_id)
×
NEW
1114
        rc = rp.get_resource_class(patch_request.resource_class_id)
×
1115
        if not rc:
×
1116
            raise errors.MissingResourceError(
×
1117
                message=f"The resource class you requested with ID {patch_request.resource_class_id} does not exist"
1118
            )
1119
        # TODO: reject session classes which change the cluster
1120
        if not patch.metadata:
×
1121
            patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch()
×
1122
        # Patch the resource class ID in the annotations
NEW
1123
        patch.metadata.annotations = {"renku.io/resource_class_id": str(patch_request.resource_class_id)}
×
1124
        if not patch.spec.session:
×
1125
            patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
×
1126
        patch.spec.session.resources = resources_from_resource_class(rc)
×
1127
        # Tolerations
1128
        tolerations = tolerations_from_resource_class(rc, nb_config.sessions.tolerations_model)
×
1129
        patch.spec.tolerations = tolerations
×
1130
        # Affinities
1131
        patch.spec.affinity = node_affinity_from_resource_class(rc, nb_config.sessions.affinity_model)
×
1132
        # Priority class (if a quota is being used)
1133
        patch.spec.priorityClassName = rc.quota
×
1134
        patch.spec.culling = get_culling(user, rp, nb_config)
×
1135
        if rp.cluster is not None:
×
1136
            patch.spec.service_account_name = rp.cluster.service_account_name
×
1137

1138
    # If the session is being hibernated we do not need to patch anything else that is
1139
    # not specifically called for in the request body, we can refresh things when the user resumes.
1140
    if is_getting_hibernated:
×
1141
        return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch.to_rfc7386())
×
1142

1143
    server_name = session.metadata.name
×
1144
    launcher = await session_repo.get_launcher(user, session.launcher_id)
×
1145
    project = await project_repo.get_project(user=user, project_id=session.project_id)
×
1146
    environment = launcher.environment
×
1147
    work_dir = environment.working_directory
×
1148
    if not work_dir:
×
1149
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
×
1150
        work_dir_fallback = PurePosixPath("/home/jovyan")
×
1151
        work_dir = image_workdir or work_dir_fallback
×
1152
    storage_mount_fallback = work_dir / "work"
×
1153
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
×
1154
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
×
1155
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
×
1156
        user=user, project_id=project.id
1157
    )
NEW
1158
    data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id)
×
1159
    git_providers = await git_provider_helper.get_providers(user=user)
×
1160
    repositories = repositories_from_project(project, git_providers)
×
1161

1162
    # User secrets
1163
    session_extras = SessionExtraResources()
×
1164
    session_extras = session_extras.concat(
×
1165
        user_secrets_extras(
1166
            user=user,
1167
            config=nb_config,
1168
            secrets_mount_directory=secrets_mount_directory.as_posix(),
1169
            k8s_secret_name=f"{server_name}-secrets",
1170
            session_secrets=session_secrets,
1171
        )
1172
    )
1173

1174
    # Data connectors
NEW
1175
    await patch_data_sources(
×
1176
        existing_session=session,
1177
        nb_config=nb_config,
1178
        server_name=server_name,
1179
        user=user,
1180
        data_connectors_stream=data_connectors_stream,
1181
        work_dir=work_dir,
1182
        # TODO: allow 'data_connectors_overrides' to be passed on the PATCH endpoint
1183
        data_connectors_overrides=[],  # patch_request.data_connectors_overrides or [],
1184
        user_repo=user_repo,
1185
    )
1186
    # session_extras = session_extras.concat(
1187
    #     await get_data_sources(
1188
    #         nb_config=nb_config,
1189
    #         server_name=server_name,
1190
    #         user=user,
1191
    #         data_connectors_stream=data_connectors_stream,
1192
    #         work_dir=work_dir,
1193
    #         data_connectors_overrides=launch_request.data_connectors_overrides or [],
1194
    #         user_repo=user_repo,
1195
    #     )
1196
    # )
1197
    # TODO: How can we patch data connectors? Should we even patch them?
1198
    # TODO: The fact that `start_session()` accepts overrides for data connectors
1199
    # TODO: but that we do not save these overrides (e.g. as annotations) means that
1200
    # TODO: we cannot patch data connectors upon resume.
1201
    # TODO: If we did, we would lose the user's provided overrides (e.g. unsaved credentials).
1202

1203
    # More init containers
1204
    session_extras = session_extras.concat(
×
1205
        await get_extra_init_containers(
1206
            nb_config,
1207
            user,
1208
            repositories,
1209
            git_providers,
1210
            storage_mount,
1211
            work_dir,
1212
            uid=environment.uid,
1213
            gid=environment.gid,
1214
        )
1215
    )
1216

1217
    # Extra containers
1218
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
×
1219

1220
    # Patching the image pull secret
1221
    image = session.spec.session.image
×
1222
    image_pull_secret = await get_image_pull_secret(
×
1223
        image=image,
1224
        server_name=server_name,
1225
        nb_config=nb_config,
1226
        connected_svcs_repo=connected_svcs_repo,
1227
        user=user,
1228
        internal_gitlab_user=internal_gitlab_user,
1229
    )
1230
    if image_pull_secret:
×
1231
        session_extras.concat(SessionExtraResources(secrets=[image_pull_secret]))
×
1232
        patch.spec.imagePullSecrets = [ImagePullSecret(name=image_pull_secret.name, adopt=image_pull_secret.adopt)]
×
1233

1234
    # Construct session patch
1235
    patch.spec.extraContainers = _make_patch_spec_list(
×
1236
        existing=session.spec.extraContainers or [], updated=session_extras.containers
1237
    )
1238
    patch.spec.initContainers = _make_patch_spec_list(
×
1239
        existing=session.spec.initContainers or [], updated=session_extras.init_containers
1240
    )
1241
    patch.spec.extraVolumes = _make_patch_spec_list(
×
1242
        existing=session.spec.extraVolumes or [], updated=session_extras.volumes
1243
    )
1244
    if not patch.spec.session:
×
1245
        patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
×
1246
    patch.spec.session.extraVolumeMounts = _make_patch_spec_list(
×
1247
        existing=session.spec.session.extraVolumeMounts or [], updated=session_extras.volume_mounts
1248
    )
1249

1250
    secrets_to_create = session_extras.secrets or []
×
1251
    for s in secrets_to_create:
×
1252
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1253

1254
    patch_serialized = patch.to_rfc7386()
×
1255
    if len(patch_serialized) == 0:
×
1256
        return session
×
1257

1258
    return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch_serialized)
×
1259

1260

1261
def _deduplicate_target_paths(dcs: dict[str, RCloneStorage]) -> dict[str, RCloneStorage]:
2✔
1262
    """Ensures that the target paths for all storages are unique.
1263

1264
    This method will attempt to de-duplicate the target_path for all items passed in,
1265
    and raise an error if it fails to generate unique target_path.
1266
    """
1267
    result_dcs: dict[str, RCloneStorage] = {}
×
1268
    mount_folders: dict[str, list[str]] = {}
×
1269

1270
    def _find_mount_folder(dc: RCloneStorage) -> str:
×
1271
        mount_folder = dc.mount_folder
×
1272
        if mount_folder not in mount_folders:
×
1273
            return mount_folder
×
1274
        # 1. Try with a "-1", "-2", etc. suffix
1275
        mount_folder_try = f"{mount_folder}-{len(mount_folders[mount_folder])}"
×
1276
        if mount_folder_try not in mount_folders:
×
1277
            return mount_folder_try
×
1278
        # 2. Try with a random suffix
1279
        suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(4)])  # nosec B311
×
1280
        mount_folder_try = f"{mount_folder}-{suffix}"
×
1281
        if mount_folder_try not in mount_folders:
×
1282
            return mount_folder_try
×
1283
        raise errors.ValidationError(
×
1284
            message=f"Could not start session because two or more data connectors ({', '.join(mount_folders[mount_folder])}) share the same mount point '{mount_folder}'"  # noqa E501
1285
        )
1286

1287
    for dc_id, dc in dcs.items():
×
1288
        original_mount_folder = dc.mount_folder
×
1289
        new_mount_folder = _find_mount_folder(dc)
×
1290
        # Keep track of the original mount folder here
1291
        if new_mount_folder != original_mount_folder:
×
1292
            logger.warning(f"Re-assigning data connector {dc_id} to mount point '{new_mount_folder}'")
×
1293
            dc_ids = mount_folders.get(original_mount_folder, [])
×
1294
            dc_ids.append(dc_id)
×
1295
            mount_folders[original_mount_folder] = dc_ids
×
1296
        # Keep track of the assigned mount folder here
1297
        dc_ids = mount_folders.get(new_mount_folder, [])
×
1298
        dc_ids.append(dc_id)
×
1299
        mount_folders[new_mount_folder] = dc_ids
×
1300
        result_dcs[dc_id] = dc.with_override(
×
1301
            # override=apispec.SessionCloudStoragePost(storage_id=dc_id, target_path=new_mount_folder)
1302
            override=SessionDataConnectorOverride(
1303
                skip=False,
1304
                data_connector_id=ULID.from_str(dc_id),
1305
                target_path=new_mount_folder,
1306
                configuration=None,
1307
                source_path=None,
1308
                readonly=None,
1309
            )
1310
        )
1311

1312
    return result_dcs
×
1313

1314

1315
class _NamedResource(Protocol):
2✔
1316
    """Represents a resource with a name."""
1317

1318
    name: str
2✔
1319

1320

1321
_T = TypeVar("_T", bound=_NamedResource)
2✔
1322

1323

1324
def _make_patch_spec_list(existing: Sequence[_T], updated: Sequence[_T]) -> list[_T] | None:
2✔
1325
    """Merges updated into existing by upserting items identified by their name.
1326

1327
    This method is used to construct session patches, merging session resources by name (containers, volumes, etc.).
1328
    """
1329
    patch_list = None
1✔
1330
    if updated:
1✔
1331
        patch_list = list(existing)
1✔
1332
        upsert_list = list(updated)
1✔
1333
        for upsert_item in upsert_list:
1✔
1334
            # Find out if the upsert_item needs to be added or updated
1335
            # found = next(enumerate(filter(lambda item: item.name == upsert_item.name, patch_list)), None)
1336
            found = next(filter(lambda t: t[1].name == upsert_item.name, enumerate(patch_list)), None)
1✔
1337
            if found is not None:
1✔
1338
                idx, _ = found
1✔
1339
                patch_list[idx] = upsert_item
1✔
1340
            else:
1341
                patch_list.append(upsert_item)
1✔
1342
    return patch_list
1✔
1343

1344

1345
def validate_session_post_request(body: apispec.SessionPostRequest) -> SessionLaunchRequest:
2✔
1346
    """Validate a session launch request."""
NEW
1347
    data_connectors_overrides = (
×
1348
        [
1349
            SessionDataConnectorOverride(
1350
                skip=dc.skip,
1351
                data_connector_id=ULID.from_str(dc.data_connector_id),
1352
                configuration=dc.configuration,
1353
                source_path=dc.source_path,
1354
                target_path=dc.target_path,
1355
                readonly=dc.readonly,
1356
            )
1357
            for dc in body.data_connectors_overrides
1358
        ]
1359
        if body.data_connectors_overrides
1360
        else None
1361
    )
NEW
1362
    env_variable_overrides = (
×
1363
        [SessionEnvVar(name=ev.name, value=ev.value) for ev in body.env_variable_overrides]
1364
        if body.env_variable_overrides
1365
        else None
1366
    )
NEW
1367
    return SessionLaunchRequest(
×
1368
        launcher_id=ULID.from_str(body.launcher_id),
1369
        disk_storage=body.disk_storage,
1370
        resource_class_id=body.resource_class_id,
1371
        data_connectors_overrides=data_connectors_overrides,
1372
        env_variable_overrides=env_variable_overrides,
1373
    )
1374

1375

1376
def validate_session_patch_request(body: apispec.SessionPatchRequest) -> SessionPatchRequest:
2✔
1377
    """Validate a session patch request."""
NEW
1378
    return SessionPatchRequest(
×
1379
        resource_class_id=body.resource_class_id,
1380
        state=SessionState(body.state.value) if body.state else None,
1381
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc