• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 20339056754

18 Dec 2025 01:46PM UTC coverage: 86.009%. First build
20339056754

Pull #1128

github

web-flow
Merge fe560b56c into c5f960729
Pull Request #1128: feat: set lastInteraction time via session patch and return willHibernateAt

84 of 96 new or added lines in 11 files covered. (87.5%)

24037 of 27947 relevant lines covered (86.01%)

1.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.04
/components/renku_data_services/crc/db.py
1
"""Adapter based on SQLAlchemy.
2

3
These adapters currently do a few things (1) generate SQL queries, (2) apply resource access controls,
4
(3) fetch the SQL results and (4) format them into a workable representation. In the future and as the code
5
grows it is worth looking into separating this functionality into separate classes rather than having
6
it all in one place.
7
"""
8

9
from asyncio import gather
2✔
10
from collections.abc import AsyncGenerator, Callable, Collection, Coroutine, Sequence
2✔
11
from dataclasses import asdict, dataclass, field
2✔
12
from functools import wraps
2✔
13
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar
2✔
14

15
from sqlalchemy import NullPool, delete, false, select, true
2✔
16
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
2✔
17
from sqlalchemy.orm import selectinload
2✔
18
from sqlalchemy.sql import Select, and_, not_, or_
2✔
19
from ulid import ULID
2✔
20

21
import renku_data_services.base_models as base_models
2✔
22
from renku_data_services import errors
2✔
23
from renku_data_services.base_models import RESET
2✔
24
from renku_data_services.crc import models
2✔
25
from renku_data_services.crc import orm as schemas
2✔
26
from renku_data_services.crc.core import validate_resource_class_update, validate_resource_pool_update
2✔
27
from renku_data_services.crc.models import ClusterPatch, ClusterSettings, SavedClusterSettings, SessionProtocol
2✔
28
from renku_data_services.crc.orm import ClusterORM
2✔
29
from renku_data_services.k8s.db import QuotaRepository
2✔
30
from renku_data_services.users.db import UserRepo
2✔
31

32

33
class _Base:
2✔
34
    def __init__(self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository) -> None:
2✔
35
        self.session_maker = session_maker
2✔
36
        self.quotas_repo = quotas_repo
2✔
37

38

39
def _resource_pool_access_control(
2✔
40
    api_user: base_models.APIUser,
41
    stmt: Select[tuple[schemas.ResourcePoolORM]],
42
) -> Select[tuple[schemas.ResourcePoolORM]]:
43
    """Modifies a select query to list resource pools based on whether the user is logged in or not."""
44
    output = stmt
2✔
45
    match (api_user.is_authenticated, api_user.is_admin):
2✔
46
        case True, False:
2✔
47
            # The user is logged in but not an admin
48
            api_user_has_default_pool_access = not_(
1✔
49
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
50
                # record does NOT EXIST in the database
51
                select(schemas.UserORM.no_default_access)
52
                .where(and_(schemas.UserORM.keycloak_id == api_user.id, schemas.UserORM.no_default_access == true()))
53
                .exists()
54
            )
55
            output = output.join(schemas.UserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
56
                or_(
57
                    schemas.UserORM.keycloak_id == api_user.id,  # the user is part of the pool
58
                    and_(  # the pool is not default but is public
59
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
60
                    ),
61
                    and_(  # the pool is default and the user is not prohibited from accessing it
62
                        schemas.ResourcePoolORM.default == true(),
63
                        api_user_has_default_pool_access,
64
                    ),
65
                )
66
            )
67
        case True, True:
2✔
68
            # The user is logged in and is an admin, they can see all resource pools
69
            pass
2✔
70
        case False, _:
1✔
71
            # The user is not logged in, they can see only the public resource pools
72
            output = output.where(schemas.ResourcePoolORM.public == true())
1✔
73
    return output
2✔
74

75

76
def _classes_user_access_control(
2✔
77
    api_user: base_models.APIUser,
78
    stmt: Select[tuple[schemas.ResourceClassORM]],
79
) -> Select[tuple[schemas.ResourceClassORM]]:
80
    """Adjust the select statement for classes based on whether the user is logged in or not."""
81
    output = stmt
2✔
82
    match (api_user.is_authenticated, api_user.is_admin):
2✔
83
        case True, False:
2✔
84
            # The user is logged in but is not an admin
85
            api_user_has_default_pool_access = not_(
1✔
86
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
87
                # record does NOT EXIST in the database
88
                select(schemas.UserORM.no_default_access)
89
                .where(and_(schemas.UserORM.keycloak_id == api_user.id, schemas.UserORM.no_default_access == true()))
90
                .exists()
91
            )
92
            output = output.join(schemas.UserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
93
                or_(
94
                    schemas.UserORM.keycloak_id == api_user.id,  # the user is part of the pool
95
                    and_(  # the pool is not default but is public
96
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
97
                    ),
98
                    and_(  # the pool is default and the user is not prohibited from accessing it
99
                        schemas.ResourcePoolORM.default == true(),
100
                        api_user_has_default_pool_access,
101
                    ),
102
                )
103
            )
104
        case True, True:
2✔
105
            # The user is logged in and is an admin, they can see all resource classes
106
            pass
2✔
107
        case False, _:
×
108
            # The user is not logged in, they can see only the classes from public resource pools
109
            output = output.join(schemas.UserORM, schemas.ResourcePoolORM.users, isouter=True).where(
×
110
                schemas.ResourcePoolORM.public == true(),
111
            )
112
    return output
2✔
113

114

115
_P = ParamSpec("_P")
2✔
116
_T = TypeVar("_T")
2✔
117

118

119
def _only_admins(
2✔
120
    f: Callable[Concatenate[Any, _P], Coroutine[Any, Any, _T]],
121
) -> Callable[Concatenate[Any, _P], Coroutine[Any, Any, _T]]:
122
    """Decorator that errors out if the user is not an admin.
123

124
    It expects the APIUser model to be a named parameter in the decorated function or
125
    to be the first parameter (after self).
126
    """
127

128
    @wraps(f)
2✔
129
    async def decorated_function(self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _T:
2✔
130
        api_user = None
2✔
131
        if "api_user" in kwargs:
2✔
132
            api_user = kwargs["api_user"]
2✔
133
        elif len(args) >= 1:
2✔
134
            api_user = args[0]
2✔
135
        if api_user is not None and not isinstance(api_user, base_models.APIUser):
2✔
136
            raise errors.ProgrammingError(message="Expected user parameter is not of type APIUser.")
×
137
        if api_user is None:
2✔
138
            raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
×
139
        if not api_user.is_admin:
2✔
140
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
141

142
        # the user is authenticated and is an admin
143
        response = await f(self, *args, **kwargs)
2✔
144
        return response
2✔
145

146
    return decorated_function
2✔
147

148

149
class ResourcePoolRepository(_Base):
2✔
150
    """The adapter used for accessing resource pools with SQLAlchemy."""
151

152
    def __init__(self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository):
2✔
153
        super().__init__(session_maker, quotas_repo)
2✔
154
        self.__cluster_repo = ClusterRepository(session_maker=self.session_maker)
2✔
155

156
    async def initialize(self, async_connection_url: str, rp: models.UnsavedResourcePool) -> None:
2✔
157
        """Add the default resource pool if it does not already exist."""
158
        engine = create_async_engine(async_connection_url, poolclass=NullPool)
×
159
        session_maker = async_sessionmaker(
×
160
            engine,
161
            expire_on_commit=True,
162
        )
163
        async with session_maker() as session, session.begin():
×
164
            stmt = select(schemas.ResourcePoolORM.default == true())
×
165
            res = await session.execute(stmt)
×
166
            default_rp = res.scalars().first()
×
167
            if default_rp is None:
×
168
                orm = schemas.ResourcePoolORM.from_unsaved_model(new_resource_pool=rp, quota=None, cluster=None)
×
169
                session.add(orm)
×
170

171
    async def get_resource_pools(
2✔
172
        self, api_user: base_models.APIUser, id: Optional[int] = None, name: Optional[str] = None
173
    ) -> list[models.ResourcePool]:
174
        """Get resource pools from database."""
175
        async with self.session_maker() as session:
2✔
176
            stmt = (
2✔
177
                select(schemas.ResourcePoolORM)
178
                .options(selectinload(schemas.ResourcePoolORM.classes))
179
                .options(selectinload(schemas.ResourcePoolORM.cluster))
180
            )
181
            if name is not None:
2✔
182
                stmt = stmt.where(schemas.ResourcePoolORM.name == name)
1✔
183
            if id is not None:
2✔
184
                stmt = stmt.where(schemas.ResourcePoolORM.id == id)
2✔
185
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
186
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
187
            res = await session.execute(stmt)
2✔
188
            orms = res.scalars().all()
2✔
189
            output: list[models.ResourcePool] = []
2✔
190
            for rp in orms:
2✔
191
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
192
                output.append(rp.dump(quota))
1✔
193
            return output
2✔
194

195
    async def get_resource_pool_from_class(
2✔
196
        self, api_user: base_models.APIUser, resource_class_id: int
197
    ) -> models.ResourcePool:
198
        """Get the resource pool the class belongs to."""
199
        async with self.session_maker() as session:
1✔
200
            stmt = (
1✔
201
                select(schemas.ResourcePoolORM)
202
                .where(schemas.ResourcePoolORM.classes.any(schemas.ResourceClassORM.id == resource_class_id))
203
                .options(selectinload(schemas.ResourcePoolORM.classes))
204
                .options(selectinload(schemas.ResourcePoolORM.cluster))
205
            )
206
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
207
            stmt = _resource_pool_access_control(api_user, stmt)
1✔
208
            res = await session.execute(stmt)
1✔
209
            orm = res.scalar()
1✔
210
            if orm is None:
1✔
211
                raise errors.MissingResourceError(
×
212
                    message=f"Could not find the resource pool where a class with ID {resource_class_id} exists."
213
                )
214
            quota = self.quotas_repo.get_quota(orm.quota) if orm.quota else None
1✔
215
            return orm.dump(quota)
1✔
216

217
    async def get_default_resource_pool(self) -> models.ResourcePool:
2✔
218
        """Get the default resource pool."""
219
        async with self.session_maker() as session:
×
220
            stmt = (
×
221
                select(schemas.ResourcePoolORM)
222
                .where(schemas.ResourcePoolORM.default == true())
223
                .options(selectinload(schemas.ResourcePoolORM.classes))
224
            )
225
            res = await session.scalar(stmt)
×
226
            if res is None:
×
227
                raise errors.ProgrammingError(
×
228
                    message="Could not find the default resource pool, but this has to exist."
229
                )
230
            quota = self.quotas_repo.get_quota(res.quota) if res.quota else None
×
231
            return res.dump(quota)
×
232

233
    async def get_default_resource_class(self) -> models.ResourceClass:
2✔
234
        """Get the default resource class in the default resource pool."""
235
        async with self.session_maker() as session:
×
236
            stmt = (
×
237
                select(schemas.ResourceClassORM)
238
                .where(schemas.ResourceClassORM.default == true())
239
                .where(schemas.ResourceClassORM.resource_pool.has(schemas.ResourcePoolORM.default == true()))
240
            )
241
            res = await session.scalar(stmt)
×
242
            if res is None:
×
243
                raise errors.ProgrammingError(
×
244
                    message="Could not find the default class from the default resource pool, but this has to exist."
245
                )
246
            return res.dump()
×
247

248
    async def filter_resource_pools(
2✔
249
        self,
250
        api_user: base_models.APIUser,
251
        cpu: float = 0,
252
        memory: int = 0,
253
        max_storage: int = 0,
254
        gpu: int = 0,
255
    ) -> list[models.ResourcePool]:
256
        """Get resource pools from database with indication of which resource class matches the specified criteria."""
257
        async with self.session_maker() as session:
2✔
258
            criteria = models.UnsavedResourceClass(
2✔
259
                name="criteria",
260
                cpu=cpu,
261
                gpu=gpu,
262
                memory=memory,
263
                max_storage=max_storage,
264
                # NOTE: the default storage has to be <= max_storage but is not used for filtering classes,
265
                # only the max_storage is used to filter resource classes that match a request
266
                default_storage=max_storage,
267
            )
268
            stmt = (
2✔
269
                select(schemas.ResourcePoolORM)
270
                .distinct()
271
                .options(selectinload(schemas.ResourcePoolORM.classes))
272
                .order_by(
273
                    schemas.ResourcePoolORM.id,
274
                    schemas.ResourcePoolORM.name,
275
                )
276
            )
277
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
278
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
279
            res = await session.execute(stmt)
2✔
280
            return [
2✔
281
                i.dump(quota=self.quotas_repo.get_quota(i.quota), class_match_criteria=criteria)
282
                for i in res.scalars().all()
283
            ]
284

285
    @_only_admins
2✔
286
    async def insert_resource_pool(
2✔
287
        self, api_user: base_models.APIUser, new_resource_pool: models.UnsavedResourcePool
288
    ) -> models.ResourcePool:
289
        """Insert resource pool into database."""
290

291
        cluster = None
1✔
292
        if new_resource_pool.cluster_id:
1✔
293
            cluster = await self.__cluster_repo.select(cluster_id=new_resource_pool.cluster_id)
1✔
294

295
        quota = None
1✔
296
        if new_resource_pool.quota is not None:
1✔
297
            quota = self.quotas_repo.create_quota(new_quota=new_resource_pool.quota)
1✔
298

299
        async with self.session_maker() as session, session.begin():
1✔
300
            resource_pool = schemas.ResourcePoolORM.from_unsaved_model(
1✔
301
                new_resource_pool=new_resource_pool, quota=quota, cluster=cluster
302
            )
303
            if resource_pool.default:
1✔
304
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.default == true())
1✔
305
                res = await session.execute(stmt)
1✔
306
                default_rps = res.unique().scalars().all()
1✔
307
                if len(default_rps) >= 1:
1✔
308
                    raise errors.ValidationError(
×
309
                        message="There can only be one default resource pool and one already exists."
310
                    )
311

312
            session.add(resource_pool)
1✔
313
            await session.flush()
1✔
314
            await session.refresh(resource_pool)
1✔
315
            return resource_pool.dump(quota=quota)
1✔
316

317
    async def get_classes(
2✔
318
        self,
319
        api_user: Optional[base_models.APIUser] = None,
320
        id: Optional[int] = None,
321
        name: Optional[str] = None,
322
        resource_pool_id: Optional[int] = None,
323
    ) -> list[models.ResourceClass]:
324
        """Get classes from the database."""
325
        async with self.session_maker() as session:
2✔
326
            stmt = select(schemas.ResourceClassORM).join(
2✔
327
                schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool, isouter=True
328
            )
329
            if resource_pool_id is not None:
2✔
330
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
331
            if id is not None:
2✔
332
                stmt = stmt.where(schemas.ResourceClassORM.id == id)
2✔
333
            if name is not None:
2✔
334
                stmt = stmt.where(schemas.ResourceClassORM.name == name)
2✔
335

336
            # Apply user access control if api_user is provided
337
            if api_user is not None:
2✔
338
                # NOTE: The line below ensures that the right users can access the right resources, do not remove.
339
                stmt = _classes_user_access_control(api_user, stmt)
2✔
340

341
            res = await session.execute(stmt)
2✔
342
            orms = res.scalars().all()
2✔
343
            return [orm.dump() for orm in orms]
2✔
344

345
    async def get_resource_class(self, api_user: base_models.APIUser, id: int) -> models.ResourceClass:
2✔
346
        """Get a specific resource class by its ID."""
347
        classes = await self.get_classes(api_user, id)
×
348
        if len(classes) == 0:
×
349
            raise errors.MissingResourceError(message=f"The resource class with ID {id} cannot be found")
×
350
        return classes[0]
×
351

352
    @_only_admins
2✔
353
    async def insert_resource_class(
2✔
354
        self,
355
        api_user: base_models.APIUser,
356
        new_resource_class: models.UnsavedResourceClass,
357
        *,
358
        resource_pool_id: Optional[int] = None,
359
    ) -> models.ResourceClass:
360
        """Insert a resource class in the database."""
361
        async with self.session_maker() as session, session.begin():
2✔
362
            resource_class = schemas.ResourceClassORM.from_unsaved_model(
2✔
363
                new_resource_class=new_resource_class, resource_pool_id=resource_pool_id
364
            )
365
            print(f"resource_class = {resource_class.resource_pool_id}")
2✔
366

367
            if resource_pool_id is not None:
2✔
368
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
369
                res = await session.execute(stmt)
2✔
370
                rp = res.scalars().first()
2✔
371
                if rp is None:
2✔
372
                    raise errors.MissingResourceError(
2✔
373
                        message=f"Resource pool with id {resource_pool_id} does not exist."
374
                    )
375
                resource_class.resource_pool = rp
1✔
376
                if resource_class.default and len(rp.classes) > 0 and any([icls.default for icls in rp.classes]):
1✔
377
                    raise errors.ValidationError(
×
378
                        message="There can only be one default resource class per resource pool."
379
                    )
380
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
381
                if quota and not quota.is_resource_class_compatible(new_resource_class):
1✔
382
                    raise errors.ValidationError(
×
383
                        message="The resource class {resource_class} is not compatible with the quota {quota}."
384
                    )
385

386
            session.add(resource_class)
1✔
387
            await session.flush()
1✔
388
            await session.refresh(resource_class)
1✔
389
            return resource_class.dump()
1✔
390

391
    @_only_admins
2✔
392
    async def update_resource_pool(
2✔
393
        self, api_user: base_models.APIUser, resource_pool_id: int, update: models.ResourcePoolPatch
394
    ) -> models.ResourcePool:
395
        """Update an existing resource pool in the database."""
396
        async with self.session_maker() as session, session.begin():
2✔
397
            stmt = (
2✔
398
                select(schemas.ResourcePoolORM)
399
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
400
                .options(selectinload(schemas.ResourcePoolORM.classes))
401
            )
402
            res = await session.scalars(stmt)
2✔
403
            rp = res.one_or_none()
2✔
404
            if rp is None:
2✔
405
                raise errors.MissingResourceError(message=f"Resource pool with id {resource_pool_id} cannot be found")
2✔
406
            quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
407

408
            validate_resource_pool_update(existing=rp.dump(quota=quota), update=update)
1✔
409

410
            if update.name is not None:
1✔
411
                rp.name = update.name
1✔
412
            if update.public is not None:
1✔
413
                rp.public = update.public
1✔
414
            if update.default is not None:
1✔
415
                rp.default = update.default
1✔
416
            if update.idle_threshold == 0 or update.idle_threshold is RESET:
1✔
417
                # Using "0" removes the value
418
                rp.idle_threshold = None
×
419
            elif isinstance(update.idle_threshold, int):
1✔
420
                rp.idle_threshold = update.idle_threshold
1✔
421
            if update.hibernation_threshold == 0 or update.hibernation_threshold is RESET:
1✔
422
                # Using "0" removes the value
423
                rp.hibernation_threshold = None
×
424
            elif isinstance(update.hibernation_threshold, int):
1✔
425
                rp.hibernation_threshold = update.hibernation_threshold
1✔
426
            if update.hibernation_warning_period == 0 or update.hibernation_warning_period is RESET:
1✔
NEW
427
                rp.hibernation_warning_period = None
×
428
            elif isinstance(update.hibernation_warning_period, int):
1✔
429
                rp.hibernation_warning_period = update.hibernation_warning_period
1✔
430
            if update.platform is not None:
1✔
431
                rp.platform = update.platform
1✔
432

433
            if update.cluster_id is RESET:
1✔
434
                rp.cluster_id = None
×
435
            elif update.cluster_id is not None:
1✔
436
                cluster = await self.__cluster_repo.select(update.cluster_id)
×
437
                rp.cluster_id = cluster.id
×
438

439
            if update.quota is RESET and rp.quota:
1✔
440
                # Remove the existing quota
441
                self.quotas_repo.delete_quota(name=rp.quota)
×
442
            elif isinstance(update.quota, models.QuotaPatch) and rp.quota is None:
1✔
443
                # Create a new quota, the `update.quota` object has already been validated
444
                assert update.quota.cpu is not None
×
445
                assert update.quota.memory is not None
×
446
                assert update.quota.gpu is not None
×
447
                new_quota = models.UnsavedQuota(
×
448
                    cpu=update.quota.cpu,
449
                    memory=update.quota.memory,
450
                    gpu=update.quota.gpu,
451
                )
452
                quota = self.quotas_repo.create_quota(new_quota=new_quota)
×
453
                rp.quota = quota.id
×
454
            elif isinstance(update.quota, models.QuotaPatch):
1✔
455
                assert rp.quota is not None
1✔
456
                assert quota is not None
1✔
457
                # Update the existing quota
458
                updated_quota = models.Quota(
1✔
459
                    cpu=update.quota.cpu if update.quota.cpu is not None else quota.cpu,
460
                    memory=update.quota.memory if update.quota.memory is not None else quota.memory,
461
                    gpu=update.quota.gpu if update.quota.gpu is not None else quota.gpu,
462
                    gpu_kind=update.quota.gpu_kind if update.quota.gpu_kind is not None else quota.gpu_kind,
463
                    id=quota.id,
464
                )
465
                quota = self.quotas_repo.update_quota(quota=updated_quota)
1✔
466
                rp.quota = quota.id
1✔
467

468
            new_classes_coroutines = []
1✔
469
            if update.classes is not None:
1✔
470
                for rc in update.classes:
1✔
471
                    new_classes_coroutines.append(
1✔
472
                        self.update_resource_class(
473
                            api_user=api_user, resource_pool_id=resource_pool_id, resource_class_id=rc.id, update=rc
474
                        )
475
                    )
476

477
            if update.remote is RESET:
1✔
478
                rp.remote_provider_id = None
1✔
479
                rp.remote_json = None
1✔
480
            elif update.remote is not None:
1✔
481
                rp.remote_provider_id = (
1✔
482
                    update.remote.provider_id if update.remote.provider_id is not None else rp.remote_provider_id
483
                )
484
                remote_json = rp.remote_json if rp.remote_json is not None else dict()
1✔
485
                remote_json.update(update.remote.to_dict())
1✔
486
                del remote_json["provider_id"]
1✔
487
                rp.remote_json = remote_json
1✔
488

489
            await gather(*new_classes_coroutines)
1✔
490
            await session.flush()
1✔
491
            await session.refresh(rp)
1✔
492
            return rp.dump(quota=quota)
1✔
493

494
    @_only_admins
2✔
495
    async def delete_resource_pool(self, api_user: base_models.APIUser, id: int) -> None:
2✔
496
        """Delete a resource pool from the database."""
497
        async with self.session_maker() as session, session.begin():
2✔
498
            stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == id)
2✔
499
            res = await session.execute(stmt)
2✔
500
            rp = res.scalars().first()
2✔
501
            if rp is not None:
2✔
502
                if rp.default:
1✔
503
                    raise errors.ValidationError(message="The default resource pool cannot be deleted.")
×
504
                await session.delete(rp)
1✔
505
                if rp.quota:
1✔
506
                    self.quotas_repo.delete_quota(rp.quota)
1✔
507
            return None
2✔
508

509
    @_only_admins
2✔
510
    async def delete_resource_class(
2✔
511
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int
512
    ) -> None:
513
        """Delete a specific resource class."""
514
        async with self.session_maker() as session, session.begin():
2✔
515
            stmt = (
2✔
516
                select(schemas.ResourceClassORM)
517
                .where(schemas.ResourceClassORM.id == resource_class_id)
518
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
519
            )
520
            res = await session.execute(stmt)
2✔
521
            cls = res.scalars().first()
2✔
522
            if cls is not None:
2✔
523
                if cls.default:
1✔
524
                    raise errors.ValidationError(message="The default resource class cannot be deleted.")
×
525
                await session.delete(cls)
1✔
526

527
    @_only_admins
2✔
528
    async def update_resource_class(
2✔
529
        self,
530
        api_user: base_models.APIUser,
531
        resource_pool_id: int,
532
        resource_class_id: int,
533
        update: models.ResourceClassPatch,
534
    ) -> models.ResourceClass:
535
        """Update a specific resource class."""
536
        async with self.session_maker() as session, session.begin():
2✔
537
            stmt = (
2✔
538
                select(schemas.ResourceClassORM)
539
                .where(schemas.ResourceClassORM.id == resource_class_id)
540
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
541
                .join(schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool)
542
                .options(selectinload(schemas.ResourceClassORM.resource_pool))
543
            )
544
            res = await session.scalars(stmt)
2✔
545
            cls = res.one_or_none()
2✔
546
            if cls is None:
2✔
547
                raise errors.MissingResourceError(
1✔
548
                    message=(
549
                        f"The resource class with id {resource_class_id} does not exist, the resource pool with "
550
                        f"id {resource_pool_id} does not exist or the requested resource class is not "
551
                        "associated with the resource pool"
552
                    )
553
                )
554

555
            validate_resource_class_update(existing=cls.dump(), update=update)
1✔
556

557
            # NOTE: updating the 'default' field is not supported, so it is skipped below
558
            if update.name is not None:
1✔
559
                cls.name = update.name
1✔
560
            if update.cpu is not None:
1✔
561
                cls.cpu = update.cpu
1✔
562
            if update.memory is not None:
1✔
563
                cls.memory = update.memory
1✔
564
            if update.max_storage is not None:
1✔
565
                cls.max_storage = update.max_storage
1✔
566
            if update.gpu is not None:
1✔
567
                cls.gpu = update.gpu
1✔
568
            if update.default_storage is not None:
1✔
569
                cls.default_storage = update.default_storage
1✔
570

571
            if update.node_affinities is not None:
1✔
572
                existing_affinities: dict[str, schemas.NodeAffintyORM] = {i.key: i for i in cls.node_affinities}
1✔
573
                new_affinities: dict[str, schemas.NodeAffintyORM] = {
1✔
574
                    i.key: schemas.NodeAffintyORM(
575
                        key=i.key,
576
                        required_during_scheduling=i.required_during_scheduling,
577
                    )
578
                    for i in update.node_affinities
579
                }
580
                for new_affinity_key, new_affinity in new_affinities.items():
1✔
581
                    if new_affinity_key in existing_affinities:
1✔
582
                        # UPDATE existing affinity
583
                        existing_affinity = existing_affinities[new_affinity_key]
1✔
584
                        if new_affinity.required_during_scheduling != existing_affinity.required_during_scheduling:
1✔
585
                            existing_affinity.required_during_scheduling = new_affinity.required_during_scheduling
1✔
586
                    else:
587
                        # CREATE a brand new affinity
588
                        cls.node_affinities.append(new_affinity)
1✔
589
                # REMOVE an affinity
590
                for existing_affinity_key, existing_affinity in existing_affinities.items():
1✔
591
                    if existing_affinity_key not in new_affinities:
1✔
592
                        cls.node_affinities.remove(existing_affinity)
1✔
593

594
            if update.tolerations is not None:
1✔
595
                existing_tolerations: dict[str, schemas.TolerationORM] = {tol.key: tol for tol in cls.tolerations}
1✔
596
                new_tolerations: dict[str, schemas.TolerationORM] = {
1✔
597
                    tol: schemas.TolerationORM(key=tol) for tol in update.tolerations
598
                }
599
                for new_tol_key, new_tol in new_tolerations.items():
1✔
600
                    if new_tol_key not in existing_tolerations:
1✔
601
                        # CREATE a brand new toleration
602
                        cls.tolerations.append(new_tol)
1✔
603
                # REMOVE a toleration
604
                for existing_tol_key, existing_tol in existing_tolerations.items():
1✔
605
                    if existing_tol_key not in new_tolerations:
1✔
606
                        cls.tolerations.remove(existing_tol)
1✔
607

608
            # NOTE: do we need to perform this check?
609
            if cls.resource_pool is None:
1✔
610
                raise errors.BaseError(
×
611
                    message="Unexpected internal error.",
612
                    detail=f"The resource class {resource_class_id} is not associated with any resource pool.",
613
                )
614

615
            await session.flush()
1✔
616
            await session.refresh(cls)
1✔
617

618
            cls_model = cls.dump()
1✔
619
            quota = self.quotas_repo.get_quota(cls_model.quota) if cls_model.quota else None
1✔
620
            if quota and not quota.is_resource_class_compatible(cls_model):
1✔
621
                raise errors.ValidationError(
×
622
                    message=f"The resource class {cls_model} is not compatible with the quota {quota}"
623
                )
624

625
            return cls_model
1✔
626

627
    @_only_admins
2✔
628
    async def get_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> list[str]:
2✔
629
        """Get all tolerations of a resource class."""
630
        async with self.session_maker() as session:
2✔
631
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
2✔
632
            if len(res_classes) == 0:
2✔
633
                raise errors.MissingResourceError(
1✔
634
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
635
                    f"class with ID {class_id} do not exist, or they are not related."
636
                )
637
            stmt = select(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
638
            res = await session.execute(stmt)
1✔
639
            return [i.key for i in res.scalars().all()]
1✔
640

641
    @_only_admins
2✔
642
    async def delete_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
643
        """Delete all tolerations for a specific resource class."""
644
        async with self.session_maker() as session, session.begin():
2✔
645
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
2✔
646
            if len(res_classes) == 0:
2✔
647
                raise errors.MissingResourceError(
1✔
648
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
649
                    f"class with ID {class_id} do not exist, or they are not related."
650
                )
651
            stmt = delete(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
652
            await session.execute(stmt)
1✔
653

654
    @_only_admins
2✔
655
    async def get_affinities(
2✔
656
        self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int
657
    ) -> list[models.NodeAffinity]:
658
        """Get all affinities for a resource class."""
659
        async with self.session_maker() as session:
2✔
660
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
2✔
661
            if len(res_classes) == 0:
2✔
662
                raise errors.MissingResourceError(
1✔
663
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
664
                    f"class with ID {class_id} do not exist, or they are not related."
665
                )
666
            stmt = select(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
667
            res = await session.execute(stmt)
1✔
668
            return [i.dump() for i in res.scalars().all()]
1✔
669

670
    @_only_admins
2✔
671
    async def delete_affinities(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
672
        """Delete all affinities from a resource class."""
673
        async with self.session_maker() as session, session.begin():
2✔
674
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
2✔
675
            if len(res_classes) == 0:
2✔
676
                raise errors.MissingResourceError(
1✔
677
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
678
                    f"class with ID {class_id} do not exist, or they are not related."
679
                )
680
            stmt = delete(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
681
            await session.execute(stmt)
1✔
682

683
    async def get_quota(self, api_user: base_models.APIUser, resource_pool_id: int) -> models.Quota:
2✔
684
        """Get the quota of a resource pool."""
685
        rps = await self.get_resource_pools(api_user=api_user, id=resource_pool_id)
2✔
686
        if len(rps) < 1:
2✔
687
            raise errors.MissingResourceError(message=f"Cannot find the resource pool with ID {resource_pool_id}.")
1✔
688
        rp = rps[0]
1✔
689
        if rp.quota is None:
1✔
690
            raise errors.MissingResourceError(
×
691
                message=f"The resource pool with ID {resource_pool_id} does not have a quota."
692
            )
693
        return rp.quota
1✔
694

695
    @_only_admins
2✔
696
    async def update_quota(
2✔
697
        self,
698
        api_user: base_models.APIUser,
699
        resource_pool_id: int,
700
        update: models.QuotaPatch,
701
        quota_put_id: str | None = None,
702
    ) -> models.Quota:
703
        """Update the quota of a resource pool."""
704
        rps = await self.get_resource_pools(api_user=api_user, id=resource_pool_id)
2✔
705
        if len(rps) < 1:
2✔
706
            raise errors.MissingResourceError(message=f"Cannot find the resource pool with ID {resource_pool_id}.")
1✔
707
        rp = rps[0]
1✔
708
        if rp.quota is None:
1✔
709
            raise errors.MissingResourceError(
×
710
                message=f"The resource pool with ID {resource_pool_id} does not have a quota."
711
            )
712
        old_quota = rp.quota
1✔
713
        new_quota = models.Quota(
1✔
714
            cpu=update.cpu if update.cpu is not None else old_quota.cpu,
715
            memory=update.memory if update.memory is not None else old_quota.memory,
716
            gpu=update.gpu if update.gpu is not None else old_quota.gpu,
717
            gpu_kind=update.gpu_kind if update.gpu_kind is not None else old_quota.gpu_kind,
718
            id=quota_put_id or old_quota.id,
719
        )
720
        if new_quota.id != old_quota.id:
1✔
721
            raise errors.ValidationError(message="The 'id' field of a quota is immutable.")
×
722

723
        for rc in rp.classes:
1✔
724
            if not new_quota.is_resource_class_compatible(rc):
1✔
725
                raise errors.ValidationError(
×
726
                    message=f"The quota {new_quota} is not compatible with the resource class {rc}."
727
                )
728

729
        return self.quotas_repo.update_quota(quota=new_quota)
1✔
730

731

732
@dataclass
2✔
733
class Respository2Users:
2✔
734
    """Information about which users can access a specific resource pool."""
735

736
    resource_pool_id: int
2✔
737
    allowed: list[base_models.User] = field(default_factory=list)
2✔
738
    disallowed: list[base_models.User] = field(default_factory=list)
2✔
739

740

741
class UserRepository(_Base):
2✔
742
    """The adapter used for accessing resource pool users with SQLAlchemy."""
743

744
    def __init__(
2✔
745
        self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository, user_repo: UserRepo
746
    ) -> None:
747
        super().__init__(session_maker, quotas_repo)
2✔
748
        self.kc_user_repo = user_repo
2✔
749

750
    @_only_admins
2✔
751
    async def get_resource_pool_users(
2✔
752
        self,
753
        *,
754
        api_user: base_models.APIUser,
755
        resource_pool_id: int,
756
        keycloak_id: Optional[str] = None,
757
    ) -> Respository2Users:
758
        """Get users of a specific resource pool from the database."""
759
        async with self.session_maker() as session, session.begin():
2✔
760
            stmt = (
2✔
761
                select(schemas.ResourcePoolORM)
762
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
763
                .options(selectinload(schemas.ResourcePoolORM.users))
764
            )
765
            if keycloak_id is not None:
2✔
766
                stmt = stmt.join(schemas.ResourcePoolORM.users, isouter=True).where(
1✔
767
                    or_(
768
                        schemas.UserORM.keycloak_id == keycloak_id,
769
                        schemas.ResourcePoolORM.public == true(),
770
                        schemas.ResourceClassORM.default == true(),
771
                    )
772
                )
773
            res = await session.execute(stmt)
2✔
774
            rp = res.scalars().first()
2✔
775
            if rp is None:
2✔
776
                raise errors.MissingResourceError(message=f"Resource pool with id {resource_pool_id} does not exist")
1✔
777
            specific_user: base_models.User | None = None
1✔
778
            if keycloak_id:
1✔
779
                specific_user_res = (
×
780
                    await session.execute(select(schemas.UserORM).where(schemas.UserORM.keycloak_id == keycloak_id))
781
                ).scalar_one_or_none()
782
                specific_user = None if not specific_user_res else specific_user_res.dump()
×
783
            allowed: list[base_models.User] = []
1✔
784
            disallowed: list[base_models.User] = []
1✔
785
            if rp.default:
1✔
786
                disallowed_stmt = select(schemas.UserORM).where(schemas.UserORM.no_default_access == true())
1✔
787
                if keycloak_id:
1✔
788
                    disallowed_stmt = disallowed_stmt.where(schemas.UserORM.keycloak_id == keycloak_id)
×
789
                disallowed_res = await session.execute(disallowed_stmt)
1✔
790
                disallowed = [user.dump() for user in disallowed_res.scalars().all()]
1✔
791
                if specific_user and specific_user not in disallowed:
1✔
792
                    allowed = [specific_user]
×
793
            elif rp.public and not rp.default:
1✔
794
                if specific_user:
×
795
                    allowed = [specific_user]
×
796
            elif not rp.public and not rp.default:
1✔
797
                allowed = [user.dump() for user in rp.users]
1✔
798
            return Respository2Users(rp.id, allowed, disallowed)
1✔
799

800
    async def get_user_resource_pools(
2✔
801
        self,
802
        api_user: base_models.APIUser,
803
        keycloak_id: str,
804
        resource_pool_id: Optional[int] = None,
805
        resource_pool_name: Optional[str] = None,
806
    ) -> list[models.ResourcePool]:
807
        """Get resource pools that a specific user has access to."""
808
        async with self.session_maker() as session, session.begin():
2✔
809
            if not api_user.is_admin and api_user.id != keycloak_id:
2✔
810
                raise errors.ValidationError(
×
811
                    message="Users cannot query for resource pools that belong to other users."
812
                )
813

814
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
815
            stmt = stmt.where(
2✔
816
                or_(
817
                    schemas.ResourcePoolORM.public == true(),
818
                    schemas.ResourcePoolORM.users.any(schemas.UserORM.keycloak_id == keycloak_id),
819
                )
820
            )
821
            if resource_pool_name is not None:
2✔
822
                stmt = stmt.where(schemas.ResourcePoolORM.name == resource_pool_name)
×
823
            if resource_pool_id is not None:
2✔
824
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
825
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
826
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
827
            res = await session.execute(stmt)
2✔
828
            rps: Sequence[schemas.ResourcePoolORM] = res.scalars().all()
2✔
829
            output: list[models.ResourcePool] = []
2✔
830
            for rp in rps:
2✔
831
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
832
                output.append(rp.dump(quota))
1✔
833
            return output
2✔
834

835
    @_only_admins
2✔
836
    async def update_user_resource_pools(
2✔
837
        self, api_user: base_models.APIUser, keycloak_id: str, resource_pool_ids: list[int], append: bool = True
838
    ) -> list[models.ResourcePool]:
839
        """Update the resource pools that a specific user has access to."""
840
        async with self.session_maker() as session, session.begin():
1✔
841
            kc_user = await self.kc_user_repo.get_user(keycloak_id)
1✔
842
            if kc_user is None:
1✔
843
                raise errors.MissingResourceError(message=f"The user with ID {keycloak_id} does not exist")
×
844
            stmt = (
1✔
845
                select(schemas.UserORM)
846
                .where(schemas.UserORM.keycloak_id == keycloak_id)
847
                .options(selectinload(schemas.UserORM.resource_pools))
848
            )
849
            res = await session.execute(stmt)
1✔
850
            user = res.scalars().first()
1✔
851
            if user is None:
1✔
852
                user = schemas.UserORM(keycloak_id=keycloak_id)
1✔
853
                session.add(user)
1✔
854
            stmt_rp = (
1✔
855
                select(schemas.ResourcePoolORM)
856
                .where(schemas.ResourcePoolORM.id.in_(resource_pool_ids))
857
                .options(selectinload(schemas.ResourcePoolORM.classes))
858
            )
859
            if user.no_default_access:
1✔
860
                stmt_rp = stmt_rp.where(schemas.ResourcePoolORM.default == false())
×
861
            res_rp = await session.execute(stmt_rp)
1✔
862
            rps_to_add = res_rp.scalars().all()
1✔
863
            if len(rps_to_add) != len(resource_pool_ids):
1✔
864
                missing_rps = set(resource_pool_ids).difference(set([i.id for i in rps_to_add]))
×
865
                raise errors.MissingResourceError(
×
866
                    message=(
867
                        f"The resource pools with ids: {missing_rps} do not exist or user doesn't have access to "
868
                        "default resource pool."
869
                    )
870
                )
871
            if user.no_default_access:
1✔
872
                default_rp = next((rp for rp in rps_to_add if rp.default), None)
×
873
                if default_rp:
×
874
                    raise errors.ForbiddenError(
×
875
                        message=f"User with keycloak id {keycloak_id} cannot access the default resource pool"
876
                    )
877
            if append:
1✔
878
                user_rp_ids = {rp.id for rp in user.resource_pools}
1✔
879
                rps_to_add = [rp for rp in rps_to_add if rp.id not in user_rp_ids]
1✔
880
                user.resource_pools.extend(rps_to_add)
1✔
881
            else:
882
                user.resource_pools = list(rps_to_add)
1✔
883
            output: list[models.ResourcePool] = []
1✔
884
            for rp in rps_to_add:
1✔
885
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
886
                output.append(rp.dump(quota))
1✔
887
            return output
1✔
888

889
    @_only_admins
2✔
890
    async def delete_resource_pool_user(
2✔
891
        self, api_user: base_models.APIUser, resource_pool_id: int, keycloak_id: str
892
    ) -> None:
893
        """Remove a user from a specific resource pool."""
894
        async with self.session_maker() as session, session.begin():
1✔
895
            sub = (
1✔
896
                select(schemas.UserORM.id)
897
                .join(schemas.ResourcePoolORM, schemas.UserORM.resource_pools)
898
                .where(schemas.UserORM.keycloak_id == keycloak_id)
899
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
900
            )
901
            stmt = delete(schemas.resource_pools_users).where(schemas.resource_pools_users.c.user_id.in_(sub))
1✔
902
            await session.execute(stmt)
1✔
903

904
    @_only_admins
2✔
905
    async def update_resource_pool_users(
2✔
906
        self, api_user: base_models.APIUser, resource_pool_id: int, user_ids: Collection[str], append: bool = True
907
    ) -> list[base_models.User]:
908
        """Update the users to have access to a specific resource pool."""
909
        async with self.session_maker() as session, session.begin():
2✔
910
            stmt = (
2✔
911
                select(schemas.ResourcePoolORM)
912
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
913
                .options(
914
                    selectinload(schemas.ResourcePoolORM.users),
915
                    selectinload(schemas.ResourcePoolORM.classes),
916
                )
917
            )
918
            res = await session.execute(stmt)
2✔
919
            rp: Optional[schemas.ResourcePoolORM] = res.scalars().first()
2✔
920
            if rp is None:
2✔
921
                raise errors.MissingResourceError(
1✔
922
                    message=f"The resource pool with id {resource_pool_id} does not exist"
923
                )
924
            if rp.default:
1✔
925
                # NOTE: If the resource pool is default just check if any users are prevented from having
926
                # default resource pool access - and remove the restriction.
927
                all_existing_users = await self.get_resource_pool_users(
1✔
928
                    api_user=api_user, resource_pool_id=resource_pool_id
929
                )
930
                users_to_modify = [user for user in all_existing_users.disallowed if user.keycloak_id in user_ids]
1✔
931
                return await gather(
1✔
932
                    *[
933
                        self.update_user(
934
                            api_user=api_user, keycloak_id=no_default_user.keycloak_id, no_default_access=False
935
                        )
936
                        for no_default_user in users_to_modify
937
                    ]
938
                )
939
            stmt_usr = select(schemas.UserORM).where(schemas.UserORM.keycloak_id.in_(user_ids))
1✔
940
            res_usr = await session.execute(stmt_usr)
1✔
941
            users_to_add_exist = res_usr.scalars().all()
1✔
942
            user_ids_to_add_exist = [i.keycloak_id for i in users_to_add_exist]
1✔
943
            users_to_add_missing = [
1✔
944
                schemas.UserORM(keycloak_id=user_id) for user_id in user_ids if user_id not in user_ids_to_add_exist
945
            ]
946
            if append:
1✔
947
                rp_user_ids = {rp.id for rp in rp.users}
1✔
948
                users_to_add = [u for u in list(users_to_add_exist) + users_to_add_missing if u.id not in rp_user_ids]
1✔
949
                rp.users.extend(users_to_add)
1✔
950
            else:
951
                rp.users = list(users_to_add_exist) + users_to_add_missing
1✔
952
            return [usr.dump() for usr in rp.users]
1✔
953

954
    @_only_admins
2✔
955
    async def update_user(self, api_user: base_models.APIUser, keycloak_id: str, **kwargs: Any) -> base_models.User:
2✔
956
        """Update a specific user."""
957
        async with self.session_maker() as session, session.begin():
1✔
958
            stmt = select(schemas.UserORM).where(schemas.UserORM.keycloak_id == keycloak_id)
1✔
959
            res = await session.execute(stmt)
1✔
960
            user: Optional[schemas.UserORM] = res.scalars().first()
1✔
961
            if not user:
1✔
962
                user = schemas.UserORM(keycloak_id=keycloak_id)
1✔
963
                session.add(user)
1✔
964
            allowed_updates = {"no_default_access"}
1✔
965
            if not set(kwargs.keys()).issubset(allowed_updates):
1✔
966
                raise errors.ValidationError(
×
967
                    message=f"Only the following fields {allowed_updates} can be updated for a resource pool user.."
968
                )
969
            if (no_default_access := kwargs.get("no_default_access")) is not None:
1✔
970
                user.no_default_access = no_default_access
1✔
971
            return user.dump()
1✔
972

973

974
@dataclass
2✔
975
class ClusterRepository:
2✔
976
    """Repository for cluster configurations."""
977

978
    session_maker: Callable[..., AsyncSession]
2✔
979

980
    async def select_all(self, cluster_id: ULID | None = None) -> AsyncGenerator[SavedClusterSettings, Any]:
2✔
981
        """Get cluster configurations from the database."""
982
        async with self.session_maker() as session:
2✔
983
            query = select(ClusterORM)
2✔
984
            if cluster_id is not None:
2✔
985
                query = query.where(ClusterORM.id == cluster_id)
2✔
986

987
            clusters = await session.stream_scalars(query)
2✔
988
            async for cluster in clusters:
2✔
989
                yield cluster.dump()
1✔
990

991
    async def select(self, cluster_id: ULID) -> SavedClusterSettings:
2✔
992
        """Get cluster configurations from the database."""
993
        async for cluster in self.select_all(cluster_id):
2✔
994
            return cluster
1✔
995

996
        raise errors.MissingResourceError(message=f"Cluster definition id='{cluster_id}' does not exist.")
2✔
997

998
    @_only_admins
2✔
999
    async def insert(self, api_user: base_models.APIUser, cluster: ClusterSettings) -> ClusterSettings:
2✔
1000
        """Creates a new cluster configuration."""
1001

1002
        cluster_orm = ClusterORM.load(cluster)
2✔
1003
        async with self.session_maker() as session, session.begin():
2✔
1004
            session.add(cluster_orm)
2✔
1005
            await session.flush()
2✔
1006
            await session.refresh(cluster_orm)
2✔
1007

1008
            return cluster_orm.dump()
2✔
1009

1010
    @_only_admins
2✔
1011
    async def update(self, api_user: base_models.APIUser, cluster: ClusterPatch, cluster_id: ULID) -> ClusterSettings:
2✔
1012
        """Updates a cluster configuration."""
1013

1014
        async with self.session_maker() as session, session.begin():
2✔
1015
            saved_cluster = (await session.scalars(select(ClusterORM).where(ClusterORM.id == cluster_id))).one_or_none()
2✔
1016
            if saved_cluster is None:
2✔
1017
                raise errors.MissingResourceError(message=f"Cluster definition id='{cluster_id}' does not exist.")
2✔
1018

1019
            for key, value in asdict(cluster).items():
1✔
1020
                match key, value:
1✔
1021
                    case "session_protocol", SessionProtocol():
1✔
1022
                        setattr(saved_cluster, key, value.value)
1✔
1023
                    case "session_storage_class", "":
1✔
1024
                        # If we received an empty string in the storage class, reset it to the default storage class by
1025
                        # setting it to None.
1026
                        setattr(saved_cluster, key, None)
×
1027
                    case "service_account_name", "":
1✔
1028
                        # If we received an empty string in the service account name, set it back to None.
1029
                        setattr(saved_cluster, key, None)
×
1030
                    case _, None:
1✔
1031
                        # Do not modify a value which has not been set in the patch
1032
                        pass
1✔
1033
                    case _, _:
1✔
1034
                        setattr(saved_cluster, key, value)
1✔
1035

1036
            await session.flush()
1✔
1037
            await session.refresh(saved_cluster)
1✔
1038

1039
            return saved_cluster.dump()
1✔
1040

1041
    @_only_admins
2✔
1042
    async def delete(self, api_user: base_models.APIUser, cluster_id: ULID) -> None:
2✔
1043
        """Get cluster configurations from the database."""
1044

1045
        async with self.session_maker() as session, session.begin():
2✔
1046
            r = await session.scalars(select(ClusterORM).where(ClusterORM.id == cluster_id))
2✔
1047
            cluster = r.one_or_none()
2✔
1048
            if cluster is not None:
2✔
1049
                await session.delete(cluster)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc