• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 10353957042

12 Aug 2024 02:35PM UTC coverage: 90.758% (+0.4%) from 90.398%
10353957042

Pull #338

github

web-flow
Merge 3a49eb6c2 into 8afb94949
Pull Request #338: feat!: expand environments specification

227 of 237 new or added lines in 7 files covered. (95.78%)

48 existing lines in 9 files now uncovered.

9202 of 10139 relevant lines covered (90.76%)

1.61 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.81
/components/renku_data_services/crc/db.py
1
"""Adapter based on SQLAlchemy.
2✔
2

3
These adapters currently do a few things (1) generate SQL queries, (2) apply resource access controls,
4
(3) fetch the SQL results and (4) format them into a workable representation. In the future and as the code
5
grows it is worth looking into separating this functionality into separate classes rather than having
6
it all in one place.
7
"""
8

9
from asyncio import gather
2✔
10
from collections.abc import Awaitable, Callable, Collection, Sequence
2✔
11
from dataclasses import dataclass, field
2✔
12
from functools import wraps
2✔
13
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar, cast
2✔
14

15
from sqlalchemy import NullPool, create_engine, delete, select
2✔
16
from sqlalchemy.ext.asyncio import AsyncSession
2✔
17
from sqlalchemy.orm import Session, selectinload, sessionmaker
2✔
18
from sqlalchemy.sql import Select, and_, not_, or_
2✔
19
from sqlalchemy.sql.expression import false, true
2✔
20

21
import renku_data_services.base_models as base_models
2✔
22
from renku_data_services import errors
2✔
23
from renku_data_services.crc import models
2✔
24
from renku_data_services.crc import orm as schemas
2✔
25
from renku_data_services.k8s.quota import QuotaRepository
2✔
26
from renku_data_services.users.db import UserRepo
2✔
27

28

29
class _Base:
2✔
30
    def __init__(self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository) -> None:
2✔
31
        self.session_maker = session_maker
2✔
32
        self.quotas_repo = quotas_repo
2✔
33

34

35
def _resource_pool_access_control(
2✔
36
    api_user: base_models.APIUser,
37
    stmt: Select[tuple[schemas.ResourcePoolORM]],
38
) -> Select[tuple[schemas.ResourcePoolORM]]:
39
    """Modifies a select query to list resource pools based on whether the user is logged in or not."""
40
    output = stmt
2✔
41
    match (api_user.is_authenticated, api_user.is_admin):
2✔
42
        case True, False:
2✔
43
            # The user is logged in but not an admin
44
            api_user_has_default_pool_access = not_(
1✔
45
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
46
                # record does NOT EXIST in the database
47
                select(schemas.RPUserORM.no_default_access)
48
                .where(
49
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
50
                )
51
                .exists()
52
            )
53
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
54
                or_(
55
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
56
                    and_(  # the pool is not default but is public
57
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
58
                    ),
59
                    and_(  # the pool is default and the user is not prohibited from accessing it
60
                        schemas.ResourcePoolORM.default == true(),
61
                        api_user_has_default_pool_access,
62
                    ),
63
                )
64
            )
65
        case True, True:
2✔
66
            # The user is logged in and is an admin, they can see all resource pools
67
            pass
2✔
68
        case False, _:
1✔
69
            # The user is not logged in, they can see only the public resource pools
70
            output = output.where(schemas.ResourcePoolORM.public == true())
1✔
71
    return output
2✔
72

73

74
def _classes_user_access_control(
2✔
75
    api_user: base_models.APIUser,
76
    stmt: Select[tuple[schemas.ResourceClassORM]],
77
) -> Select[tuple[schemas.ResourceClassORM]]:
78
    """Adjust the select statement for classes based on whether the user is logged in or not."""
79
    output = stmt
2✔
80
    match (api_user.is_authenticated, api_user.is_admin):
2✔
81
        case True, False:
2✔
82
            # The user is logged in but is not an admin
83
            api_user_has_default_pool_access = not_(
1✔
84
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
85
                # record does NOT EXIST in the database
86
                select(schemas.RPUserORM.no_default_access)
87
                .where(
88
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
89
                )
90
                .exists()
91
            )
92
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
93
                or_(
94
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
95
                    and_(  # the pool is not default but is public
96
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
97
                    ),
98
                    and_(  # the pool is default and the user is not prohibited from accessing it
99
                        schemas.ResourcePoolORM.default == true(),
100
                        api_user_has_default_pool_access,
101
                    ),
102
                )
103
            )
104
        case True, True:
2✔
105
            # The user is logged in and is an admin, they can see all resource classes
106
            pass
2✔
107
        case False, _:
×
108
            # The user is not logged in, they can see only the classes from public resource pools
109
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
×
110
                schemas.ResourcePoolORM.public == true(),
111
            )
112
    return output
2✔
113

114

115
_P = ParamSpec("_P")
2✔
116
_T = TypeVar("_T")
2✔
117

118

119
def _only_admins(f: Callable[Concatenate[Any, _P], Awaitable[_T]]) -> Callable[Concatenate[Any, _P], Awaitable[_T]]:
2✔
120
    """Decorator that errors out if the user is not an admin.
121

122
    It expects the APIUser model to be a named parameter in the decorated function or
123
    to be the first parameter (after self).
124
    """
125

126
    @wraps(f)
2✔
127
    async def decorated_function(self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _T:
2✔
128
        api_user = None
2✔
129
        if "api_user" in kwargs:
2✔
130
            api_user = kwargs["api_user"]
2✔
131
        elif len(args) >= 1:
1✔
132
            api_user = args[0]
1✔
133
        if api_user is not None and not isinstance(api_user, base_models.APIUser):
2✔
134
            raise errors.ProgrammingError(message="Expected user parameter is not of type APIUser.")
×
135
        if api_user is None:
2✔
136
            raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
×
137
        if not api_user.is_admin:
2✔
138
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
139

140
        # the user is authenticated and is an admin
141
        response = await f(self, *args, **kwargs)
2✔
142
        return response
2✔
143

144
    return decorated_function
2✔
145

146

147
class ResourcePoolRepository(_Base):
2✔
148
    """The adapter used for accessing resource pools with SQLAlchemy."""
2✔
149

150
    def initialize(self, sync_connection_url: str, rp: models.ResourcePool) -> None:
2✔
151
        """Add the default resource pool if it does not already exist."""
152
        engine = create_engine(sync_connection_url, poolclass=NullPool)
×
153
        session_maker = sessionmaker(
×
154
            engine,
155
            class_=Session,
156
            expire_on_commit=True,
157
        )
158
        with session_maker() as session, session.begin():
×
159
            stmt = select(schemas.ResourcePoolORM.default == true())
×
160
            res = session.execute(stmt)
×
161
            default_rp = res.scalars().first()
×
162
            if default_rp is None:
×
163
                orm = schemas.ResourcePoolORM.load(rp)
×
164
                session.add(orm)
×
165

166
    async def get_resource_pools(
2✔
167
        self, api_user: base_models.APIUser, id: Optional[int] = None, name: Optional[str] = None
168
    ) -> list[models.ResourcePool]:
169
        """Get resource pools from database."""
170
        async with self.session_maker() as session:
2✔
171
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
172
            if name is not None:
2✔
173
                stmt = stmt.where(schemas.ResourcePoolORM.name == name)
1✔
174
            if id is not None:
2✔
175
                stmt = stmt.where(schemas.ResourcePoolORM.id == id)
2✔
176
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
177
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
178
            res = await session.execute(stmt)
2✔
179
            orms = res.scalars().all()
2✔
180
            output: list[models.ResourcePool] = []
2✔
181
            for rp in orms:
2✔
182
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
183
                output.append(rp.dump(quota))
1✔
184
            return output
2✔
185

186
    async def filter_resource_pools(
2✔
187
        self,
188
        api_user: base_models.APIUser,
189
        cpu: float = 0,
190
        memory: int = 0,
191
        max_storage: int = 0,
192
        gpu: int = 0,
193
    ) -> list[models.ResourcePool]:
194
        """Get resource pools from database with indication of which resource class matches the specified crtieria."""
195
        async with self.session_maker() as session:
2✔
196
            criteria = models.ResourceClass(
2✔
197
                name="criteria",
198
                cpu=cpu,
199
                gpu=gpu,
200
                memory=memory,
201
                max_storage=max_storage,
202
                # NOTE: the default storage has to be <= max_storage but is not used for filtering classes,
203
                # only the max_storage is used to filter resource classes that match a request
204
                default_storage=max_storage,
205
            )
206
            stmt = (
2✔
207
                select(schemas.ResourcePoolORM)
208
                .join(schemas.ResourcePoolORM.classes)
209
                .order_by(
210
                    schemas.ResourcePoolORM.id,
211
                    schemas.ResourcePoolORM.name,
212
                    schemas.ResourceClassORM.id,
213
                    schemas.ResourceClassORM.name,
214
                )
215
            )
216
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
217
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
218
            res = await session.execute(stmt)
2✔
219
            return [i.dump(self.quotas_repo.get_quota(i.quota), criteria) for i in res.unique().scalars().all()]
2✔
220

221
    @_only_admins
2✔
222
    async def insert_resource_pool(
2✔
223
        self, api_user: base_models.APIUser, resource_pool: models.ResourcePool
224
    ) -> models.ResourcePool:
225
        """Insert resource pool into database."""
226
        quota = None
2✔
227
        if resource_pool.quota:
2✔
228
            for rc in resource_pool.classes:
2✔
229
                if not resource_pool.quota.is_resource_class_compatible(rc):
2✔
230
                    raise errors.ValidationError(
×
231
                        message=f"The quota {quota} is not compatible with resource class {rc}"
232
                    )
233
            quota = self.quotas_repo.create_quota(resource_pool.quota)
2✔
234
            resource_pool = resource_pool.set_quota(quota)
2✔
235
        orm = schemas.ResourcePoolORM.load(resource_pool)
2✔
236
        async with self.session_maker() as session, session.begin():
2✔
237
            if orm.idle_threshold == 0:
2✔
238
                orm.idle_threshold = None
×
239
            if orm.hibernation_threshold == 0:
2✔
240
                orm.hibernation_threshold = None
×
241
            if orm.default:
2✔
242
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.default == true())
1✔
243
                res = await session.execute(stmt)
1✔
244
                default_rps = res.unique().scalars().all()
1✔
245
                if len(default_rps) >= 1:
1✔
246
                    raise errors.ValidationError(
×
247
                        message="There can only be one default resource pool and one already exists."
248
                    )
249
            session.add(orm)
2✔
250
        return orm.dump(quota)
2✔
251

252
    async def get_classes(
2✔
253
        self,
254
        api_user: Optional[base_models.APIUser] = None,
255
        id: Optional[int] = None,
256
        name: Optional[str] = None,
257
        resource_pool_id: Optional[int] = None,
258
    ) -> list[models.ResourceClass]:
259
        """Get classes from the database."""
260
        async with self.session_maker() as session:
2✔
261
            stmt = select(schemas.ResourceClassORM).join(
2✔
262
                schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool, isouter=True
263
            )
264
            if resource_pool_id is not None:
2✔
265
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
266
            if id is not None:
2✔
267
                stmt = stmt.where(schemas.ResourceClassORM.id == id)
1✔
268
            if name is not None:
2✔
269
                stmt = stmt.where(schemas.ResourceClassORM.name == name)
2✔
270

271
            # Apply user access control if api_user is provided
272
            if api_user is not None:
2✔
273
                # NOTE: The line below ensures that the right users can access the right resources, do not remove.
274
                stmt = _classes_user_access_control(api_user, stmt)
2✔
275

276
            res = await session.execute(stmt)
2✔
277
            orms = res.scalars().all()
2✔
278
            return [orm.dump() for orm in orms]
2✔
279

280
    @_only_admins
2✔
281
    async def insert_resource_class(
2✔
282
        self,
283
        api_user: base_models.APIUser,
284
        resource_class: models.ResourceClass,
285
        *,
286
        resource_pool_id: Optional[int] = None,
287
    ) -> models.ResourceClass:
288
        """Insert a resource class in the database."""
289
        cls = schemas.ResourceClassORM.load(resource_class)
2✔
290
        async with self.session_maker() as session, session.begin():
2✔
291
            if resource_pool_id is not None:
2✔
292
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == resource_pool_id)
2✔
293
                res = await session.execute(stmt)
2✔
294
                rp = res.scalars().first()
2✔
295
                if rp is None:
2✔
296
                    raise errors.MissingResourceError(
2✔
297
                        message=f"Resource pool with id {resource_pool_id} does not exist."
298
                    )
299
                if cls.default and len(rp.classes) > 0 and any([icls.default for icls in rp.classes]):
1✔
300
                    raise errors.ValidationError(
×
301
                        message="There can only be one default resource class per resource pool."
302
                    )
303
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
304
                if quota and not quota.is_resource_class_compatible(resource_class):
1✔
305
                    raise errors.ValidationError(
×
306
                        message="The resource class {resource_class} is not compatible with the quota {quota}."
307
                    )
308
                cls.resource_pool = rp
1✔
309
                cls.resource_pool_id = rp.id
1✔
310

311
            session.add(cls)
1✔
312
        return cls.dump()
1✔
313

314
    @_only_admins
2✔
315
    async def update_resource_pool(self, api_user: base_models.APIUser, id: int, **kwargs: Any) -> models.ResourcePool:
2✔
316
        """Update an existing resource pool in the database."""
317
        rp: Optional[schemas.ResourcePoolORM] = None
2✔
318
        async with self.session_maker() as session, session.begin():
2✔
319
            stmt = (
2✔
320
                select(schemas.ResourcePoolORM)
321
                .where(schemas.ResourcePoolORM.id == id)
322
                .options(selectinload(schemas.ResourcePoolORM.classes))
323
            )
324
            res = await session.execute(stmt)
2✔
325
            rp = res.scalars().first()
2✔
326
            if rp is None:
2✔
327
                raise errors.MissingResourceError(message=f"Resource pool with id {id} cannot be found")
2✔
328
            quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
329
            if len(kwargs) == 0:
1✔
330
                return rp.dump(quota)
×
331

332
            if kwargs.get("idle_threshold", None) == 0:
1✔
333
                kwargs["idle_threshold"] = None
×
334
            if kwargs.get("hibernation_threshold", None) == 0:
1✔
335
                kwargs["hibernation_threshold"] = None
×
336
            # NOTE: The .update method on the model validates the update to the resource pool
337
            old_rp_model = rp.dump(quota)
1✔
338
            new_rp_model = old_rp_model.update(**kwargs)
1✔
339
            new_classes = None
1✔
340
            new_classes_coroutines = []
1✔
341
            for key, val in kwargs.items():
1✔
342
                match key:
1✔
343
                    case "name" | "public" | "default" | "idle_threshold" | "hibernation_threshold":
1✔
344
                        setattr(rp, key, val)
1✔
345
                    case "quota":
1✔
346
                        if val is None:
1✔
347
                            continue
×
348

349
                        # For updating a quota, there are two options:
350
                        # 1. no quota exists --> create a new one
351
                        # 2. a quota exists and can only be updated, not replaced (the ids, if provided, must match)
352

353
                        new_id = val.get("id")
1✔
354

355
                        if quota and quota.id is not None and new_id is not None and quota.id != new_id:
1✔
356
                            raise errors.ValidationError(
×
357
                                message="The ID of an existing quota cannot be updated, "
358
                                f"please remove the ID field from the request or use ID {quota.id}."
359
                            )
360

361
                        # the id must match for update
362
                        if quota:
1✔
363
                            val["id"] = quota.id or new_id
1✔
364

365
                        new_quota = models.Quota.from_dict(val)
1✔
366

367
                        if new_id or quota:
1✔
368
                            new_quota = self.quotas_repo.update_quota(new_quota)
1✔
369
                        else:
370
                            new_quota = self.quotas_repo.create_quota(new_quota)
×
371
                        rp.quota = new_quota.id
1✔
372
                        new_rp_model = new_rp_model.update(quota=new_quota)
1✔
373
                    case "classes":
1✔
374
                        new_classes = []
1✔
375
                        for cls in val:
1✔
376
                            class_id = cls.pop("id")
1✔
377
                            cls.pop("matching", None)
1✔
378
                            if len(cls) == 0:
1✔
379
                                raise errors.ValidationError(
×
380
                                    message="More fields than the id of the class "
381
                                    "should be provided when updating it"
382
                                )
383
                            new_classes_coroutines.append(
1✔
384
                                self.update_resource_class(
385
                                    api_user, resource_pool_id=id, resource_class_id=class_id, **cls
386
                                )
387
                            )
388
                    case _:
×
389
                        pass
×
390
            new_classes = await gather(*new_classes_coroutines)
1✔
391
            if new_classes is not None and len(new_classes) > 0:
1✔
392
                new_rp_model = new_rp_model.update(classes=new_classes)
1✔
393
            return new_rp_model
1✔
394

395
    @_only_admins
2✔
396
    async def delete_resource_pool(self, api_user: base_models.APIUser, id: int) -> Optional[models.ResourcePool]:
2✔
397
        """Delete a resource pool from the database."""
398
        async with self.session_maker() as session, session.begin():
2✔
399
            stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == id)
2✔
400
            res = await session.execute(stmt)
2✔
401
            rp = res.scalars().first()
2✔
402
            if rp is not None:
2✔
403
                if rp.default:
1✔
404
                    raise errors.ValidationError(message="The default resource pool cannot be deleted.")
×
405
                await session.delete(rp)
1✔
406
                quota = None
1✔
407
                if rp.quota:
1✔
408
                    quota = self.quotas_repo.get_quota(rp.quota)
1✔
409
                    self.quotas_repo.delete_quota(rp.quota)
1✔
410
                return rp.dump(quota)
1✔
411
            return None
1✔
412

413
    @_only_admins
2✔
414
    async def delete_resource_class(
2✔
415
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int
416
    ) -> None:
417
        """Delete a specific resource class."""
418
        async with self.session_maker() as session, session.begin():
1✔
419
            stmt = (
1✔
420
                select(schemas.ResourceClassORM)
421
                .where(schemas.ResourceClassORM.id == resource_class_id)
422
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
423
            )
424
            res = await session.execute(stmt)
1✔
425
            cls = res.scalars().first()
1✔
426
            if cls is not None:
1✔
427
                if cls.default:
1✔
428
                    raise errors.ValidationError(message="The default resource class cannot be deleted.")
×
429
                await session.delete(cls)
1✔
430

431
    @_only_admins
2✔
432
    async def update_resource_class(
2✔
433
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int, **kwargs: Any
434
    ) -> models.ResourceClass:
435
        """Update a specific resource class."""
436
        async with self.session_maker() as session, session.begin():
1✔
437
            stmt = (
1✔
438
                select(schemas.ResourceClassORM)
439
                .where(schemas.ResourceClassORM.id == resource_class_id)
440
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
441
                .join(schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool)
442
                .options(selectinload(schemas.ResourceClassORM.resource_pool))
443
            )
444
            res = await session.execute(stmt)
1✔
445
            cls: Optional[schemas.ResourceClassORM] = res.scalars().first()
1✔
446
            if cls is None:
1✔
447
                raise errors.MissingResourceError(
×
448
                    message=(
449
                        f"The resource class with id {resource_class_id} does not exist, the resource pool with "
450
                        f"id {resource_pool_id} does not exist or the requested resource class is not "
451
                        "associated with the resource pool"
452
                    )
453
                )
454
            for k, v in kwargs.items():
1✔
455
                match k:
1✔
456
                    case "node_affinities":
1✔
457
                        v = cast(list[dict[str, str | bool]], v)
1✔
458
                        existing_affinities: dict[str, schemas.NodeAffintyORM] = {i.key: i for i in cls.node_affinities}
1✔
459
                        new_affinities: dict[str, schemas.NodeAffintyORM] = {
1✔
460
                            i["key"]: schemas.NodeAffintyORM(**i) for i in v
461
                        }
462
                        for new_affinity_key, new_affinity in new_affinities.items():
1✔
463
                            if new_affinity_key in existing_affinities:
1✔
464
                                # UPDATE existing affinity
465
                                existing_affinity = existing_affinities[new_affinity_key]
1✔
466
                                if (
1✔
467
                                    new_affinity.required_during_scheduling
468
                                    != existing_affinity.required_during_scheduling
469
                                ):
470
                                    existing_affinity.required_during_scheduling = (
1✔
471
                                        new_affinity.required_during_scheduling
472
                                    )
473
                            else:
474
                                # CREATE a brand new affinity
475
                                cls.node_affinities.append(new_affinity)
1✔
476
                        # REMOVE an affinity
477
                        for existing_affinity_key, existing_affinity in existing_affinities.items():
1✔
478
                            if existing_affinity_key not in new_affinities:
1✔
479
                                cls.node_affinities.remove(existing_affinity)
1✔
480
                    case "tolerations":
1✔
481
                        v = cast(list[str], v)
1✔
482
                        existing_tolerations: dict[str, schemas.TolerationORM] = {
1✔
483
                            tol.key: tol for tol in cls.tolerations
484
                        }
485
                        new_tolerations: dict[str, schemas.TolerationORM] = {
1✔
486
                            tol: schemas.TolerationORM(key=tol) for tol in v
487
                        }
488
                        for new_tol_key, new_tol in new_tolerations.items():
1✔
489
                            if new_tol_key not in existing_tolerations:
1✔
490
                                # CREATE a brand new toleration
491
                                cls.tolerations.append(new_tol)
1✔
492
                        # REMOVE a toleration
493
                        for existing_tol_key, existing_tol in existing_tolerations.items():
1✔
494
                            if existing_tol_key not in new_tolerations:
1✔
495
                                cls.tolerations.remove(existing_tol)
1✔
496
                    case _:
1✔
497
                        setattr(cls, k, v)
1✔
498
            if cls.resource_pool is None:
1✔
499
                raise errors.BaseError(
×
500
                    message="Unexpected internal error.",
501
                    detail=f"The resource class {resource_class_id} is not associated with any resource pool.",
502
                )
503
            quota = self.quotas_repo.get_quota(cls.resource_pool.quota) if cls.resource_pool.quota else None
1✔
504
            cls_model = cls.dump()
1✔
505
            if quota and not quota.is_resource_class_compatible(cls_model):
1✔
506
                raise errors.ValidationError(
×
507
                    message=f"The resource class {cls_model} is not compatible with the quota {quota}"
508
                )
509
            return cls_model
1✔
510

511
    @_only_admins
2✔
512
    async def get_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> list[str]:
2✔
513
        """Get all tolerations of a resource class."""
514
        async with self.session_maker() as session:
1✔
515
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
516
            if len(res_classes) == 0:
1✔
UNCOV
517
                raise errors.MissingResourceError(
×
518
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
519
                    f"class with ID {class_id} do not exist, or they are not related."
520
                )
521
            stmt = select(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
522
            res = await session.execute(stmt)
1✔
523
            return [i.key for i in res.scalars().all()]
1✔
524

525
    @_only_admins
2✔
526
    async def delete_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
527
        """Delete all tolerations for a specific resource class."""
528
        async with self.session_maker() as session, session.begin():
1✔
529
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
530
            if len(res_classes) == 0:
1✔
531
                raise errors.MissingResourceError(
×
532
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
533
                    f"class with ID {class_id} do not exist, or they are not related."
534
                )
535
            stmt = delete(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
536
            await session.execute(stmt)
1✔
537

538
    @_only_admins
2✔
539
    async def get_affinities(
2✔
540
        self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int
541
    ) -> list[models.NodeAffinity]:
542
        """Get all affinities for a resource class."""
543
        async with self.session_maker() as session:
1✔
544
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
545
            if len(res_classes) == 0:
1✔
546
                raise errors.MissingResourceError(
×
547
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
548
                    f"class with ID {class_id} do not exist, or they are not related."
549
                )
550
            stmt = select(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
551
            res = await session.execute(stmt)
1✔
552
            return [i.dump() for i in res.scalars().all()]
1✔
553

554
    @_only_admins
2✔
555
    async def delete_affinities(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
556
        """Delete all affinities from a resource class."""
557
        async with self.session_maker() as session, session.begin():
1✔
558
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
559
            if len(res_classes) == 0:
1✔
560
                raise errors.MissingResourceError(
×
561
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
562
                    f"class with ID {class_id} do not exist, or they are not related."
563
                )
564
            stmt = delete(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
565
            await session.execute(stmt)
1✔
566

567

568
@dataclass
2✔
569
class RespositoryUsers:
2✔
570
    """Information about which users can access a specific resource pool."""
2✔
571

572
    resource_pool_id: int
2✔
573
    allowed: list[base_models.User] = field(default_factory=list)
2✔
574
    disallowed: list[base_models.User] = field(default_factory=list)
2✔
575

576

577
class UserRepository(_Base):
2✔
578
    """The adapter used for accessing resource pool users with SQLAlchemy."""
2✔
579

580
    def __init__(
2✔
581
        self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository, user_repo: UserRepo
582
    ) -> None:
583
        super().__init__(session_maker, quotas_repo)
2✔
584
        self.kc_user_repo = user_repo
2✔
585

586
    @_only_admins
2✔
587
    async def get_resource_pool_users(
2✔
588
        self,
589
        *,
590
        api_user: base_models.APIUser,
591
        resource_pool_id: int,
592
        keycloak_id: Optional[str] = None,
593
    ) -> RespositoryUsers:
594
        """Get users of a specific resource pool from the database."""
595
        async with self.session_maker() as session, session.begin():
2✔
596
            stmt = (
2✔
597
                select(schemas.ResourcePoolORM)
598
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
599
                .options(selectinload(schemas.ResourcePoolORM.users))
600
            )
601
            if keycloak_id is not None:
2✔
602
                stmt = stmt.join(schemas.ResourcePoolORM.users, isouter=True).where(
1✔
603
                    or_(
604
                        schemas.RPUserORM.keycloak_id == keycloak_id,
605
                        schemas.ResourcePoolORM.public == true(),
606
                        schemas.ResourceClassORM.default == true(),
607
                    )
608
                )
609
            res = await session.execute(stmt)
2✔
610
            rp = res.scalars().first()
2✔
611
            if rp is None:
2✔
612
                raise errors.MissingResourceError(message=f"Resource pool with id {resource_pool_id} does not exist")
1✔
613
            specific_user: base_models.User | None = None
1✔
614
            if keycloak_id:
1✔
615
                specific_user_res = (
×
616
                    await session.execute(select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id))
617
                ).scalar_one_or_none()
618
                specific_user = None if not specific_user_res else specific_user_res.dump()
×
619
            allowed: list[base_models.User] = []
1✔
620
            disallowed: list[base_models.User] = []
1✔
621
            if rp.default:
1✔
622
                disallowed_stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.no_default_access == true())
1✔
623
                if keycloak_id:
1✔
624
                    disallowed_stmt = disallowed_stmt.where(schemas.RPUserORM.keycloak_id == keycloak_id)
×
625
                disallowed_res = await session.execute(disallowed_stmt)
1✔
626
                disallowed = [user.dump() for user in disallowed_res.scalars().all()]
1✔
627
                if specific_user and specific_user not in disallowed:
1✔
628
                    allowed = [specific_user]
×
629
            elif rp.public and not rp.default:
1✔
630
                if specific_user:
×
631
                    allowed = [specific_user]
×
632
            elif not rp.public and not rp.default:
1✔
633
                allowed = [user.dump() for user in rp.users]
1✔
634
            return RespositoryUsers(rp.id, allowed, disallowed)
1✔
635

636
    async def get_user_resource_pools(
2✔
637
        self,
638
        api_user: base_models.APIUser,
639
        keycloak_id: str,
640
        resource_pool_id: Optional[int] = None,
641
        resource_pool_name: Optional[str] = None,
642
    ) -> list[models.ResourcePool]:
643
        """Get resource pools that a specific user has access to."""
644
        async with self.session_maker() as session, session.begin():
2✔
645
            if not api_user.is_admin and api_user.id != keycloak_id:
2✔
646
                raise errors.ValidationError(
×
647
                    message="Users cannot query for resource pools that belong to other users."
648
                )
649

650
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
651
            stmt = stmt.where(
2✔
652
                or_(
653
                    schemas.ResourcePoolORM.public == true(),
654
                    schemas.ResourcePoolORM.users.any(schemas.RPUserORM.keycloak_id == keycloak_id),
655
                )
656
            )
657
            if resource_pool_name is not None:
2✔
658
                stmt = stmt.where(schemas.ResourcePoolORM.name == resource_pool_name)
×
659
            if resource_pool_id is not None:
2✔
660
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
661
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
662
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
663
            res = await session.execute(stmt)
2✔
664
            rps: Sequence[schemas.ResourcePoolORM] = res.scalars().all()
2✔
665
            output: list[models.ResourcePool] = []
2✔
666
            for rp in rps:
2✔
667
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
2✔
668
                output.append(rp.dump(quota))
2✔
669
            return output
2✔
670

671
    @_only_admins
2✔
672
    async def update_user_resource_pools(
2✔
673
        self, api_user: base_models.APIUser, keycloak_id: str, resource_pool_ids: list[int], append: bool = True
674
    ) -> list[models.ResourcePool]:
675
        """Update the resource pools that a specific user has access to."""
676
        async with self.session_maker() as session, session.begin():
1✔
677
            kc_user = await self.kc_user_repo.get_user(keycloak_id)
1✔
678
            if kc_user is None:
1✔
679
                raise errors.MissingResourceError(message=f"The user with ID {keycloak_id} does not exist")
×
680
            stmt = (
1✔
681
                select(schemas.RPUserORM)
682
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
683
                .options(selectinload(schemas.RPUserORM.resource_pools))
684
            )
685
            res = await session.execute(stmt)
1✔
686
            user = res.scalars().first()
1✔
687
            if user is None:
1✔
688
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
689
                session.add(user)
1✔
690
            stmt_rp = (
1✔
691
                select(schemas.ResourcePoolORM)
692
                .where(schemas.ResourcePoolORM.id.in_(resource_pool_ids))
693
                .options(selectinload(schemas.ResourcePoolORM.classes))
694
            )
695
            if user.no_default_access:
1✔
696
                stmt_rp = stmt_rp.where(schemas.ResourcePoolORM.default == false())
×
697
            res_rp = await session.execute(stmt_rp)
1✔
698
            rps_to_add = res_rp.scalars().all()
1✔
699
            if len(rps_to_add) != len(resource_pool_ids):
1✔
700
                missing_rps = set(resource_pool_ids).difference(set([i.id for i in rps_to_add]))
×
701
                raise errors.MissingResourceError(
×
702
                    message=(
703
                        f"The resource pools with ids: {missing_rps} do not exist or user doesn't have access to "
704
                        "default resource pool."
705
                    )
706
                )
707
            if user.no_default_access:
1✔
708
                default_rp = next((rp for rp in rps_to_add if rp.default), None)
×
709
                if default_rp:
×
710
                    raise errors.ForbiddenError(
×
711
                        message=f"User with keycloak id {keycloak_id} cannot access the default resource pool"
712
                    )
713
            if append:
1✔
714
                user_rp_ids = {rp.id for rp in user.resource_pools}
1✔
715
                rps_to_add = [rp for rp in rps_to_add if rp.id not in user_rp_ids]
1✔
716
                user.resource_pools.extend(rps_to_add)
1✔
717
            else:
718
                user.resource_pools = list(rps_to_add)
1✔
719
            output: list[models.ResourcePool] = []
1✔
720
            for rp in rps_to_add:
1✔
721
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
722
                output.append(rp.dump(quota))
1✔
723
            return output
1✔
724

725
    @_only_admins
2✔
726
    async def delete_resource_pool_user(
2✔
727
        self, api_user: base_models.APIUser, resource_pool_id: int, keycloak_id: str
728
    ) -> None:
729
        """Remove a user from a specific resource pool."""
730
        async with self.session_maker() as session, session.begin():
1✔
731
            sub = (
1✔
732
                select(schemas.RPUserORM.id)
733
                .join(schemas.ResourcePoolORM, schemas.RPUserORM.resource_pools)
734
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
735
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
736
            )
737
            stmt = delete(schemas.resource_pools_users).where(schemas.resource_pools_users.c.user_id.in_(sub))
1✔
738
            await session.execute(stmt)
1✔
739

740
    @_only_admins
2✔
741
    async def update_resource_pool_users(
2✔
742
        self, api_user: base_models.APIUser, resource_pool_id: int, user_ids: Collection[str], append: bool = True
743
    ) -> list[base_models.User]:
744
        """Update the users to have access to a specific resource pool."""
745
        async with self.session_maker() as session, session.begin():
2✔
746
            stmt = (
2✔
747
                select(schemas.ResourcePoolORM)
748
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
749
                .options(
750
                    selectinload(schemas.ResourcePoolORM.users),
751
                    selectinload(schemas.ResourcePoolORM.classes),
752
                )
753
            )
754
            res = await session.execute(stmt)
2✔
755
            rp: Optional[schemas.ResourcePoolORM] = res.scalars().first()
2✔
756
            if rp is None:
2✔
757
                raise errors.MissingResourceError(
1✔
758
                    message=f"The resource pool with id {resource_pool_id} does not exist"
759
                )
760
            if rp.default:
1✔
761
                # NOTE: If the resource pool is default just check if any users are prevented from having
762
                # default resource pool access - and remove the restriction.
763
                all_existing_users = await self.get_resource_pool_users(
1✔
764
                    api_user=api_user, resource_pool_id=resource_pool_id
765
                )
766
                users_to_modify = [user for user in all_existing_users.disallowed if user.keycloak_id in user_ids]
1✔
767
                return await gather(
1✔
768
                    *[
769
                        self.update_user(
770
                            api_user=api_user, keycloak_id=no_default_user.keycloak_id, no_default_access=False
771
                        )
772
                        for no_default_user in users_to_modify
773
                    ]
774
                )
775
            stmt_usr = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id.in_(user_ids))
1✔
776
            res_usr = await session.execute(stmt_usr)
1✔
777
            users_to_add_exist = res_usr.scalars().all()
1✔
778
            user_ids_to_add_exist = [i.keycloak_id for i in users_to_add_exist]
1✔
779
            users_to_add_missing = [
1✔
780
                schemas.RPUserORM(keycloak_id=user_id) for user_id in user_ids if user_id not in user_ids_to_add_exist
781
            ]
782
            if append:
1✔
783
                rp_user_ids = {rp.id for rp in rp.users}
1✔
784
                users_to_add = [u for u in list(users_to_add_exist) + users_to_add_missing if u.id not in rp_user_ids]
1✔
785
                rp.users.extend(users_to_add)
1✔
786
            else:
787
                rp.users = list(users_to_add_exist) + users_to_add_missing
1✔
788
            return [usr.dump() for usr in rp.users]
1✔
789

790
    @_only_admins
2✔
791
    async def update_user(self, api_user: base_models.APIUser, keycloak_id: str, **kwargs: Any) -> base_models.User:
2✔
792
        """Update a specific user."""
793
        async with self.session_maker() as session, session.begin():
1✔
794
            stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id)
1✔
795
            res = await session.execute(stmt)
1✔
796
            user: Optional[schemas.RPUserORM] = res.scalars().first()
1✔
797
            if not user:
1✔
798
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
799
                session.add(user)
1✔
800
            allowed_updates = set(["no_default_access"])
1✔
801
            if not set(kwargs.keys()).issubset(allowed_updates):
1✔
802
                raise errors.ValidationError(
×
803
                    message=f"Only the following fields {allowed_updates} " "can be updated for a resource pool user.."
804
                )
805
            if (no_default_access := kwargs.get("no_default_access", None)) is not None:
1✔
806
                user.no_default_access = no_default_access
1✔
807
            return user.dump()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc