• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 18707793088

22 Oct 2025 06:47AM UTC coverage: 86.825% (+0.006%) from 86.819%
18707793088

Pull #1065

github

web-flow
Merge 30515e0e4 into d4f265455
Pull Request #1065: feat: allow data connectors to be skipped when launching sessions

36 of 59 new or added lines in 4 files covered. (61.02%)

5 existing lines in 3 files now uncovered.

22723 of 26171 relevant lines covered (86.83%)

1.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

17.39
/components/renku_data_services/notebooks/core_sessions.py
1
"""A selection of core functions for AmaltheaSessions."""
2

3
import base64
2✔
4
import json
2✔
5
import os
2✔
6
import random
2✔
7
import string
2✔
8
from collections.abc import AsyncIterator, Sequence
2✔
9
from datetime import timedelta
2✔
10
from pathlib import PurePosixPath
2✔
11
from typing import Protocol, TypeVar, cast
2✔
12
from urllib.parse import urljoin, urlparse
2✔
13

14
import httpx
2✔
15
from kubernetes.client import V1ObjectMeta, V1Secret
2✔
16
from sanic import Request
2✔
17
from toml import dumps
2✔
18
from ulid import ULID
2✔
19
from yaml import safe_dump
2✔
20

21
import renku_data_services.notebooks.image_check as ic
2✔
22
from renku_data_services.app_config import logging
2✔
23
from renku_data_services.base_models import AnonymousAPIUser, APIUser, AuthenticatedAPIUser
2✔
24
from renku_data_services.base_models.metrics import MetricsService
2✔
25
from renku_data_services.connected_services.db import ConnectedServicesRepository
2✔
26
from renku_data_services.crc.db import ClusterRepository, ResourcePoolRepository
2✔
27
from renku_data_services.crc.models import (
2✔
28
    ClusterSettings,
29
    GpuKind,
30
    RemoteConfigurationFirecrest,
31
    ResourceClass,
32
    ResourcePool,
33
)
34
from renku_data_services.data_connectors.db import (
2✔
35
    DataConnectorSecretRepository,
36
)
37
from renku_data_services.data_connectors.models import DataConnectorSecret, DataConnectorWithSecrets
2✔
38
from renku_data_services.errors import errors
2✔
39
from renku_data_services.k8s.models import K8sSecret, sanitizer
2✔
40
from renku_data_services.notebooks import apispec, core
2✔
41
from renku_data_services.notebooks.api.amalthea_patches import git_proxy, init_containers
2✔
42
from renku_data_services.notebooks.api.amalthea_patches.init_containers import user_secrets_extras
2✔
43
from renku_data_services.notebooks.api.classes.image import Image
2✔
44
from renku_data_services.notebooks.api.classes.repository import GitProvider, Repository
2✔
45
from renku_data_services.notebooks.api.schemas.cloud_storage import RCloneStorage
2✔
46
from renku_data_services.notebooks.config import GitProviderHelperProto, NotebooksConfig
2✔
47
from renku_data_services.notebooks.crs import (
2✔
48
    AmaltheaSessionSpec,
49
    AmaltheaSessionV1Alpha1,
50
    AmaltheaSessionV1Alpha1MetadataPatch,
51
    AmaltheaSessionV1Alpha1Patch,
52
    AmaltheaSessionV1Alpha1SpecPatch,
53
    AmaltheaSessionV1Alpha1SpecSessionPatch,
54
    Authentication,
55
    AuthenticationType,
56
    Culling,
57
    DataSource,
58
    ExtraContainer,
59
    ExtraVolume,
60
    ExtraVolumeMount,
61
    ImagePullPolicy,
62
    ImagePullSecret,
63
    Ingress,
64
    InitContainer,
65
    Limits,
66
    LimitsStr,
67
    Metadata,
68
    ReconcileStrategy,
69
    Requests,
70
    RequestsStr,
71
    Resources,
72
    SecretAsVolume,
73
    SecretAsVolumeItem,
74
    Session,
75
    SessionEnvItem,
76
    SessionLocation,
77
    ShmSizeStr,
78
    SizeStr,
79
    State,
80
    Storage,
81
)
82
from renku_data_services.notebooks.models import (
2✔
83
    ExtraSecret,
84
    SessionDataConnectorOverride,
85
    SessionEnvVar,
86
    SessionExtraResources,
87
    SessionLaunchRequest,
88
)
89
from renku_data_services.notebooks.util.kubernetes_ import (
2✔
90
    renku_2_make_server_name,
91
)
92
from renku_data_services.notebooks.utils import (
2✔
93
    node_affinity_from_resource_class,
94
    tolerations_from_resource_class,
95
)
96
from renku_data_services.project.db import ProjectRepository, ProjectSessionSecretRepository
2✔
97
from renku_data_services.project.models import Project, SessionSecret
2✔
98
from renku_data_services.session.db import SessionRepository
2✔
99
from renku_data_services.session.models import SessionLauncher
2✔
100
from renku_data_services.users.db import UserRepo
2✔
101
from renku_data_services.utils.cryptography import get_encryption_key
2✔
102

103
logger = logging.getLogger(__name__)
2✔
104

105

106
async def get_extra_init_containers(
2✔
107
    nb_config: NotebooksConfig,
108
    user: AnonymousAPIUser | AuthenticatedAPIUser,
109
    repositories: list[Repository],
110
    git_providers: list[GitProvider],
111
    storage_mount: PurePosixPath,
112
    work_dir: PurePosixPath,
113
    uid: int = 1000,
114
    gid: int = 1000,
115
) -> SessionExtraResources:
116
    """Get all extra init containers that should be added to an amalthea session."""
117
    # TODO: The above statement is not correct: the init container for user secrets is not included here
118
    cert_init, cert_vols = init_containers.certificates_container(nb_config)
×
119
    session_init_containers = [InitContainer.model_validate(sanitizer(cert_init))]
×
120
    extra_volumes = [ExtraVolume.model_validate(sanitizer(volume)) for volume in cert_vols]
×
121
    git_clone = await init_containers.git_clone_container_v2(
×
122
        user=user,
123
        config=nb_config,
124
        repositories=repositories,
125
        git_providers=git_providers,
126
        workspace_mount_path=storage_mount,
127
        work_dir=work_dir,
128
        uid=uid,
129
        gid=gid,
130
    )
131
    if git_clone is not None:
×
132
        session_init_containers.append(InitContainer.model_validate(git_clone))
×
133
    return SessionExtraResources(
×
134
        init_containers=session_init_containers,
135
        volumes=extra_volumes,
136
    )
137

138

139
async def get_extra_containers(
2✔
140
    nb_config: NotebooksConfig,
141
    user: AnonymousAPIUser | AuthenticatedAPIUser,
142
    repositories: list[Repository],
143
    git_providers: list[GitProvider],
144
) -> SessionExtraResources:
145
    """Get the extra containers added to amalthea sessions."""
146
    conts: list[ExtraContainer] = []
×
147
    git_proxy_container = await git_proxy.main_container(
×
148
        user=user, config=nb_config, repositories=repositories, git_providers=git_providers
149
    )
150
    if git_proxy_container:
×
151
        conts.append(ExtraContainer.model_validate(sanitizer(git_proxy_container)))
×
152
    return SessionExtraResources(containers=conts)
×
153

154

155
async def get_auth_secret_authenticated(
2✔
156
    nb_config: NotebooksConfig,
157
    user: AuthenticatedAPIUser,
158
    server_name: str,
159
    base_server_url: str,
160
    base_server_https_url: str,
161
    base_server_path: str,
162
) -> ExtraSecret:
163
    """Get the extra secrets that need to be added to the session for an authenticated user."""
164
    secret_data = {}
×
165

166
    parsed_proxy_url = urlparse(urljoin(base_server_url + "/", "oauth2"))
×
167
    vol = ExtraVolume(
×
168
        name="renku-authorized-emails",
169
        secret=SecretAsVolume(
170
            secretName=server_name,
171
            items=[SecretAsVolumeItem(key="authorized_emails", path="authorized_emails")],
172
        ),
173
    )
174
    secret_data["auth"] = dumps(
×
175
        {
176
            "provider": "oidc",
177
            "client_id": nb_config.sessions.oidc.client_id,
178
            "oidc_issuer_url": nb_config.sessions.oidc.issuer_url,
179
            "session_cookie_minimal": True,
180
            "skip_provider_button": True,
181
            # NOTE: If the redirect url is not HTTPS then some or identity providers will fail.
182
            "redirect_url": urljoin(base_server_https_url + "/", "oauth2/callback"),
183
            "cookie_path": base_server_path,
184
            "proxy_prefix": parsed_proxy_url.path,
185
            "authenticated_emails_file": "/authorized_emails",
186
            "client_secret": nb_config.sessions.oidc.client_secret,
187
            "cookie_secret": base64.urlsafe_b64encode(os.urandom(32)).decode(),
188
            "insecure_oidc_allow_unverified_email": nb_config.sessions.oidc.allow_unverified_email,
189
        }
190
    )
191
    secret_data["authorized_emails"] = user.email
×
192
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
×
193
    vol_mount = ExtraVolumeMount(
×
194
        name="renku-authorized-emails",
195
        mountPath="/authorized_emails",
196
        subPath="authorized_emails",
197
    )
198
    return ExtraSecret(secret, vol, vol_mount)
×
199

200

201
def get_auth_secret_anonymous(nb_config: NotebooksConfig, server_name: str, request: Request) -> ExtraSecret:
2✔
202
    """Get the extra secrets that need to be added to the session for an anonymous user."""
203
    # NOTE: We extract the session cookie value here in order to avoid creating a cookie.
204
    # The gateway encrypts and signs cookies so the user ID injected in the request headers does not
205
    # match the value of the session cookie.
206
    session_id = cast(str | None, request.cookies.get(nb_config.session_id_cookie_name))
×
207
    if not session_id:
×
208
        raise errors.UnauthorizedError(
×
209
            message=f"You have to have a renku session cookie at {nb_config.session_id_cookie_name} "
210
            "in order to launch an anonymous session."
211
        )
212
    # NOTE: Amalthea looks for the token value first in the cookie and then in the authorization header
213
    secret_data = {
×
214
        "auth": safe_dump(
215
            {
216
                "authproxy": {
217
                    "token": session_id,
218
                    "cookie_key": nb_config.session_id_cookie_name,
219
                    "verbose": True,
220
                }
221
            }
222
        )
223
    }
224
    secret = V1Secret(metadata=V1ObjectMeta(name=server_name), string_data=secret_data)
×
225
    return ExtraSecret(secret)
×
226

227

228
def __get_gitlab_image_pull_secret(
2✔
229
    nb_config: NotebooksConfig, user: AuthenticatedAPIUser, image_pull_secret_name: str, access_token: str
230
) -> ExtraSecret:
231
    """Create a Kubernetes secret for private GitLab registry authentication."""
232

233
    k8s_namespace = nb_config.k8s_client.namespace()
×
234

235
    registry_secret = {
×
236
        "auths": {
237
            nb_config.git.registry: {
238
                "Username": "oauth2",
239
                "Password": access_token,
240
                "Email": user.email,
241
            }
242
        }
243
    }
244
    registry_secret = json.dumps(registry_secret)
×
245

246
    secret_data = {".dockerconfigjson": registry_secret}
×
247
    secret = V1Secret(
×
248
        metadata=V1ObjectMeta(name=image_pull_secret_name, namespace=k8s_namespace),
249
        string_data=secret_data,
250
        type="kubernetes.io/dockerconfigjson",
251
    )
252

253
    return ExtraSecret(secret)
×
254

255

256
async def get_data_sources(
2✔
257
    nb_config: NotebooksConfig,
258
    user: AnonymousAPIUser | AuthenticatedAPIUser,
259
    server_name: str,
260
    data_connectors_stream: AsyncIterator[DataConnectorWithSecrets],
261
    work_dir: PurePosixPath,
262
    data_connectors_overrides: list[SessionDataConnectorOverride],
263
    user_repo: UserRepo,
264
) -> SessionExtraResources:
265
    """Generate cloud storage related resources."""
266
    data_sources: list[DataSource] = []
×
267
    secrets: list[ExtraSecret] = []
×
268
    dcs: dict[str, RCloneStorage] = {}
×
269
    dcs_secrets: dict[str, list[DataConnectorSecret]] = {}
×
270
    user_secret_key: str | None = None
×
271
    async for dc in data_connectors_stream:
×
272
        mount_folder = (
×
273
            dc.data_connector.storage.target_path
274
            if PurePosixPath(dc.data_connector.storage.target_path).is_absolute()
275
            else (work_dir / dc.data_connector.storage.target_path).as_posix()
276
        )
277
        dcs[str(dc.data_connector.id)] = RCloneStorage(
×
278
            source_path=dc.data_connector.storage.source_path,
279
            mount_folder=mount_folder,
280
            configuration=dc.data_connector.storage.configuration,
281
            readonly=dc.data_connector.storage.readonly,
282
            name=dc.data_connector.name,
283
            secrets={str(secret.secret_id): secret.name for secret in dc.secrets},
284
            storage_class=nb_config.cloud_storage.storage_class,
285
        )
286
        if len(dc.secrets) > 0:
×
287
            dcs_secrets[str(dc.data_connector.id)] = dc.secrets
×
288
    if isinstance(user, AuthenticatedAPIUser) and len(dcs_secrets) > 0:
×
289
        secret_key = await user_repo.get_or_create_user_secret_key(user)
×
290
        user_secret_key = get_encryption_key(secret_key.encode(), user.id.encode()).decode("utf-8")
×
291
    # NOTE: Check the cloud storage overrides from the request body and if any match
292
    # then overwrite the projects cloud storages
293
    # NOTE: Cloud storages in the session launch request body that are not from the DB will cause a 404 error
294
    # TODO: Is this correct? -> NOTE: Overriding the configuration when a saved secret is there will cause a 422 error
NEW
295
    for dco in data_connectors_overrides:
×
NEW
296
        dc_id = str(dco.data_connector_id)
×
NEW
297
        if dc_id not in dcs:
×
UNCOV
298
            raise errors.MissingResourceError(
×
299
                message=f"You have requested a data connector with ID {dc_id} which does not exist "
300
                "or you don't have access to."
301
            )
302
        # NOTE: if 'skip' is true, we do not mount that data connector
NEW
303
        if dco.skip:
×
NEW
304
            del dcs[dc_id]
×
NEW
305
            continue
×
NEW
306
        if dco.target_path is not None and not PurePosixPath(dco.target_path).is_absolute():
×
NEW
307
            dco.target_path = (work_dir / dco.target_path).as_posix()
×
NEW
308
        dcs[dc_id] = dcs[dc_id].with_override(dco)
×
309

310
    # Handle potential duplicate target_path
311
    dcs = _deduplicate_target_paths(dcs)
×
312

313
    for cs_id, cs in dcs.items():
×
314
        secret_name = f"{server_name}-ds-{cs_id.lower()}"
×
315
        secret_key_needed = len(dcs_secrets.get(cs_id, [])) > 0
×
316
        if secret_key_needed and user_secret_key is None:
×
317
            raise errors.ProgrammingError(
×
318
                message=f"You have saved storage secrets for data connector {cs_id} "
319
                f"associated with your user ID {user.id} but no key to decrypt them, "
320
                "therefore we cannot mount the requested data connector. "
321
                "Please report this to the renku administrators."
322
            )
323
        secret = ExtraSecret(
×
324
            cs.secret(
325
                secret_name,
326
                await nb_config.k8s_client.namespace(),
327
                user_secret_key=user_secret_key if secret_key_needed else None,
328
            )
329
        )
330
        secrets.append(secret)
×
331
        data_sources.append(
×
332
            DataSource(
333
                mountPath=cs.mount_folder,
334
                secretRef=secret.ref(),
335
                accessMode="ReadOnlyMany" if cs.readonly else "ReadWriteOnce",
336
            )
337
        )
338
    return SessionExtraResources(
×
339
        data_sources=data_sources,
340
        secrets=secrets,
341
        data_connector_secrets=dcs_secrets,
342
    )
343

344

345
async def request_dc_secret_creation(
2✔
346
    user: AuthenticatedAPIUser | AnonymousAPIUser,
347
    nb_config: NotebooksConfig,
348
    manifest: AmaltheaSessionV1Alpha1,
349
    dc_secrets: dict[str, list[DataConnectorSecret]],
350
) -> None:
351
    """Request the specified data connector secrets to be created by the secret service."""
352
    if isinstance(user, AnonymousAPIUser):
×
353
        return
×
354
    owner_reference = {
×
355
        "apiVersion": manifest.apiVersion,
356
        "kind": manifest.kind,
357
        "name": manifest.metadata.name,
358
        "uid": manifest.metadata.uid,
359
    }
360
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
×
361
    headers = {"Authorization": f"bearer {user.access_token}"}
×
362

363
    cluster_id = None
×
364
    namespace = await nb_config.k8s_v2_client.namespace()
×
365
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
×
366
        cluster_id = cluster.id
×
367
        namespace = cluster.namespace
×
368

369
    for s_id, secrets in dc_secrets.items():
×
370
        if len(secrets) == 0:
×
371
            continue
×
372
        request_data = {
×
373
            "name": f"{manifest.metadata.name}-ds-{s_id.lower()}-secrets",
374
            "namespace": namespace,
375
            "secret_ids": [str(secret.secret_id) for secret in secrets],
376
            "owner_references": [owner_reference],
377
            "key_mapping": {str(secret.secret_id): secret.name for secret in secrets},
378
            "cluster_id": str(cluster_id),
379
        }
380
        async with httpx.AsyncClient(timeout=10) as client:
×
381
            res = await client.post(secrets_url, headers=headers, json=request_data)
×
382
            if res.status_code >= 300 or res.status_code < 200:
×
383
                raise errors.ProgrammingError(
×
384
                    message=f"The secret for data connector with {s_id} could not be "
385
                    f"successfully created, the status code was {res.status_code}."
386
                    "Please contact a Renku administrator.",
387
                    detail=res.text,
388
                )
389

390

391
def get_launcher_env_variables(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> list[SessionEnvItem]:
2✔
392
    """Get the environment variables from the launcher, with overrides from the request."""
393
    output: list[SessionEnvItem] = []
×
NEW
394
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
×
395
    for env in launcher.env_variables or []:
×
396
        if env.name in env_overrides:
×
397
            output.append(SessionEnvItem(name=env.name, value=env_overrides[env.name]))
×
398
        else:
399
            output.append(SessionEnvItem(name=env.name, value=env.value))
×
400
    return output
×
401

402

403
def verify_launcher_env_variable_overrides(launcher: SessionLauncher, launch_request: SessionLaunchRequest) -> None:
2✔
404
    """Raise an error if there are env variables that are not defined in the launcher."""
NEW
405
    env_overrides = {i.name: i.value for i in launch_request.env_variable_overrides or []}
×
406
    known_env_names = {i.name for i in launcher.env_variables or []}
×
407
    unknown_env_names = set(env_overrides.keys()) - known_env_names
×
408
    if unknown_env_names:
×
409
        message = f"""The following environment variables are not defined in the session launcher: {unknown_env_names}.
×
410
            Please remove them from the launch request or add them to the session launcher."""
411
        raise errors.ValidationError(message=message)
×
412

413

414
async def request_session_secret_creation(
2✔
415
    user: AuthenticatedAPIUser | AnonymousAPIUser,
416
    nb_config: NotebooksConfig,
417
    manifest: AmaltheaSessionV1Alpha1,
418
    session_secrets: list[SessionSecret],
419
) -> None:
420
    """Request the specified user session secrets to be created by the secret service."""
421
    if isinstance(user, AnonymousAPIUser):
×
422
        return
×
423
    if not session_secrets:
×
424
        return
×
425
    owner_reference = {
×
426
        "apiVersion": manifest.apiVersion,
427
        "kind": manifest.kind,
428
        "name": manifest.metadata.name,
429
        "uid": manifest.metadata.uid,
430
    }
431
    key_mapping: dict[str, list[str]] = dict()
×
432
    for s in session_secrets:
×
433
        secret_id = str(s.secret_id)
×
434
        if secret_id not in key_mapping:
×
435
            key_mapping[secret_id] = list()
×
436
        key_mapping[secret_id].append(s.secret_slot.filename)
×
437

438
    cluster_id = None
×
439
    namespace = await nb_config.k8s_v2_client.namespace()
×
440
    if (cluster := await nb_config.k8s_v2_client.cluster_by_class_id(manifest.resource_class_id(), user)) is not None:
×
441
        cluster_id = cluster.id
×
442
        namespace = cluster.namespace
×
443

444
    request_data = {
×
445
        "name": f"{manifest.metadata.name}-secrets",
446
        "namespace": namespace,
447
        "secret_ids": [str(s.secret_id) for s in session_secrets],
448
        "owner_references": [owner_reference],
449
        "key_mapping": key_mapping,
450
        "cluster_id": str(cluster_id),
451
    }
452
    secrets_url = nb_config.user_secrets.secrets_storage_service_url + "/api/secrets/kubernetes"
×
453
    headers = {"Authorization": f"bearer {user.access_token}"}
×
454
    async with httpx.AsyncClient(timeout=10) as client:
×
455
        res = await client.post(secrets_url, headers=headers, json=request_data)
×
456
        if res.status_code >= 300 or res.status_code < 200:
×
457
            raise errors.ProgrammingError(
×
458
                message="The session secrets could not be successfully created, "
459
                f"the status code was {res.status_code}."
460
                "Please contact a Renku administrator.",
461
                detail=res.text,
462
            )
463

464

465
def resources_from_resource_class(resource_class: ResourceClass) -> Resources:
2✔
466
    """Convert the resource class to a k8s resources spec."""
467
    requests: dict[str, Requests | RequestsStr] = {
×
468
        "cpu": RequestsStr(str(round(resource_class.cpu * 1000)) + "m"),
469
        "memory": RequestsStr(f"{resource_class.memory}Gi"),
470
    }
471
    limits: dict[str, Limits | LimitsStr] = {"memory": LimitsStr(f"{resource_class.memory}Gi")}
×
472
    if resource_class.gpu > 0:
×
473
        gpu_name = GpuKind.NVIDIA.value + "/gpu"
×
474
        requests[gpu_name] = Requests(resource_class.gpu)
×
475
        # NOTE: GPUs have to be set in limits too since GPUs cannot be overcommited, if
476
        # not on some clusters this will cause the session to fully fail to start.
477
        limits[gpu_name] = Limits(resource_class.gpu)
×
478
    return Resources(requests=requests, limits=limits if len(limits) > 0 else None)
×
479

480

481
def repositories_from_project(project: Project, git_providers: list[GitProvider]) -> list[Repository]:
2✔
482
    """Get the list of git repositories from a project."""
483
    repositories: list[Repository] = []
×
484
    for repo in project.repositories:
×
485
        found_provider_id: str | None = None
×
486
        for provider in git_providers:
×
487
            if urlparse(provider.url).netloc == urlparse(repo).netloc:
×
488
                found_provider_id = provider.id
×
489
                break
×
490
        repositories.append(Repository(url=repo, provider=found_provider_id))
×
491
    return repositories
×
492

493

494
async def repositories_from_session(
2✔
495
    user: AnonymousAPIUser | AuthenticatedAPIUser,
496
    session: AmaltheaSessionV1Alpha1,
497
    project_repo: ProjectRepository,
498
    git_providers: list[GitProvider],
499
) -> list[Repository]:
500
    """Get the list of git repositories from a session."""
501
    try:
×
502
        project = await project_repo.get_project(user, session.project_id)
×
503
    except errors.MissingResourceError:
×
504
        return []
×
505
    return repositories_from_project(project, git_providers)
×
506

507

508
def get_culling(
2✔
509
    user: AuthenticatedAPIUser | AnonymousAPIUser, resource_pool: ResourcePool, nb_config: NotebooksConfig
510
) -> Culling:
511
    """Create the culling specification for an AmaltheaSession."""
512
    idle_threshold_seconds = resource_pool.idle_threshold or nb_config.sessions.culling.registered.idle_seconds
×
513
    if user.is_anonymous:
×
514
        # NOTE: Anonymous sessions should not be hibernated at all, but there is no such option in Amalthea
515
        # So in this case we set a very low hibernation threshold so the session is deleted quickly after
516
        # it is hibernated.
517
        hibernation_threshold_seconds = 1
×
518
    else:
519
        hibernation_threshold_seconds = (
×
520
            resource_pool.hibernation_threshold or nb_config.sessions.culling.registered.hibernated_seconds
521
        )
522
    return Culling(
×
523
        maxAge=timedelta(seconds=nb_config.sessions.culling.registered.max_age_seconds),
524
        maxFailedDuration=timedelta(seconds=nb_config.sessions.culling.registered.failed_seconds),
525
        maxHibernatedDuration=timedelta(seconds=hibernation_threshold_seconds),
526
        maxIdleDuration=timedelta(seconds=idle_threshold_seconds),
527
        maxStartingDuration=timedelta(seconds=nb_config.sessions.culling.registered.pending_seconds),
528
    )
529

530

531
async def __requires_image_pull_secret(nb_config: NotebooksConfig, image: str, internal_gitlab_user: APIUser) -> bool:
2✔
532
    """Determines if an image requires a pull secret based on its visibility and their GitLab access token."""
533

534
    parsed_image = Image.from_path(image)
×
535
    image_repo = parsed_image.repo_api()
×
536

537
    image_exists_publicly = await image_repo.image_exists(parsed_image)
×
538
    if image_exists_publicly:
×
539
        return False
×
540

541
    if parsed_image.hostname == nb_config.git.registry and internal_gitlab_user.access_token:
×
542
        image_repo = image_repo.with_oauth2_token(internal_gitlab_user.access_token)
×
543
        image_exists_privately = await image_repo.image_exists(parsed_image)
×
544
        if image_exists_privately:
×
545
            return True
×
546
    # No pull secret needed if the image is private and the user cannot access it
547
    return False
×
548

549

550
def __format_image_pull_secret(secret_name: str, access_token: str, registry_domain: str) -> ExtraSecret:
2✔
551
    registry_secret = {
×
552
        "auths": {registry_domain: {"auth": base64.b64encode(f"oauth2:{access_token}".encode()).decode()}}
553
    }
554
    registry_secret = json.dumps(registry_secret)
×
555
    registry_secret = base64.b64encode(registry_secret.encode()).decode()
×
556
    return ExtraSecret(
×
557
        V1Secret(
558
            data={".dockerconfigjson": registry_secret},
559
            metadata=V1ObjectMeta(name=secret_name),
560
            type="kubernetes.io/dockerconfigjson",
561
        )
562
    )
563

564

565
async def __get_connected_services_image_pull_secret(
2✔
566
    secret_name: str, connected_svcs_repo: ConnectedServicesRepository, image: str, user: APIUser
567
) -> ExtraSecret | None:
568
    """Return a secret for accessing the image if one is available for the given user."""
569
    image_parsed = Image.from_path(image)
×
570
    image_check_result = await ic.check_image(image_parsed, user, connected_svcs_repo, None)
×
571
    logger.debug(f"Set pull secret for {image} to connection {image_check_result.image_provider}")
×
572
    if not image_check_result.token:
×
573
        return None
×
574

575
    if not image_check_result.image_provider:
×
576
        return None
×
577

578
    return __format_image_pull_secret(
×
579
        secret_name=secret_name,
580
        access_token=image_check_result.token,
581
        registry_domain=image_check_result.image_provider.registry_url,
582
    )
583

584

585
async def get_image_pull_secret(
2✔
586
    image: str,
587
    server_name: str,
588
    nb_config: NotebooksConfig,
589
    user: APIUser,
590
    internal_gitlab_user: APIUser,
591
    connected_svcs_repo: ConnectedServicesRepository,
592
) -> ExtraSecret | None:
593
    """Get an image pull secret."""
594

595
    v2_secret = await __get_connected_services_image_pull_secret(
×
596
        f"{server_name}-image-secret", connected_svcs_repo, image, user
597
    )
598
    if v2_secret:
×
599
        return v2_secret
×
600

601
    if (
×
602
        nb_config.enable_internal_gitlab
603
        and isinstance(user, AuthenticatedAPIUser)
604
        and internal_gitlab_user.access_token is not None
605
    ):
606
        needs_pull_secret = await __requires_image_pull_secret(nb_config, image, internal_gitlab_user)
×
607
        if needs_pull_secret:
×
608
            v1_secret = __get_gitlab_image_pull_secret(
×
609
                nb_config, user, f"{server_name}-image-secret-v1", internal_gitlab_user.access_token
610
            )
611
            return v1_secret
×
612

613
    return None
×
614

615

616
def get_remote_secret(
2✔
617
    user: AuthenticatedAPIUser | AnonymousAPIUser,
618
    config: NotebooksConfig,
619
    server_name: str,
620
    remote_provider_id: str,
621
    git_providers: list[GitProvider],
622
) -> ExtraSecret | None:
623
    """Returns the secret containing the configuration for the remote session controller."""
624
    if not user.is_authenticated or user.access_token is None or user.refresh_token is None:
×
625
        return None
×
626
    remote_provider = next(filter(lambda p: p.id == remote_provider_id, git_providers), None)
×
627
    if not remote_provider:
×
628
        return None
×
629
    renku_base_url = "https://" + config.sessions.ingress.host
×
630
    renku_base_url = renku_base_url.rstrip("/")
×
631
    renku_auth_token_uri = f"{renku_base_url}/auth/realms/{config.keycloak_realm}/protocol/openid-connect/token"
×
632
    secret_data = {
×
633
        "RSC_AUTH_KIND": "renku",
634
        "RSC_AUTH_TOKEN_URI": remote_provider.access_token_url,
635
        "RSC_AUTH_RENKU_ACCESS_TOKEN": user.access_token,
636
        "RSC_AUTH_RENKU_REFRESH_TOKEN": user.refresh_token,
637
        "RSC_AUTH_RENKU_TOKEN_URI": renku_auth_token_uri,
638
        "RSC_AUTH_RENKU_CLIENT_ID": config.sessions.git_proxy.renku_client_id,
639
        "RSC_AUTH_RENKU_CLIENT_SECRET": config.sessions.git_proxy.renku_client_secret,
640
    }
641
    secret_name = f"{server_name}-remote-secret"
×
642
    secret = V1Secret(metadata=V1ObjectMeta(name=secret_name), string_data=secret_data)
×
643
    return ExtraSecret(secret)
×
644

645

646
def get_remote_env(
2✔
647
    remote: RemoteConfigurationFirecrest,
648
) -> list[SessionEnvItem]:
649
    """Returns env variables used for remote sessions."""
650
    env = [
×
651
        SessionEnvItem(name="RSC_REMOTE_KIND", value=remote.kind.value),
652
        SessionEnvItem(name="RSC_FIRECREST_API_URL", value=remote.api_url),
653
        SessionEnvItem(name="RSC_FIRECREST_SYSTEM_NAME", value=remote.system_name),
654
    ]
655
    if remote.partition:
×
656
        env.append(SessionEnvItem(name="RSC_FIRECREST_PARTITION", value=remote.partition))
×
657
    return env
×
658

659

660
async def start_session(
2✔
661
    request: Request,
662
    launch_request: SessionLaunchRequest,
663
    user: AnonymousAPIUser | AuthenticatedAPIUser,
664
    internal_gitlab_user: APIUser,
665
    nb_config: NotebooksConfig,
666
    git_provider_helper: GitProviderHelperProto,
667
    cluster_repo: ClusterRepository,
668
    data_connector_secret_repo: DataConnectorSecretRepository,
669
    project_repo: ProjectRepository,
670
    project_session_secret_repo: ProjectSessionSecretRepository,
671
    rp_repo: ResourcePoolRepository,
672
    session_repo: SessionRepository,
673
    user_repo: UserRepo,
674
    metrics: MetricsService,
675
    connected_svcs_repo: ConnectedServicesRepository,
676
) -> tuple[AmaltheaSessionV1Alpha1, bool]:
677
    """Start an Amalthea session.
678

679
    Returns a tuple where the first item is an instance of an Amalthea session
680
    and the second item is a boolean set to true iff a new session was created.
681
    """
NEW
682
    launcher = await session_repo.get_launcher(user=user, launcher_id=launch_request.launcher_id)
×
NEW
683
    launcher_id = launcher.id
×
UNCOV
684
    project = await project_repo.get_project(user=user, project_id=launcher.project_id)
×
685

686
    # Determine resource_class_id: the class can be overwritten at the user's request
NEW
687
    resource_class_id = launch_request.resource_class_id or launcher.resource_class_id
×
688

689
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(resource_class_id, user)
×
690

691
    server_name = renku_2_make_server_name(
×
692
        user=user, project_id=str(launcher.project_id), launcher_id=str(launcher_id), cluster_id=str(cluster.id)
693
    )
NEW
694
    existing_session = await nb_config.k8s_v2_client.get_session(name=server_name, safe_username=user.id)
×
695
    if existing_session is not None and existing_session.spec is not None:
×
696
        return existing_session, False
×
697

698
    # Fully determine the resource pool and resource class
699
    if resource_class_id is None:
×
700
        resource_pool = await rp_repo.get_default_resource_pool()
×
701
        resource_class = resource_pool.get_default_resource_class()
×
702
        if not resource_class and len(resource_pool.classes) > 0:
×
703
            resource_class = resource_pool.classes[0]
×
704
        if not resource_class or not resource_class.id:
×
705
            raise errors.ProgrammingError(message="Cannot find any resource classes in the default pool.")
×
706
        resource_class_id = resource_class.id
×
707
    else:
708
        resource_pool = await rp_repo.get_resource_pool_from_class(user, resource_class_id)
×
709
        resource_class = resource_pool.get_resource_class(resource_class_id)
×
710
        if not resource_class or not resource_class.id:
×
711
            raise errors.MissingResourceError(message=f"The resource class with ID {resource_class_id} does not exist.")
×
NEW
712
    await nb_config.crc_validator.validate_class_storage(user, resource_class.id, launch_request.disk_storage)
×
NEW
713
    disk_storage = launch_request.disk_storage or resource_class.default_storage
×
714

715
    # Determine session location
716
    session_location = SessionLocation.remote if resource_pool.remote else SessionLocation.local
×
717
    if session_location == SessionLocation.remote and not user.is_authenticated:
×
718
        raise errors.ValidationError(message="Anonymous users cannot start remote sessions.")
×
719

720
    environment = launcher.environment
×
721
    image = environment.container_image
×
722
    work_dir = environment.working_directory
×
723
    if not work_dir:
×
724
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
×
725
        work_dir_fallback = PurePosixPath("/home/jovyan")
×
726
        work_dir = image_workdir or work_dir_fallback
×
727
    storage_mount_fallback = work_dir / "work"
×
728
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
×
729
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
×
730
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
×
731
        user=user, project_id=project.id
732
    )
733
    data_connectors_stream = data_connector_secret_repo.get_data_connectors_with_secrets(user, project.id)
×
734
    git_providers = await git_provider_helper.get_providers(user=user)
×
735
    repositories = repositories_from_project(project, git_providers)
×
736

737
    # User secrets
738
    session_extras = SessionExtraResources()
×
739
    session_extras = session_extras.concat(
×
740
        user_secrets_extras(
741
            user=user,
742
            config=nb_config,
743
            secrets_mount_directory=secrets_mount_directory.as_posix(),
744
            k8s_secret_name=f"{server_name}-secrets",
745
            session_secrets=session_secrets,
746
        )
747
    )
748

749
    # Data connectors
750
    session_extras = session_extras.concat(
×
751
        await get_data_sources(
752
            nb_config=nb_config,
753
            server_name=server_name,
754
            user=user,
755
            data_connectors_stream=data_connectors_stream,
756
            work_dir=work_dir,
757
            data_connectors_overrides=launch_request.data_connectors_overrides or [],
758
            user_repo=user_repo,
759
        )
760
    )
761

762
    # More init containers
763
    session_extras = session_extras.concat(
×
764
        await get_extra_init_containers(
765
            nb_config,
766
            user,
767
            repositories,
768
            git_providers,
769
            storage_mount,
770
            work_dir,
771
            uid=environment.uid,
772
            gid=environment.gid,
773
        )
774
    )
775

776
    # Extra containers
777
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
×
778

779
    # Cluster settings (ingress, storage class, etc)
780
    cluster_settings: ClusterSettings
781
    try:
×
782
        cluster_settings = await cluster_repo.select(cluster.id)
×
783
    except errors.MissingResourceError:
×
784
        # Fallback to global, main cluster parameters
785
        cluster_settings = nb_config.local_cluster_settings()
×
786

787
    (
×
788
        base_server_path,
789
        base_server_url,
790
        base_server_https_url,
791
        host,
792
        tls_secret,
793
        ingress_class_name,
794
        ingress_annotations,
795
    ) = cluster_settings.get_ingress_parameters(server_name)
796
    storage_class = cluster_settings.get_storage_class()
×
797
    service_account_name = cluster_settings.service_account_name
×
798

799
    ui_path = f"{base_server_path}/{environment.default_url.lstrip('/')}"
×
800

801
    ingress = Ingress(
×
802
        host=host,
803
        ingressClassName=ingress_class_name,
804
        annotations=ingress_annotations,
805
        tlsSecret=tls_secret,
806
        pathPrefix=base_server_path,
807
    )
808

809
    # Annotations
810
    annotations: dict[str, str] = {
×
811
        "renku.io/project_id": str(launcher.project_id),
812
        "renku.io/launcher_id": str(launcher_id),
813
        "renku.io/resource_class_id": str(resource_class_id),
814
    }
815

816
    # Authentication
817
    if isinstance(user, AuthenticatedAPIUser):
×
818
        auth_secret = await get_auth_secret_authenticated(
×
819
            nb_config, user, server_name, base_server_url, base_server_https_url, base_server_path
820
        )
821
    else:
822
        auth_secret = get_auth_secret_anonymous(nb_config, server_name, request)
×
823
    session_extras = session_extras.concat(
×
824
        SessionExtraResources(
825
            secrets=[auth_secret],
826
            volumes=[auth_secret.volume] if auth_secret.volume else [],
827
        )
828
    )
829
    authn_extra_volume_mounts: list[ExtraVolumeMount] = []
×
830
    if auth_secret.volume_mount:
×
831
        authn_extra_volume_mounts.append(auth_secret.volume_mount)
×
832

833
    cert_vol_mounts = init_containers.certificates_volume_mounts(nb_config)
×
834
    if cert_vol_mounts:
×
835
        authn_extra_volume_mounts.extend(cert_vol_mounts)
×
836

837
    image_secret = await get_image_pull_secret(
×
838
        image=image,
839
        server_name=server_name,
840
        nb_config=nb_config,
841
        user=user,
842
        internal_gitlab_user=internal_gitlab_user,
843
        connected_svcs_repo=connected_svcs_repo,
844
    )
845
    if image_secret:
×
846
        session_extras = session_extras.concat(SessionExtraResources(secrets=[image_secret]))
×
847

848
    # Remote session configuration
849
    remote_secret = None
×
850
    if session_location == SessionLocation.remote:
×
851
        assert resource_pool.remote is not None
×
852
        if resource_pool.remote.provider_id is None:
×
853
            raise errors.ProgrammingError(
×
854
                message=f"The resource pool {resource_pool.id} configuration is not valid (missing field 'remote_provider_id')."  # noqa E501
855
            )
856
        remote_secret = get_remote_secret(
×
857
            user=user,
858
            config=nb_config,
859
            server_name=server_name,
860
            remote_provider_id=resource_pool.remote.provider_id,
861
            git_providers=git_providers,
862
        )
863
    if remote_secret is not None:
×
864
        session_extras = session_extras.concat(SessionExtraResources(secrets=[remote_secret]))
×
865

866
    # Raise an error if there are invalid environment variables in the request body
NEW
867
    verify_launcher_env_variable_overrides(launcher, launch_request)
×
868
    env = [
×
869
        SessionEnvItem(name="RENKU_BASE_URL_PATH", value=base_server_path),
870
        SessionEnvItem(name="RENKU_BASE_URL", value=base_server_url),
871
        SessionEnvItem(name="RENKU_MOUNT_DIR", value=storage_mount.as_posix()),
872
        SessionEnvItem(name="RENKU_SESSION", value="1"),
873
        SessionEnvItem(name="RENKU_SESSION_IP", value="0.0.0.0"),  # nosec B104
874
        SessionEnvItem(name="RENKU_SESSION_PORT", value=f"{environment.port}"),
875
        SessionEnvItem(name="RENKU_WORKING_DIR", value=work_dir.as_posix()),
876
        SessionEnvItem(name="RENKU_SECRETS_PATH", value=project.secrets_mount_directory.as_posix()),
877
        SessionEnvItem(name="RENKU_PROJECT_ID", value=str(project.id)),
878
        SessionEnvItem(name="RENKU_PROJECT_PATH", value=project.path.serialize()),
879
        SessionEnvItem(name="RENKU_LAUNCHER_ID", value=str(launcher.id)),
880
    ]
881
    if session_location == SessionLocation.remote:
×
882
        assert resource_pool.remote is not None
×
883
        env.extend(
×
884
            get_remote_env(
885
                remote=resource_pool.remote,
886
            )
887
        )
NEW
888
    launcher_env_variables = get_launcher_env_variables(launcher, launch_request)
×
889
    env.extend(launcher_env_variables)
×
890

891
    session = AmaltheaSessionV1Alpha1(
×
892
        metadata=Metadata(name=server_name, annotations=annotations),
893
        spec=AmaltheaSessionSpec(
894
            location=session_location,
895
            imagePullSecrets=[ImagePullSecret(name=image_secret.name, adopt=True)] if image_secret else [],
896
            codeRepositories=[],
897
            hibernated=False,
898
            reconcileStrategy=ReconcileStrategy.whenFailedOrHibernated,
899
            priorityClassName=resource_class.quota,
900
            session=Session(
901
                image=image,
902
                imagePullPolicy=ImagePullPolicy.Always,
903
                urlPath=ui_path,
904
                port=environment.port,
905
                storage=Storage(
906
                    className=storage_class,
907
                    size=SizeStr(str(disk_storage) + "G"),
908
                    mountPath=storage_mount.as_posix(),
909
                ),
910
                workingDir=work_dir.as_posix(),
911
                runAsUser=environment.uid,
912
                runAsGroup=environment.gid,
913
                resources=resources_from_resource_class(resource_class),
914
                extraVolumeMounts=session_extras.volume_mounts,
915
                command=environment.command,
916
                args=environment.args,
917
                shmSize=ShmSizeStr("1G"),
918
                stripURLPath=environment.strip_path_prefix,
919
                env=env,
920
                remoteSecretRef=remote_secret.ref() if remote_secret else None,
921
            ),
922
            ingress=ingress,
923
            extraContainers=session_extras.containers,
924
            initContainers=session_extras.init_containers,
925
            extraVolumes=session_extras.volumes,
926
            culling=get_culling(user, resource_pool, nb_config),
927
            authentication=Authentication(
928
                enabled=True,
929
                type=AuthenticationType.oauth2proxy
930
                if isinstance(user, AuthenticatedAPIUser)
931
                else AuthenticationType.token,
932
                secretRef=auth_secret.key_ref("auth"),
933
                extraVolumeMounts=authn_extra_volume_mounts,
934
            ),
935
            dataSources=session_extras.data_sources,
936
            tolerations=tolerations_from_resource_class(resource_class, nb_config.sessions.tolerations_model),
937
            affinity=node_affinity_from_resource_class(resource_class, nb_config.sessions.affinity_model),
938
            serviceAccountName=service_account_name,
939
        ),
940
    )
941
    secrets_to_create = session_extras.secrets or []
×
942
    for s in secrets_to_create:
×
943
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
944
    try:
×
945
        session = await nb_config.k8s_v2_client.create_session(session, user)
×
946
    except Exception as err:
×
947
        for s in secrets_to_create:
×
948
            await nb_config.k8s_v2_client.delete_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
949
        raise errors.ProgrammingError(message="Could not start the amalthea session") from err
×
950
    else:
951
        try:
×
952
            await request_session_secret_creation(user, nb_config, session, session_secrets)
×
953
            data_connector_secrets = session_extras.data_connector_secrets or dict()
×
954
            await request_dc_secret_creation(user, nb_config, session, data_connector_secrets)
×
955
        except Exception:
×
956
            await nb_config.k8s_v2_client.delete_session(server_name, user.id)
×
957
            raise
×
958

959
    await metrics.user_requested_session_launch(
×
960
        user=user,
961
        metadata={
962
            "cpu": int(resource_class.cpu * 1000),
963
            "memory": resource_class.memory,
964
            "gpu": resource_class.gpu,
965
            "storage": disk_storage,
966
            "resource_class_id": resource_class.id,
967
            "resource_pool_id": resource_pool.id or "",
968
            "resource_class_name": f"{resource_pool.name}.{resource_class.name}",
969
            "session_id": server_name,
970
        },
971
    )
972
    return session, True
×
973

974

975
async def patch_session(
2✔
976
    body: apispec.SessionPatchRequest,
977
    session_id: str,
978
    user: AnonymousAPIUser | AuthenticatedAPIUser,
979
    internal_gitlab_user: APIUser,
980
    nb_config: NotebooksConfig,
981
    git_provider_helper: GitProviderHelperProto,
982
    project_repo: ProjectRepository,
983
    project_session_secret_repo: ProjectSessionSecretRepository,
984
    rp_repo: ResourcePoolRepository,
985
    session_repo: SessionRepository,
986
    connected_svcs_repo: ConnectedServicesRepository,
987
    metrics: MetricsService,
988
) -> AmaltheaSessionV1Alpha1:
989
    """Patch an Amalthea session."""
990
    session = await nb_config.k8s_v2_client.get_session(session_id, user.id)
×
991
    if session is None:
×
992
        raise errors.MissingResourceError(message=f"The session with ID {session_id} does not exist")
×
993
    if session.spec is None:
×
994
        raise errors.ProgrammingError(
×
995
            message=f"The session {session_id} being patched is missing the expected 'spec' field.", quiet=True
996
        )
997
    cluster = await nb_config.k8s_v2_client.cluster_by_class_id(session.resource_class_id(), user)
×
998

999
    patch = AmaltheaSessionV1Alpha1Patch(spec=AmaltheaSessionV1Alpha1SpecPatch())
×
1000
    is_getting_hibernated: bool = False
×
1001

1002
    # Hibernation
1003
    # TODO: Some patching should only be done when the session is in some states to avoid inadvertent restarts
1004
    # Refresh tokens for git proxy
1005
    if (
×
1006
        body.state is not None
1007
        and body.state.value.lower() == State.Hibernated.value.lower()
1008
        and body.state.value.lower() != session.status.state.value.lower()
1009
    ):
1010
        # Session is being hibernated
1011
        patch.spec.hibernated = True
×
1012
        is_getting_hibernated = True
×
1013
    elif (
×
1014
        body.state is not None
1015
        and body.state.value.lower() == State.Running.value.lower()
1016
        and session.status.state.value.lower() != body.state.value.lower()
1017
    ):
1018
        # Session is being resumed
1019
        patch.spec.hibernated = False
×
1020
        await metrics.user_requested_session_resume(user, metadata={"session_id": session_id})
×
1021

1022
    # Resource class
1023
    if body.resource_class_id is not None:
×
1024
        new_cluster = await nb_config.k8s_v2_client.cluster_by_class_id(body.resource_class_id, user)
×
1025
        if new_cluster.id != cluster.id:
×
1026
            raise errors.ValidationError(
×
1027
                message=(
1028
                    f"The requested resource class {body.resource_class_id} is not in the "
1029
                    f"same cluster {cluster.id} as the current resource class {session.resource_class_id()}."
1030
                )
1031
            )
1032
        rp = await rp_repo.get_resource_pool_from_class(user, body.resource_class_id)
×
1033
        rc = rp.get_resource_class(body.resource_class_id)
×
1034
        if not rc:
×
1035
            raise errors.MissingResourceError(
×
1036
                message=f"The resource class you requested with ID {body.resource_class_id} does not exist"
1037
            )
1038
        # TODO: reject session classes which change the cluster
1039
        if not patch.metadata:
×
1040
            patch.metadata = AmaltheaSessionV1Alpha1MetadataPatch()
×
1041
        # Patch the resource class ID in the annotations
1042
        patch.metadata.annotations = {"renku.io/resource_class_id": str(body.resource_class_id)}
×
1043
        if not patch.spec.session:
×
1044
            patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
×
1045
        patch.spec.session.resources = resources_from_resource_class(rc)
×
1046
        # Tolerations
1047
        tolerations = tolerations_from_resource_class(rc, nb_config.sessions.tolerations_model)
×
1048
        patch.spec.tolerations = tolerations
×
1049
        # Affinities
1050
        patch.spec.affinity = node_affinity_from_resource_class(rc, nb_config.sessions.affinity_model)
×
1051
        # Priority class (if a quota is being used)
1052
        patch.spec.priorityClassName = rc.quota
×
1053
        patch.spec.culling = get_culling(user, rp, nb_config)
×
1054
        if rp.cluster is not None:
×
1055
            patch.spec.service_account_name = rp.cluster.service_account_name
×
1056

1057
    # If the session is being hibernated we do not need to patch anything else that is
1058
    # not specifically called for in the request body, we can refresh things when the user resumes.
1059
    if is_getting_hibernated:
×
1060
        return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch.to_rfc7386())
×
1061

1062
    server_name = session.metadata.name
×
1063
    launcher = await session_repo.get_launcher(user, session.launcher_id)
×
1064
    project = await project_repo.get_project(user=user, project_id=session.project_id)
×
1065
    environment = launcher.environment
×
1066
    work_dir = environment.working_directory
×
1067
    if not work_dir:
×
1068
        image_workdir = await core.docker_image_workdir(nb_config, environment.container_image, internal_gitlab_user)
×
1069
        work_dir_fallback = PurePosixPath("/home/jovyan")
×
1070
        work_dir = image_workdir or work_dir_fallback
×
1071
    storage_mount_fallback = work_dir / "work"
×
1072
    storage_mount = launcher.environment.mount_directory or storage_mount_fallback
×
1073
    secrets_mount_directory = storage_mount / project.secrets_mount_directory
×
1074
    session_secrets = await project_session_secret_repo.get_all_session_secrets_from_project(
×
1075
        user=user, project_id=project.id
1076
    )
1077
    git_providers = await git_provider_helper.get_providers(user=user)
×
1078
    repositories = repositories_from_project(project, git_providers)
×
1079

1080
    # User secrets
1081
    session_extras = SessionExtraResources()
×
1082
    session_extras = session_extras.concat(
×
1083
        user_secrets_extras(
1084
            user=user,
1085
            config=nb_config,
1086
            secrets_mount_directory=secrets_mount_directory.as_posix(),
1087
            k8s_secret_name=f"{server_name}-secrets",
1088
            session_secrets=session_secrets,
1089
        )
1090
    )
1091

1092
    # Data connectors: skip
1093
    # TODO: How can we patch data connectors? Should we even patch them?
1094
    # TODO: The fact that `start_session()` accepts overrides for data connectors
1095
    # TODO: but that we do not save these overrides (e.g. as annotations) means that
1096
    # TODO: we cannot patch data connectors upon resume.
1097
    # TODO: If we did, we would lose the user's provided overrides (e.g. unsaved credentials).
1098

1099
    # More init containers
1100
    session_extras = session_extras.concat(
×
1101
        await get_extra_init_containers(
1102
            nb_config,
1103
            user,
1104
            repositories,
1105
            git_providers,
1106
            storage_mount,
1107
            work_dir,
1108
            uid=environment.uid,
1109
            gid=environment.gid,
1110
        )
1111
    )
1112

1113
    # Extra containers
1114
    session_extras = session_extras.concat(await get_extra_containers(nb_config, user, repositories, git_providers))
×
1115

1116
    # Patching the image pull secret
1117
    image = session.spec.session.image
×
1118
    image_pull_secret = await get_image_pull_secret(
×
1119
        image=image,
1120
        server_name=server_name,
1121
        nb_config=nb_config,
1122
        connected_svcs_repo=connected_svcs_repo,
1123
        user=user,
1124
        internal_gitlab_user=internal_gitlab_user,
1125
    )
1126
    if image_pull_secret:
×
1127
        session_extras.concat(SessionExtraResources(secrets=[image_pull_secret]))
×
1128
        patch.spec.imagePullSecrets = [ImagePullSecret(name=image_pull_secret.name, adopt=image_pull_secret.adopt)]
×
1129

1130
    # Construct session patch
1131
    patch.spec.extraContainers = _make_patch_spec_list(
×
1132
        existing=session.spec.extraContainers or [], updated=session_extras.containers
1133
    )
1134
    patch.spec.initContainers = _make_patch_spec_list(
×
1135
        existing=session.spec.initContainers or [], updated=session_extras.init_containers
1136
    )
1137
    patch.spec.extraVolumes = _make_patch_spec_list(
×
1138
        existing=session.spec.extraVolumes or [], updated=session_extras.volumes
1139
    )
1140
    if not patch.spec.session:
×
1141
        patch.spec.session = AmaltheaSessionV1Alpha1SpecSessionPatch()
×
1142
    patch.spec.session.extraVolumeMounts = _make_patch_spec_list(
×
1143
        existing=session.spec.session.extraVolumeMounts or [], updated=session_extras.volume_mounts
1144
    )
1145

1146
    secrets_to_create = session_extras.secrets or []
×
1147
    for s in secrets_to_create:
×
1148
        await nb_config.k8s_v2_client.create_secret(K8sSecret.from_v1_secret(s.secret, cluster))
×
1149

1150
    patch_serialized = patch.to_rfc7386()
×
1151
    if len(patch_serialized) == 0:
×
1152
        return session
×
1153

1154
    return await nb_config.k8s_v2_client.patch_session(session_id, user.id, patch_serialized)
×
1155

1156

1157
def _deduplicate_target_paths(dcs: dict[str, RCloneStorage]) -> dict[str, RCloneStorage]:
2✔
1158
    """Ensures that the target paths for all storages are unique.
1159

1160
    This method will attempt to de-duplicate the target_path for all items passed in,
1161
    and raise an error if it fails to generate unique target_path.
1162
    """
1163
    result_dcs: dict[str, RCloneStorage] = {}
×
1164
    mount_folders: dict[str, list[str]] = {}
×
1165

1166
    def _find_mount_folder(dc: RCloneStorage) -> str:
×
1167
        mount_folder = dc.mount_folder
×
1168
        if mount_folder not in mount_folders:
×
1169
            return mount_folder
×
1170
        # 1. Try with a "-1", "-2", etc. suffix
1171
        mount_folder_try = f"{mount_folder}-{len(mount_folders[mount_folder])}"
×
1172
        if mount_folder_try not in mount_folders:
×
1173
            return mount_folder_try
×
1174
        # 2. Try with a random suffix
1175
        suffix = "".join([random.choice(string.ascii_lowercase + string.digits) for _ in range(4)])  # nosec B311
×
1176
        mount_folder_try = f"{mount_folder}-{suffix}"
×
1177
        if mount_folder_try not in mount_folders:
×
1178
            return mount_folder_try
×
1179
        raise errors.ValidationError(
×
1180
            message=f"Could not start session because two or more data connectors ({', '.join(mount_folders[mount_folder])}) share the same mount point '{mount_folder}'"  # noqa E501
1181
        )
1182

1183
    for dc_id, dc in dcs.items():
×
1184
        original_mount_folder = dc.mount_folder
×
1185
        new_mount_folder = _find_mount_folder(dc)
×
1186
        # Keep track of the original mount folder here
1187
        if new_mount_folder != original_mount_folder:
×
1188
            logger.warning(f"Re-assigning data connector {dc_id} to mount point '{new_mount_folder}'")
×
1189
            dc_ids = mount_folders.get(original_mount_folder, [])
×
1190
            dc_ids.append(dc_id)
×
1191
            mount_folders[original_mount_folder] = dc_ids
×
1192
        # Keep track of the assigned mount folder here
1193
        dc_ids = mount_folders.get(new_mount_folder, [])
×
1194
        dc_ids.append(dc_id)
×
1195
        mount_folders[new_mount_folder] = dc_ids
×
1196
        result_dcs[dc_id] = dc.with_override(
×
1197
            override=SessionDataConnectorOverride(
1198
                skip=False,
1199
                data_connector_id=ULID.from_str(dc_id),
1200
                target_path=new_mount_folder,
1201
                configuration=None,
1202
                source_path=None,
1203
                readonly=None,
1204
            )
1205
        )
1206

1207
    return result_dcs
×
1208

1209

1210
class _NamedResource(Protocol):
2✔
1211
    """Represents a resource with a name."""
1212

1213
    name: str
2✔
1214

1215

1216
_T = TypeVar("_T", bound=_NamedResource)
2✔
1217

1218

1219
def _make_patch_spec_list(existing: Sequence[_T], updated: Sequence[_T]) -> list[_T] | None:
2✔
1220
    """Merges updated into existing by upserting items identified by their name.
1221

1222
    This method is used to construct session patches, merging session resources by name (containers, volumes, etc.).
1223
    """
1224
    patch_list = None
1✔
1225
    if updated:
1✔
1226
        patch_list = list(existing)
1✔
1227
        upsert_list = list(updated)
1✔
1228
        for upsert_item in upsert_list:
1✔
1229
            # Find out if the upsert_item needs to be added or updated
1230
            # found = next(enumerate(filter(lambda item: item.name == upsert_item.name, patch_list)), None)
1231
            found = next(filter(lambda t: t[1].name == upsert_item.name, enumerate(patch_list)), None)
1✔
1232
            if found is not None:
1✔
1233
                idx, _ = found
1✔
1234
                patch_list[idx] = upsert_item
1✔
1235
            else:
1236
                patch_list.append(upsert_item)
1✔
1237
    return patch_list
1✔
1238

1239

1240
def validate_session_post_request(body: apispec.SessionPostRequest) -> SessionLaunchRequest:
2✔
1241
    """Validate a session launch request."""
NEW
1242
    data_connectors_overrides = (
×
1243
        [
1244
            SessionDataConnectorOverride(
1245
                skip=dc.skip,
1246
                data_connector_id=ULID.from_str(dc.data_connector_id),
1247
                configuration=dc.configuration,
1248
                source_path=dc.source_path,
1249
                target_path=dc.target_path,
1250
                readonly=dc.readonly,
1251
            )
1252
            for dc in body.data_connectors_overrides
1253
        ]
1254
        if body.data_connectors_overrides
1255
        else None
1256
    )
NEW
1257
    env_variable_overrides = (
×
1258
        [SessionEnvVar(name=ev.name, value=ev.value) for ev in body.env_variable_overrides]
1259
        if body.env_variable_overrides
1260
        else None
1261
    )
NEW
1262
    return SessionLaunchRequest(
×
1263
        launcher_id=ULID.from_str(body.launcher_id),
1264
        disk_storage=body.disk_storage,
1265
        resource_class_id=body.resource_class_id,
1266
        data_connectors_overrides=data_connectors_overrides,
1267
        env_variable_overrides=env_variable_overrides,
1268
    )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc