• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 12318338392

13 Dec 2024 03:21PM UTC coverage: 86.388% (-0.001%) from 86.389%
12318338392

Pull #572

github

web-flow
Merge 0eb4f5dce into 9cb5d8146
Pull Request #572: feat: manage v1 sessions

17 of 32 new or added lines in 5 files covered. (53.13%)

4 existing lines in 4 files now uncovered.

14641 of 16948 relevant lines covered (86.39%)

1.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

85.68
/components/renku_data_services/crc/db.py
1
"""Adapter based on SQLAlchemy.
2

3
These adapters currently do a few things (1) generate SQL queries, (2) apply resource access controls,
4
(3) fetch the SQL results and (4) format them into a workable representation. In the future and as the code
5
grows it is worth looking into separating this functionality into separate classes rather than having
6
it all in one place.
7
"""
8

9
from asyncio import gather
2✔
10
from collections.abc import Callable, Collection, Coroutine, Sequence
2✔
11
from dataclasses import dataclass, field
2✔
12
from functools import wraps
2✔
13
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar, cast
2✔
14

15
from sqlalchemy import NullPool, delete, false, select, true
2✔
16
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
2✔
17
from sqlalchemy.orm import selectinload
2✔
18
from sqlalchemy.sql import Select, and_, not_, or_
2✔
19

20
import renku_data_services.base_models as base_models
2✔
21
from renku_data_services import errors
2✔
22
from renku_data_services.crc import models
2✔
23
from renku_data_services.crc import orm as schemas
2✔
24
from renku_data_services.k8s.quota import QuotaRepository
2✔
25
from renku_data_services.users.db import UserRepo
2✔
26

27

28
class _Base:
2✔
29
    def __init__(self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository) -> None:
2✔
30
        self.session_maker = session_maker
2✔
31
        self.quotas_repo = quotas_repo
2✔
32

33

34
def _resource_pool_access_control(
2✔
35
    api_user: base_models.APIUser,
36
    stmt: Select[tuple[schemas.ResourcePoolORM]],
37
) -> Select[tuple[schemas.ResourcePoolORM]]:
38
    """Modifies a select query to list resource pools based on whether the user is logged in or not."""
39
    output = stmt
2✔
40
    match (api_user.is_authenticated, api_user.is_admin):
2✔
41
        case True, False:
2✔
42
            # The user is logged in but not an admin
43
            api_user_has_default_pool_access = not_(
1✔
44
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
45
                # record does NOT EXIST in the database
46
                select(schemas.RPUserORM.no_default_access)
47
                .where(
48
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
49
                )
50
                .exists()
51
            )
52
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
53
                or_(
54
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
55
                    and_(  # the pool is not default but is public
56
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
57
                    ),
58
                    and_(  # the pool is default and the user is not prohibited from accessing it
59
                        schemas.ResourcePoolORM.default == true(),
60
                        api_user_has_default_pool_access,
61
                    ),
62
                )
63
            )
64
        case True, True:
2✔
65
            # The user is logged in and is an admin, they can see all resource pools
66
            pass
2✔
67
        case False, _:
1✔
68
            # The user is not logged in, they can see only the public resource pools
69
            output = output.where(schemas.ResourcePoolORM.public == true())
1✔
70
    return output
2✔
71

72

73
def _classes_user_access_control(
2✔
74
    api_user: base_models.APIUser,
75
    stmt: Select[tuple[schemas.ResourceClassORM]],
76
) -> Select[tuple[schemas.ResourceClassORM]]:
77
    """Adjust the select statement for classes based on whether the user is logged in or not."""
78
    output = stmt
2✔
79
    match (api_user.is_authenticated, api_user.is_admin):
2✔
80
        case True, False:
2✔
81
            # The user is logged in but is not an admin
82
            api_user_has_default_pool_access = not_(
1✔
83
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
84
                # record does NOT EXIST in the database
85
                select(schemas.RPUserORM.no_default_access)
86
                .where(
87
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
88
                )
89
                .exists()
90
            )
91
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
92
                or_(
93
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
94
                    and_(  # the pool is not default but is public
95
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
96
                    ),
97
                    and_(  # the pool is default and the user is not prohibited from accessing it
98
                        schemas.ResourcePoolORM.default == true(),
99
                        api_user_has_default_pool_access,
100
                    ),
101
                )
102
            )
103
        case True, True:
2✔
104
            # The user is logged in and is an admin, they can see all resource classes
105
            pass
2✔
106
        case False, _:
×
107
            # The user is not logged in, they can see only the classes from public resource pools
108
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
×
109
                schemas.ResourcePoolORM.public == true(),
110
            )
111
    return output
2✔
112

113

114
_P = ParamSpec("_P")
2✔
115
_T = TypeVar("_T")
2✔
116

117

118
def _only_admins(
2✔
119
    f: Callable[Concatenate[Any, _P], Coroutine[Any, Any, _T]],
120
) -> Callable[Concatenate[Any, _P], Coroutine[Any, Any, _T]]:
121
    """Decorator that errors out if the user is not an admin.
122

123
    It expects the APIUser model to be a named parameter in the decorated function or
124
    to be the first parameter (after self).
125
    """
126

127
    @wraps(f)
2✔
128
    async def decorated_function(self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _T:
2✔
129
        api_user = None
2✔
130
        if "api_user" in kwargs:
2✔
131
            api_user = kwargs["api_user"]
2✔
132
        elif len(args) >= 1:
2✔
133
            api_user = args[0]
2✔
134
        if api_user is not None and not isinstance(api_user, base_models.APIUser):
2✔
135
            raise errors.ProgrammingError(message="Expected user parameter is not of type APIUser.")
×
136
        if api_user is None:
2✔
137
            raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
×
138
        if not api_user.is_admin:
2✔
139
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
140

141
        # the user is authenticated and is an admin
142
        response = await f(self, *args, **kwargs)
2✔
143
        return response
2✔
144

145
    return decorated_function
2✔
146

147

148
class ResourcePoolRepository(_Base):
2✔
149
    """The adapter used for accessing resource pools with SQLAlchemy."""
150

151
    async def initialize(self, async_connection_url: str, rp: models.ResourcePool) -> None:
2✔
152
        """Add the default resource pool if it does not already exist."""
153
        engine = create_async_engine(async_connection_url, poolclass=NullPool)
×
154
        session_maker = async_sessionmaker(
×
155
            engine,
156
            expire_on_commit=True,
157
        )
158
        async with session_maker() as session, session.begin():
×
159
            stmt = select(schemas.ResourcePoolORM.default == true())
×
160
            res = await session.execute(stmt)
×
161
            default_rp = res.scalars().first()
×
162
            if default_rp is None:
×
163
                orm = schemas.ResourcePoolORM.load(rp)
×
164
                session.add(orm)
×
165

166
    async def get_resource_pools(
2✔
167
        self, api_user: base_models.APIUser, id: Optional[int] = None, name: Optional[str] = None
168
    ) -> list[models.ResourcePool]:
169
        """Get resource pools from database."""
170
        async with self.session_maker() as session:
2✔
171
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
172
            if name is not None:
2✔
173
                stmt = stmt.where(schemas.ResourcePoolORM.name == name)
1✔
174
            if id is not None:
2✔
175
                stmt = stmt.where(schemas.ResourcePoolORM.id == id)
2✔
176
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
177
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
178
            res = await session.execute(stmt)
2✔
179
            orms = res.scalars().all()
2✔
180
            output: list[models.ResourcePool] = []
2✔
181
            for rp in orms:
2✔
182
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
183
                output.append(rp.dump(quota))
1✔
184
            return output
2✔
185

186
    async def get_default_resource_class(self) -> models.ResourceClass:
2✔
187
        """Get the default resource class in the default resource pool."""
188
        async with self.session_maker() as session:
×
189
            stmt = (
×
190
                select(schemas.ResourceClassORM)
191
                .where(schemas.ResourceClassORM.default == true())
192
                .where(schemas.ResourceClassORM.resource_pool.has(schemas.ResourcePoolORM.default == true()))
193
            )
194
            res = await session.scalar(stmt)
×
195
            if res is None:
×
196
                raise errors.ProgrammingError(
×
197
                    message="Could not find the default class from the default resource pool, but this has to exist."
198
                )
199
            return res.dump()
×
200

201
    async def filter_resource_pools(
2✔
202
        self,
203
        api_user: base_models.APIUser,
204
        cpu: float = 0,
205
        memory: int = 0,
206
        max_storage: int = 0,
207
        gpu: int = 0,
208
    ) -> list[models.ResourcePool]:
209
        """Get resource pools from database with indication of which resource class matches the specified criteria."""
210
        async with self.session_maker() as session:
2✔
211
            criteria = models.ResourceClass(
2✔
212
                name="criteria",
213
                cpu=cpu,
214
                gpu=gpu,
215
                memory=memory,
216
                max_storage=max_storage,
217
                # NOTE: the default storage has to be <= max_storage but is not used for filtering classes,
218
                # only the max_storage is used to filter resource classes that match a request
219
                default_storage=max_storage,
220
            )
221
            stmt = (
2✔
222
                select(schemas.ResourcePoolORM)
223
                .distinct()
224
                .options(selectinload(schemas.ResourcePoolORM.classes))
225
                .order_by(
226
                    schemas.ResourcePoolORM.id,
227
                    schemas.ResourcePoolORM.name,
228
                )
229
            )
230
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
231
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
232
            res = await session.execute(stmt)
2✔
233
            return [i.dump(self.quotas_repo.get_quota(i.quota), criteria) for i in res.scalars().all()]
2✔
234

235
    @_only_admins
2✔
236
    async def insert_resource_pool(
2✔
237
        self, api_user: base_models.APIUser, resource_pool: models.ResourcePool
238
    ) -> models.ResourcePool:
239
        """Insert resource pool into database."""
240
        quota = None
2✔
241
        if resource_pool.quota:
2✔
242
            for rc in resource_pool.classes:
2✔
243
                if not resource_pool.quota.is_resource_class_compatible(rc):
2✔
244
                    raise errors.ValidationError(
×
245
                        message=f"The quota {quota} is not compatible with resource class {rc}"
246
                    )
247
            quota = self.quotas_repo.create_quota(resource_pool.quota)
2✔
248
            resource_pool = resource_pool.set_quota(quota)
2✔
249
        orm = schemas.ResourcePoolORM.load(resource_pool)
2✔
250
        async with self.session_maker() as session, session.begin():
2✔
251
            if orm.idle_threshold == 0:
2✔
252
                orm.idle_threshold = None
×
253
            if orm.hibernation_threshold == 0:
2✔
254
                orm.hibernation_threshold = None
×
255
            if orm.default:
2✔
256
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.default == true())
1✔
257
                res = await session.execute(stmt)
1✔
258
                default_rps = res.unique().scalars().all()
1✔
259
                if len(default_rps) >= 1:
1✔
260
                    raise errors.ValidationError(
×
261
                        message="There can only be one default resource pool and one already exists."
262
                    )
263
            session.add(orm)
2✔
264
        return orm.dump(quota)
2✔
265

266
    async def get_classes(
2✔
267
        self,
268
        api_user: Optional[base_models.APIUser] = None,
269
        id: Optional[int] = None,
270
        name: Optional[str] = None,
271
        resource_pool_id: Optional[int] = None,
272
    ) -> list[models.ResourceClass]:
273
        """Get classes from the database."""
274
        async with self.session_maker() as session:
2✔
275
            stmt = select(schemas.ResourceClassORM).join(
2✔
276
                schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool, isouter=True
277
            )
278
            if resource_pool_id is not None:
2✔
279
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
280
            if id is not None:
2✔
281
                stmt = stmt.where(schemas.ResourceClassORM.id == id)
2✔
282
            if name is not None:
2✔
283
                stmt = stmt.where(schemas.ResourceClassORM.name == name)
2✔
284

285
            # Apply user access control if api_user is provided
286
            if api_user is not None:
2✔
287
                # NOTE: The line below ensures that the right users can access the right resources, do not remove.
288
                stmt = _classes_user_access_control(api_user, stmt)
2✔
289

290
            res = await session.execute(stmt)
2✔
291
            orms = res.scalars().all()
2✔
292
            return [orm.dump() for orm in orms]
2✔
293

294
    async def get_resource_class(self, api_user: base_models.APIUser, id: int) -> models.ResourceClass:
2✔
295
        """Get a specific resource class by its ID."""
296
        classes = await self.get_classes(api_user, id)
×
297
        if len(classes) == 0:
×
298
            raise errors.MissingResourceError(message=f"The resource class with ID {id} cannot be found", quiet=True)
×
299
        return classes[0]
×
300

301
    @_only_admins
2✔
302
    async def insert_resource_class(
2✔
303
        self,
304
        api_user: base_models.APIUser,
305
        resource_class: models.ResourceClass,
306
        *,
307
        resource_pool_id: Optional[int] = None,
308
    ) -> models.ResourceClass:
309
        """Insert a resource class in the database."""
310
        cls = schemas.ResourceClassORM.load(resource_class)
2✔
311
        async with self.session_maker() as session, session.begin():
2✔
312
            if resource_pool_id is not None:
2✔
313
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
314
                res = await session.execute(stmt)
2✔
315
                rp = res.scalars().first()
2✔
316
                if rp is None:
2✔
317
                    raise errors.MissingResourceError(
2✔
318
                        message=f"Resource pool with id {resource_pool_id} does not exist."
319
                    )
320
                if cls.default and len(rp.classes) > 0 and any([icls.default for icls in rp.classes]):
1✔
321
                    raise errors.ValidationError(
×
322
                        message="There can only be one default resource class per resource pool."
323
                    )
324
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
325
                if quota and not quota.is_resource_class_compatible(resource_class):
1✔
326
                    raise errors.ValidationError(
×
327
                        message="The resource class {resource_class} is not compatible with the quota {quota}."
328
                    )
329
                cls.resource_pool = rp
1✔
330
                cls.resource_pool_id = rp.id
1✔
331

332
            session.add(cls)
1✔
333
        return cls.dump()
1✔
334

335
    @_only_admins
2✔
336
    async def update_resource_pool(self, api_user: base_models.APIUser, id: int, **kwargs: Any) -> models.ResourcePool:
2✔
337
        """Update an existing resource pool in the database."""
338
        rp: Optional[schemas.ResourcePoolORM] = None
2✔
339
        async with self.session_maker() as session, session.begin():
2✔
340
            stmt = (
2✔
341
                select(schemas.ResourcePoolORM)
342
                .where(schemas.ResourcePoolORM.id == id)
343
                .options(selectinload(schemas.ResourcePoolORM.classes))
344
            )
345
            res = await session.execute(stmt)
2✔
346
            rp = res.scalars().first()
2✔
347
            if rp is None:
2✔
348
                raise errors.MissingResourceError(message=f"Resource pool with id {id} cannot be found")
2✔
349
            quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
350
            if len(kwargs) == 0:
1✔
351
                return rp.dump(quota)
×
352

353
            if kwargs.get("idle_threshold") == 0:
1✔
354
                kwargs["idle_threshold"] = None
×
355
            if kwargs.get("hibernation_threshold") == 0:
1✔
356
                kwargs["hibernation_threshold"] = None
×
357
            # NOTE: The .update method on the model validates the update to the resource pool
358
            old_rp_model = rp.dump(quota)
1✔
359
            new_rp_model = old_rp_model.update(**kwargs)
1✔
360
            new_classes = None
1✔
361
            new_classes_coroutines = []
1✔
362
            for key, val in kwargs.items():
1✔
363
                match key:
1✔
364
                    case "name" | "public" | "default" | "idle_threshold" | "hibernation_threshold":
1✔
365
                        setattr(rp, key, val)
1✔
366
                    case "quota":
1✔
367
                        if val is None:
1✔
368
                            continue
×
369

370
                        # For updating a quota, there are two options:
371
                        # 1. no quota exists --> create a new one
372
                        # 2. a quota exists and can only be updated, not replaced (the ids, if provided, must match)
373

374
                        new_id = val.get("id")
1✔
375

376
                        if quota and quota.id is not None and new_id is not None and quota.id != new_id:
1✔
377
                            raise errors.ValidationError(
×
378
                                message="The ID of an existing quota cannot be updated, "
379
                                f"please remove the ID field from the request or use ID {quota.id}."
380
                            )
381

382
                        # the id must match for update
383
                        if quota:
1✔
384
                            val["id"] = quota.id or new_id
1✔
385

386
                        new_quota = models.Quota.from_dict(val)
1✔
387

388
                        if new_id or quota:
1✔
389
                            new_quota = self.quotas_repo.update_quota(new_quota)
1✔
390
                        else:
391
                            new_quota = self.quotas_repo.create_quota(new_quota)
×
392
                        rp.quota = new_quota.id
1✔
393
                        new_rp_model = new_rp_model.update(quota=new_quota)
1✔
394
                    case "classes":
1✔
395
                        new_classes = []
1✔
396
                        for cls in val:
1✔
397
                            class_id = cls.pop("id")
1✔
398
                            cls.pop("matching", None)
1✔
399
                            if len(cls) == 0:
1✔
400
                                raise errors.ValidationError(
×
401
                                    message="More fields than the id of the class "
402
                                    "should be provided when updating it"
403
                                )
404
                            new_classes_coroutines.append(
1✔
405
                                self.update_resource_class(
406
                                    api_user, resource_pool_id=id, resource_class_id=class_id, **cls
407
                                )
408
                            )
409
                    case _:
×
410
                        pass
×
411
            new_classes = await gather(*new_classes_coroutines)
1✔
412
            if new_classes is not None and len(new_classes) > 0:
1✔
413
                new_rp_model = new_rp_model.update(classes=new_classes)
1✔
414
            return new_rp_model
1✔
415

416
    @_only_admins
2✔
417
    async def delete_resource_pool(self, api_user: base_models.APIUser, id: int) -> None:
2✔
418
        """Delete a resource pool from the database."""
419
        async with self.session_maker() as session, session.begin():
2✔
420
            stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == id)
2✔
421
            res = await session.execute(stmt)
2✔
422
            rp = res.scalars().first()
2✔
423
            if rp is not None:
2✔
424
                if rp.default:
1✔
425
                    raise errors.ValidationError(message="The default resource pool cannot be deleted.")
×
426
                await session.delete(rp)
1✔
427
                if rp.quota:
1✔
428
                    self.quotas_repo.delete_quota(rp.quota)
1✔
429
            return None
2✔
430

431
    @_only_admins
2✔
432
    async def delete_resource_class(
2✔
433
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int
434
    ) -> None:
435
        """Delete a specific resource class."""
436
        async with self.session_maker() as session, session.begin():
1✔
437
            stmt = (
1✔
438
                select(schemas.ResourceClassORM)
439
                .where(schemas.ResourceClassORM.id == resource_class_id)
440
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
441
            )
442
            res = await session.execute(stmt)
1✔
443
            cls = res.scalars().first()
1✔
444
            if cls is not None:
1✔
445
                if cls.default:
1✔
446
                    raise errors.ValidationError(message="The default resource class cannot be deleted.")
×
447
                await session.delete(cls)
1✔
448

449
    @_only_admins
2✔
450
    async def update_resource_class(
2✔
451
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int, **kwargs: Any
452
    ) -> models.ResourceClass:
453
        """Update a specific resource class."""
454
        async with self.session_maker() as session, session.begin():
1✔
455
            stmt = (
1✔
456
                select(schemas.ResourceClassORM)
457
                .where(schemas.ResourceClassORM.id == resource_class_id)
458
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
459
                .join(schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool)
460
                .options(selectinload(schemas.ResourceClassORM.resource_pool))
461
            )
462
            res = await session.execute(stmt)
1✔
463
            cls: Optional[schemas.ResourceClassORM] = res.scalars().first()
1✔
464
            if cls is None:
1✔
465
                raise errors.MissingResourceError(
×
466
                    message=(
467
                        f"The resource class with id {resource_class_id} does not exist, the resource pool with "
468
                        f"id {resource_pool_id} does not exist or the requested resource class is not "
469
                        "associated with the resource pool"
470
                    )
471
                )
472
            for k, v in kwargs.items():
1✔
473
                match k:
1✔
474
                    case "node_affinities":
1✔
475
                        v = cast(list[dict[str, str | bool]], v)
1✔
476
                        existing_affinities: dict[str, schemas.NodeAffintyORM] = {i.key: i for i in cls.node_affinities}
1✔
477
                        new_affinities: dict[str, schemas.NodeAffintyORM] = {
1✔
478
                            i["key"]: schemas.NodeAffintyORM(**i) for i in v
479
                        }
480
                        for new_affinity_key, new_affinity in new_affinities.items():
1✔
481
                            if new_affinity_key in existing_affinities:
1✔
482
                                # UPDATE existing affinity
483
                                existing_affinity = existing_affinities[new_affinity_key]
1✔
484
                                if (
1✔
485
                                    new_affinity.required_during_scheduling
486
                                    != existing_affinity.required_during_scheduling
487
                                ):
488
                                    existing_affinity.required_during_scheduling = (
1✔
489
                                        new_affinity.required_during_scheduling
490
                                    )
491
                            else:
492
                                # CREATE a brand new affinity
493
                                cls.node_affinities.append(new_affinity)
1✔
494
                        # REMOVE an affinity
495
                        for existing_affinity_key, existing_affinity in existing_affinities.items():
1✔
496
                            if existing_affinity_key not in new_affinities:
1✔
497
                                cls.node_affinities.remove(existing_affinity)
1✔
498
                    case "tolerations":
1✔
499
                        v = cast(list[str], v)
1✔
500
                        existing_tolerations: dict[str, schemas.TolerationORM] = {
1✔
501
                            tol.key: tol for tol in cls.tolerations
502
                        }
503
                        new_tolerations: dict[str, schemas.TolerationORM] = {
1✔
504
                            tol: schemas.TolerationORM(key=tol) for tol in v
505
                        }
506
                        for new_tol_key, new_tol in new_tolerations.items():
1✔
507
                            if new_tol_key not in existing_tolerations:
1✔
508
                                # CREATE a brand new toleration
509
                                cls.tolerations.append(new_tol)
1✔
510
                        # REMOVE a toleration
511
                        for existing_tol_key, existing_tol in existing_tolerations.items():
1✔
512
                            if existing_tol_key not in new_tolerations:
1✔
513
                                cls.tolerations.remove(existing_tol)
1✔
514
                    case _:
1✔
515
                        setattr(cls, k, v)
1✔
516
            if cls.resource_pool is None:
1✔
517
                raise errors.BaseError(
×
518
                    message="Unexpected internal error.",
519
                    detail=f"The resource class {resource_class_id} is not associated with any resource pool.",
520
                )
521
            quota = self.quotas_repo.get_quota(cls.resource_pool.quota) if cls.resource_pool.quota else None
1✔
522
            cls_model = cls.dump()
1✔
523
            if quota and not quota.is_resource_class_compatible(cls_model):
1✔
524
                raise errors.ValidationError(
×
525
                    message=f"The resource class {cls_model} is not compatible with the quota {quota}"
526
                )
527
            return cls_model
1✔
528

529
    @_only_admins
2✔
530
    async def get_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> list[str]:
2✔
531
        """Get all tolerations of a resource class."""
532
        async with self.session_maker() as session:
1✔
533
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
534
            if len(res_classes) == 0:
1✔
535
                raise errors.MissingResourceError(
×
536
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
537
                    f"class with ID {class_id} do not exist, or they are not related."
538
                )
539
            stmt = select(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
540
            res = await session.execute(stmt)
1✔
541
            return [i.key for i in res.scalars().all()]
1✔
542

543
    @_only_admins
2✔
544
    async def delete_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
545
        """Delete all tolerations for a specific resource class."""
546
        async with self.session_maker() as session, session.begin():
2✔
547
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
2✔
548
            if len(res_classes) == 0:
2✔
549
                raise errors.MissingResourceError(
1✔
550
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
551
                    f"class with ID {class_id} do not exist, or they are not related."
552
                )
553
            stmt = delete(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
554
            await session.execute(stmt)
1✔
555

556
    @_only_admins
2✔
557
    async def get_affinities(
2✔
558
        self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int
559
    ) -> list[models.NodeAffinity]:
560
        """Get all affinities for a resource class."""
561
        async with self.session_maker() as session:
1✔
562
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
563
            if len(res_classes) == 0:
1✔
564
                raise errors.MissingResourceError(
×
565
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
566
                    f"class with ID {class_id} do not exist, or they are not related."
567
                )
568
            stmt = select(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
569
            res = await session.execute(stmt)
1✔
570
            return [i.dump() for i in res.scalars().all()]
1✔
571

572
    @_only_admins
2✔
573
    async def delete_affinities(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
574
        """Delete all affinities from a resource class."""
575
        async with self.session_maker() as session, session.begin():
1✔
576
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
577
            if len(res_classes) == 0:
1✔
578
                raise errors.MissingResourceError(
×
579
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
580
                    f"class with ID {class_id} do not exist, or they are not related."
581
                )
582
            stmt = delete(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
583
            await session.execute(stmt)
1✔
584

585

586
@dataclass
2✔
587
class RespositoryUsers:
2✔
588
    """Information about which users can access a specific resource pool."""
589

590
    resource_pool_id: int
2✔
591
    allowed: list[base_models.User] = field(default_factory=list)
2✔
592
    disallowed: list[base_models.User] = field(default_factory=list)
2✔
593

594

595
class UserRepository(_Base):
2✔
596
    """The adapter used for accessing resource pool users with SQLAlchemy."""
597

598
    def __init__(
2✔
599
        self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository, user_repo: UserRepo
600
    ) -> None:
601
        super().__init__(session_maker, quotas_repo)
2✔
602
        self.kc_user_repo = user_repo
2✔
603

604
    @_only_admins
2✔
605
    async def get_resource_pool_users(
2✔
606
        self,
607
        *,
608
        api_user: base_models.APIUser,
609
        resource_pool_id: int,
610
        keycloak_id: Optional[str] = None,
611
    ) -> RespositoryUsers:
612
        """Get users of a specific resource pool from the database."""
613
        async with self.session_maker() as session, session.begin():
2✔
614
            stmt = (
2✔
615
                select(schemas.ResourcePoolORM)
616
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
617
                .options(selectinload(schemas.ResourcePoolORM.users))
618
            )
619
            if keycloak_id is not None:
2✔
UNCOV
620
                stmt = stmt.join(schemas.ResourcePoolORM.users, isouter=True).where(
×
621
                    or_(
622
                        schemas.RPUserORM.keycloak_id == keycloak_id,
623
                        schemas.ResourcePoolORM.public == true(),
624
                        schemas.ResourceClassORM.default == true(),
625
                    )
626
                )
627
            res = await session.execute(stmt)
2✔
628
            rp = res.scalars().first()
2✔
629
            if rp is None:
2✔
630
                raise errors.MissingResourceError(message=f"Resource pool with id {resource_pool_id} does not exist")
1✔
631
            specific_user: base_models.User | None = None
1✔
632
            if keycloak_id:
1✔
633
                specific_user_res = (
×
634
                    await session.execute(select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id))
635
                ).scalar_one_or_none()
636
                specific_user = None if not specific_user_res else specific_user_res.dump()
×
637
            allowed: list[base_models.User] = []
1✔
638
            disallowed: list[base_models.User] = []
1✔
639
            if rp.default:
1✔
640
                disallowed_stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.no_default_access == true())
1✔
641
                if keycloak_id:
1✔
642
                    disallowed_stmt = disallowed_stmt.where(schemas.RPUserORM.keycloak_id == keycloak_id)
×
643
                disallowed_res = await session.execute(disallowed_stmt)
1✔
644
                disallowed = [user.dump() for user in disallowed_res.scalars().all()]
1✔
645
                if specific_user and specific_user not in disallowed:
1✔
646
                    allowed = [specific_user]
×
647
            elif rp.public and not rp.default:
1✔
648
                if specific_user:
×
649
                    allowed = [specific_user]
×
650
            elif not rp.public and not rp.default:
1✔
651
                allowed = [user.dump() for user in rp.users]
1✔
652
            return RespositoryUsers(rp.id, allowed, disallowed)
1✔
653

654
    async def get_user_resource_pools(
2✔
655
        self,
656
        api_user: base_models.APIUser,
657
        keycloak_id: str,
658
        resource_pool_id: Optional[int] = None,
659
        resource_pool_name: Optional[str] = None,
660
    ) -> list[models.ResourcePool]:
661
        """Get resource pools that a specific user has access to."""
662
        async with self.session_maker() as session, session.begin():
2✔
663
            if not api_user.is_admin and api_user.id != keycloak_id:
2✔
664
                raise errors.ValidationError(
×
665
                    message="Users cannot query for resource pools that belong to other users."
666
                )
667

668
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
669
            stmt = stmt.where(
2✔
670
                or_(
671
                    schemas.ResourcePoolORM.public == true(),
672
                    schemas.ResourcePoolORM.users.any(schemas.RPUserORM.keycloak_id == keycloak_id),
673
                )
674
            )
675
            if resource_pool_name is not None:
2✔
676
                stmt = stmt.where(schemas.ResourcePoolORM.name == resource_pool_name)
×
677
            if resource_pool_id is not None:
2✔
678
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
679
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
680
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
681
            res = await session.execute(stmt)
2✔
682
            rps: Sequence[schemas.ResourcePoolORM] = res.scalars().all()
2✔
683
            output: list[models.ResourcePool] = []
2✔
684
            for rp in rps:
2✔
685
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
2✔
686
                output.append(rp.dump(quota))
2✔
687
            return output
2✔
688

689
    @_only_admins
2✔
690
    async def update_user_resource_pools(
2✔
691
        self, api_user: base_models.APIUser, keycloak_id: str, resource_pool_ids: list[int], append: bool = True
692
    ) -> list[models.ResourcePool]:
693
        """Update the resource pools that a specific user has access to."""
694
        async with self.session_maker() as session, session.begin():
1✔
695
            kc_user = await self.kc_user_repo.get_user(keycloak_id)
1✔
696
            if kc_user is None:
1✔
697
                raise errors.MissingResourceError(message=f"The user with ID {keycloak_id} does not exist")
×
698
            stmt = (
1✔
699
                select(schemas.RPUserORM)
700
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
701
                .options(selectinload(schemas.RPUserORM.resource_pools))
702
            )
703
            res = await session.execute(stmt)
1✔
704
            user = res.scalars().first()
1✔
705
            if user is None:
1✔
706
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
707
                session.add(user)
1✔
708
            stmt_rp = (
1✔
709
                select(schemas.ResourcePoolORM)
710
                .where(schemas.ResourcePoolORM.id.in_(resource_pool_ids))
711
                .options(selectinload(schemas.ResourcePoolORM.classes))
712
            )
713
            if user.no_default_access:
1✔
714
                stmt_rp = stmt_rp.where(schemas.ResourcePoolORM.default == false())
×
715
            res_rp = await session.execute(stmt_rp)
1✔
716
            rps_to_add = res_rp.scalars().all()
1✔
717
            if len(rps_to_add) != len(resource_pool_ids):
1✔
718
                missing_rps = set(resource_pool_ids).difference(set([i.id for i in rps_to_add]))
×
719
                raise errors.MissingResourceError(
×
720
                    message=(
721
                        f"The resource pools with ids: {missing_rps} do not exist or user doesn't have access to "
722
                        "default resource pool."
723
                    )
724
                )
725
            if user.no_default_access:
1✔
726
                default_rp = next((rp for rp in rps_to_add if rp.default), None)
×
727
                if default_rp:
×
728
                    raise errors.ForbiddenError(
×
729
                        message=f"User with keycloak id {keycloak_id} cannot access the default resource pool"
730
                    )
731
            if append:
1✔
732
                user_rp_ids = {rp.id for rp in user.resource_pools}
1✔
733
                rps_to_add = [rp for rp in rps_to_add if rp.id not in user_rp_ids]
1✔
734
                user.resource_pools.extend(rps_to_add)
1✔
735
            else:
736
                user.resource_pools = list(rps_to_add)
1✔
737
            output: list[models.ResourcePool] = []
1✔
738
            for rp in rps_to_add:
1✔
739
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
740
                output.append(rp.dump(quota))
1✔
741
            return output
1✔
742

743
    @_only_admins
2✔
744
    async def delete_resource_pool_user(
2✔
745
        self, api_user: base_models.APIUser, resource_pool_id: int, keycloak_id: str
746
    ) -> None:
747
        """Remove a user from a specific resource pool."""
748
        async with self.session_maker() as session, session.begin():
1✔
749
            sub = (
1✔
750
                select(schemas.RPUserORM.id)
751
                .join(schemas.ResourcePoolORM, schemas.RPUserORM.resource_pools)
752
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
753
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
754
            )
755
            stmt = delete(schemas.resource_pools_users).where(schemas.resource_pools_users.c.user_id.in_(sub))
1✔
756
            await session.execute(stmt)
1✔
757

758
    @_only_admins
2✔
759
    async def update_resource_pool_users(
2✔
760
        self, api_user: base_models.APIUser, resource_pool_id: int, user_ids: Collection[str], append: bool = True
761
    ) -> list[base_models.User]:
762
        """Update the users to have access to a specific resource pool."""
763
        async with self.session_maker() as session, session.begin():
2✔
764
            stmt = (
2✔
765
                select(schemas.ResourcePoolORM)
766
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
767
                .options(
768
                    selectinload(schemas.ResourcePoolORM.users),
769
                    selectinload(schemas.ResourcePoolORM.classes),
770
                )
771
            )
772
            res = await session.execute(stmt)
2✔
773
            rp: Optional[schemas.ResourcePoolORM] = res.scalars().first()
2✔
774
            if rp is None:
2✔
775
                raise errors.MissingResourceError(
1✔
776
                    message=f"The resource pool with id {resource_pool_id} does not exist"
777
                )
778
            if rp.default:
1✔
779
                # NOTE: If the resource pool is default just check if any users are prevented from having
780
                # default resource pool access - and remove the restriction.
781
                all_existing_users = await self.get_resource_pool_users(
1✔
782
                    api_user=api_user, resource_pool_id=resource_pool_id
783
                )
784
                users_to_modify = [user for user in all_existing_users.disallowed if user.keycloak_id in user_ids]
1✔
785
                return await gather(
1✔
786
                    *[
787
                        self.update_user(
788
                            api_user=api_user, keycloak_id=no_default_user.keycloak_id, no_default_access=False
789
                        )
790
                        for no_default_user in users_to_modify
791
                    ]
792
                )
793
            stmt_usr = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id.in_(user_ids))
1✔
794
            res_usr = await session.execute(stmt_usr)
1✔
795
            users_to_add_exist = res_usr.scalars().all()
1✔
796
            user_ids_to_add_exist = [i.keycloak_id for i in users_to_add_exist]
1✔
797
            users_to_add_missing = [
1✔
798
                schemas.RPUserORM(keycloak_id=user_id) for user_id in user_ids if user_id not in user_ids_to_add_exist
799
            ]
800
            if append:
1✔
801
                rp_user_ids = {rp.id for rp in rp.users}
1✔
802
                users_to_add = [u for u in list(users_to_add_exist) + users_to_add_missing if u.id not in rp_user_ids]
1✔
803
                rp.users.extend(users_to_add)
1✔
804
            else:
805
                rp.users = list(users_to_add_exist) + users_to_add_missing
1✔
806
            return [usr.dump() for usr in rp.users]
1✔
807

808
    @_only_admins
2✔
809
    async def update_user(self, api_user: base_models.APIUser, keycloak_id: str, **kwargs: Any) -> base_models.User:
2✔
810
        """Update a specific user."""
811
        async with self.session_maker() as session, session.begin():
1✔
812
            stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id)
1✔
813
            res = await session.execute(stmt)
1✔
814
            user: Optional[schemas.RPUserORM] = res.scalars().first()
1✔
815
            if not user:
1✔
816
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
817
                session.add(user)
1✔
818
            allowed_updates = set(["no_default_access"])
1✔
819
            if not set(kwargs.keys()).issubset(allowed_updates):
1✔
820
                raise errors.ValidationError(
×
821
                    message=f"Only the following fields {allowed_updates} " "can be updated for a resource pool user.."
822
                )
823
            if (no_default_access := kwargs.get("no_default_access")) is not None:
1✔
824
                user.no_default_access = no_default_access
1✔
825
            return user.dump()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc