• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 9446003648

10 Jun 2024 09:35AM UTC coverage: 90.298% (+0.06%) from 90.239%
9446003648

Pull #248

github

web-flow
Merge 9ff3e8a6c into 1e340ea36
Pull Request #248: feat: add support for bitbucket

40 of 46 new or added lines in 3 files covered. (86.96%)

4 existing lines in 4 files now uncovered.

8488 of 9400 relevant lines covered (90.3%)

1.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

86.82
/components/renku_data_services/crc/db.py
1
"""Adapter based on SQLAlchemy.
2✔
2

3
These adapters currently do a few things (1) generate SQL queries, (2) apply resource access controls,
4
(3) fetch the SQL results and (4) format them into a workable representation. In the future and as the code
5
grows it is worth looking into separating this functionality into separate classes rather than having
6
it all in one place.
7
"""
8

9
from asyncio import gather
2✔
10
from collections.abc import Awaitable, Callable, Collection, Sequence
2✔
11
from dataclasses import dataclass, field
2✔
12
from functools import wraps
2✔
13
from typing import Any, Concatenate, Optional, ParamSpec, TypeVar, cast
2✔
14

15
from sqlalchemy import NullPool, create_engine, delete, select
2✔
16
from sqlalchemy.ext.asyncio import AsyncSession
2✔
17
from sqlalchemy.orm import Session, selectinload, sessionmaker
2✔
18
from sqlalchemy.sql import Select, and_, not_, or_
2✔
19
from sqlalchemy.sql.expression import false, true
2✔
20

21
import renku_data_services.base_models as base_models
2✔
22
from renku_data_services import errors
2✔
23
from renku_data_services.crc import models
2✔
24
from renku_data_services.crc import orm as schemas
2✔
25
from renku_data_services.k8s.quota import QuotaRepository
2✔
26
from renku_data_services.users.db import UserRepo
2✔
27

28

29
class _Base:
2✔
30
    def __init__(self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository) -> None:
2✔
31
        self.session_maker = session_maker
2✔
32
        self.quotas_repo = quotas_repo
2✔
33

34

35
def _resource_pool_access_control(
2✔
36
    api_user: base_models.APIUser,
37
    stmt: Select[tuple[schemas.ResourcePoolORM]],
38
) -> Select[tuple[schemas.ResourcePoolORM]]:
39
    """Modifies a select query to list resource pools based on whether the user is logged in or not."""
40
    output = stmt
2✔
41
    match (api_user.is_authenticated, api_user.is_admin):
2✔
42
        case True, False:
2✔
43
            # The user is logged in but not an admin
44
            api_user_has_default_pool_access = not_(
1✔
45
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
46
                # record does NOT EXIST in the database
47
                select(schemas.RPUserORM.no_default_access)
48
                .where(
49
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
50
                )
51
                .exists()
52
            )
53
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
1✔
54
                or_(
55
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
56
                    and_(  # the pool is not default but is public
57
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
58
                    ),
59
                    and_(  # the pool is default and the user is not prohibited from accessing it
60
                        schemas.ResourcePoolORM.default == true(),
61
                        api_user_has_default_pool_access,
62
                    ),
63
                )
64
            )
65
        case True, True:
2✔
66
            # The user is logged in and is an admin, they can see all resource pools
67
            pass
2✔
68
        case False, _:
1✔
69
            # The user is not logged in, they can see only the public resource pools
70
            output = output.where(schemas.ResourcePoolORM.public == true())
1✔
71
    return output
2✔
72

73

74
def _classes_user_access_control(
2✔
75
    api_user: base_models.APIUser,
76
    stmt: Select[tuple[schemas.ResourceClassORM]],
77
) -> Select[tuple[schemas.ResourceClassORM]]:
78
    """Adjust the select statement for classes based on whether the user is logged in or not."""
79
    output = stmt
1✔
80
    match (api_user.is_authenticated, api_user.is_admin):
1✔
81
        case True, False:
1✔
82
            # The user is logged in but is not an admin
83
            api_user_has_default_pool_access = not_(
×
84
                # NOTE: The only way to check that a user is allowed to access the default pool is that such a
85
                # record does NOT EXIST in the database
86
                select(schemas.RPUserORM.no_default_access)
87
                .where(
88
                    and_(schemas.RPUserORM.keycloak_id == api_user.id, schemas.RPUserORM.no_default_access == true())
89
                )
90
                .exists()
91
            )
92
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
×
93
                or_(
94
                    schemas.RPUserORM.keycloak_id == api_user.id,  # the user is part of the pool
95
                    and_(  # the pool is not default but is public
96
                        schemas.ResourcePoolORM.default != true(), schemas.ResourcePoolORM.public == true()
97
                    ),
98
                    and_(  # the pool is default and the user is not prohibited from accessing it
99
                        schemas.ResourcePoolORM.default == true(),
100
                        api_user_has_default_pool_access,
101
                    ),
102
                )
103
            )
104
        case True, True:
1✔
105
            # The user is logged in and is an admin, they can see all resource classes
106
            pass
1✔
107
        case False, _:
×
108
            # The user is not logged in, they can see only the classes from public resource pools
109
            output = output.join(schemas.RPUserORM, schemas.ResourcePoolORM.users, isouter=True).where(
×
110
                schemas.ResourcePoolORM.public == true(),
111
            )
112
    return output
1✔
113

114

115
_P = ParamSpec("_P")
2✔
116
_T = TypeVar("_T")
2✔
117

118

119
def _only_admins(f: Callable[Concatenate[Any, _P], Awaitable[_T]]) -> Callable[Concatenate[Any, _P], Awaitable[_T]]:
2✔
120
    """Decorator that errors out if the user is not an admin.
121

122
    It expects the APIUser model to be a named parameter in the decorated function or
123
    to be the first parameter (after self).
124
    """
125

126
    @wraps(f)
2✔
127
    async def decorated_function(self: Any, *args: _P.args, **kwargs: _P.kwargs) -> _T:
2✔
128
        api_user = None
2✔
129
        if "api_user" in kwargs:
2✔
130
            api_user = kwargs["api_user"]
2✔
131
        elif len(args) >= 1:
1✔
132
            api_user = args[0]
1✔
133
        if api_user is not None and not isinstance(api_user, base_models.APIUser):
2✔
134
            raise errors.ProgrammingError(message="Expected user parameter is not of type APIUser.")
×
135
        if api_user is None or not api_user.is_admin:
2✔
136
            raise errors.Unauthorized(message="You do not have the required permissions for this operation.")
×
137

138
        # the user is authenticated and is an admin
139
        response = await f(self, *args, **kwargs)
2✔
140
        return response
2✔
141

142
    return decorated_function
2✔
143

144

145
class ResourcePoolRepository(_Base):
2✔
146
    """The adapter used for accessing resource pools with SQLAlchemy."""
2✔
147

148
    def initialize(self, sync_connection_url: str, rp: models.ResourcePool) -> None:
2✔
149
        """Add the default resource pool if it does not already exist."""
150
        engine = create_engine(sync_connection_url, poolclass=NullPool)
×
151
        session_maker = sessionmaker(
×
152
            engine,
153
            class_=Session,
154
            expire_on_commit=True,
155
        )
156
        with session_maker() as session, session.begin():
×
157
            stmt = select(schemas.ResourcePoolORM.default == true())
×
158
            res = session.execute(stmt)
×
159
            default_rp = res.scalars().first()
×
160
            if default_rp is None:
×
161
                orm = schemas.ResourcePoolORM.load(rp)
×
162
                session.add(orm)
×
163

164
    async def get_resource_pools(
2✔
165
        self, api_user: base_models.APIUser, id: Optional[int] = None, name: Optional[str] = None
166
    ) -> list[models.ResourcePool]:
167
        """Get resource pools from database."""
168
        async with self.session_maker() as session:
1✔
169
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
1✔
170
            if name is not None:
1✔
171
                stmt = stmt.where(schemas.ResourcePoolORM.name == name)
1✔
172
            if id is not None:
1✔
173
                stmt = stmt.where(schemas.ResourcePoolORM.id == id)
1✔
174
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
175
            stmt = _resource_pool_access_control(api_user, stmt)
1✔
176
            res = await session.execute(stmt)
1✔
177
            orms = res.scalars().all()
1✔
178
            output: list[models.ResourcePool] = []
1✔
179
            for rp in orms:
1✔
180
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
181
                output.append(rp.dump(quota))
1✔
182
            return output
1✔
183

184
    async def filter_resource_pools(
2✔
185
        self,
186
        api_user: base_models.APIUser,
187
        cpu: float = 0,
188
        memory: int = 0,
189
        max_storage: int = 0,
190
        gpu: int = 0,
191
    ) -> list[models.ResourcePool]:
192
        """Get resource pools from database with indication of which resource class matches the specified crtieria."""
193
        async with self.session_maker() as session:
2✔
194
            criteria = models.ResourceClass(
2✔
195
                name="criteria",
196
                cpu=cpu,
197
                gpu=gpu,
198
                memory=memory,
199
                max_storage=max_storage,
200
                # NOTE: the default storage has to be <= max_storage but is not used for filtering classes,
201
                # only the max_storage is used to filter resource classes that match a request
202
                default_storage=max_storage,
203
            )
204
            stmt = (
2✔
205
                select(schemas.ResourcePoolORM)
206
                .join(schemas.ResourcePoolORM.classes)
207
                .order_by(
208
                    schemas.ResourcePoolORM.id,
209
                    schemas.ResourcePoolORM.name,
210
                    schemas.ResourceClassORM.id,
211
                    schemas.ResourceClassORM.name,
212
                )
213
            )
214
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
215
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
216
            res = await session.execute(stmt)
2✔
217
            return [i.dump(self.quotas_repo.get_quota(i.quota), criteria) for i in res.unique().scalars().all()]
2✔
218

219
    @_only_admins
2✔
220
    async def insert_resource_pool(
2✔
221
        self, api_user: base_models.APIUser, resource_pool: models.ResourcePool
222
    ) -> models.ResourcePool:
223
        """Insert resource pool into database."""
224
        quota = None
2✔
225
        if resource_pool.quota:
2✔
226
            for rc in resource_pool.classes:
2✔
227
                if not resource_pool.quota.is_resource_class_compatible(rc):
2✔
228
                    raise errors.ValidationError(
×
229
                        message=f"The quota {quota} is not compatible with resource class {rc}"
230
                    )
231
            quota = self.quotas_repo.create_quota(resource_pool.quota)
2✔
232
            resource_pool = resource_pool.set_quota(quota)
2✔
233
        orm = schemas.ResourcePoolORM.load(resource_pool)
2✔
234
        async with self.session_maker() as session, session.begin():
2✔
235
            if orm.idle_threshold == 0:
2✔
236
                orm.idle_threshold = None
×
237
            if orm.hibernation_threshold == 0:
2✔
238
                orm.hibernation_threshold = None
×
239
            if orm.default:
2✔
240
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.default == true())
1✔
241
                res = await session.execute(stmt)
1✔
242
                default_rps = res.unique().scalars().all()
1✔
243
                if len(default_rps) >= 1:
1✔
244
                    raise errors.ValidationError(
×
245
                        message="There can only be one default resource pool and one already exists."
246
                    )
247
            session.add(orm)
2✔
248
        return orm.dump(quota)
2✔
249

250
    async def get_classes(
2✔
251
        self,
252
        api_user: base_models.APIUser,
253
        id: Optional[int] = None,
254
        name: Optional[str] = None,
255
        resource_pool_id: Optional[int] = None,
256
    ) -> list[models.ResourceClass]:
257
        """Get classes from the database."""
258
        async with self.session_maker() as session:
1✔
259
            stmt = select(schemas.ResourceClassORM).join(
1✔
260
                schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool, isouter=True
261
            )
262
            if resource_pool_id is not None:
1✔
263
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
264
            if id is not None:
1✔
265
                stmt = stmt.where(schemas.ResourceClassORM.id == id)
1✔
266
            if name is not None:
1✔
267
                stmt = stmt.where(schemas.ResourceClassORM.name == name)
1✔
268
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
269
            stmt = _classes_user_access_control(api_user, stmt)
1✔
270
            res = await session.execute(stmt)
1✔
271
            orms = res.scalars().all()
1✔
272
            return [orm.dump() for orm in orms]
1✔
273

274
    @_only_admins
2✔
275
    async def insert_resource_class(
2✔
276
        self,
277
        api_user: base_models.APIUser,
278
        resource_class: models.ResourceClass,
279
        *,
280
        resource_pool_id: Optional[int] = None,
281
    ) -> models.ResourceClass:
282
        """Insert a resource class in the database."""
283
        cls = schemas.ResourceClassORM.load(resource_class)
1✔
284
        async with self.session_maker() as session, session.begin():
1✔
285
            if resource_pool_id is not None:
1✔
286
                stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
287
                res = await session.execute(stmt)
1✔
288
                rp = res.scalars().first()
1✔
289
                if rp is None:
1✔
290
                    raise errors.MissingResourceError(
1✔
291
                        message=f"Resource pool with id {resource_pool_id} does not exist."
292
                    )
293
                if cls.default and len(rp.classes) > 0 and any([icls.default for icls in rp.classes]):
1✔
294
                    raise errors.ValidationError(
×
295
                        message="There can only be one default resource class per resource pool."
296
                    )
297
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
298
                if quota and not quota.is_resource_class_compatible(resource_class):
1✔
299
                    raise errors.ValidationError(
×
300
                        message="The resource class {resource_class} is not compatible with the quota {quota}."
301
                    )
302
                cls.resource_pool = rp
1✔
303
                cls.resource_pool_id = rp.id
1✔
304

305
            session.add(cls)
1✔
306
        return cls.dump()
1✔
307

308
    @_only_admins
2✔
309
    async def update_resource_pool(self, api_user: base_models.APIUser, id: int, **kwargs: Any) -> models.ResourcePool:
2✔
310
        """Update an existing resource pool in the database."""
311
        rp: Optional[schemas.ResourcePoolORM] = None
1✔
312
        async with self.session_maker() as session, session.begin():
1✔
313
            stmt = (
1✔
314
                select(schemas.ResourcePoolORM)
315
                .where(schemas.ResourcePoolORM.id == id)
316
                .options(selectinload(schemas.ResourcePoolORM.classes))
317
            )
318
            res = await session.execute(stmt)
1✔
319
            rp = res.scalars().first()
1✔
320
            if rp is None:
1✔
321
                raise errors.MissingResourceError(message=f"Resource pool with id {id} cannot be found")
1✔
322
            quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
323
            if len(kwargs) == 0:
1✔
324
                return rp.dump(quota)
×
325

326
            if kwargs.get("idle_threshold", None) == 0:
1✔
327
                kwargs["idle_threshold"] = None
×
328
            if kwargs.get("hibernation_threshold", None) == 0:
1✔
329
                kwargs["hibernation_threshold"] = None
×
330
            # NOTE: The .update method on the model validates the update to the resource pool
331
            old_rp_model = rp.dump(quota)
1✔
332
            new_rp_model = old_rp_model.update(**kwargs)
1✔
333
            new_classes = None
1✔
334
            new_classes_coroutines = []
1✔
335
            for key, val in kwargs.items():
1✔
336
                match key:
1✔
337
                    case "name" | "public" | "default" | "idle_threshold" | "hibernation_threshold":
1✔
338
                        setattr(rp, key, val)
1✔
339
                    case "quota":
1✔
340
                        if val is None:
1✔
341
                            continue
×
342

343
                        # For updating a quota, there are two options:
344
                        # 1. no quota exists --> create a new one
345
                        # 2. a quota exists and can only be updated, not replaced (the ids, if provided, must match)
346

347
                        new_id = val.get("id")
1✔
348

349
                        if quota and quota.id is not None and new_id is not None and quota.id != new_id:
1✔
350
                            raise errors.ValidationError(
×
351
                                message="The ID of an existing quota cannot be updated, "
352
                                f"please remove the ID field from the request or use ID {quota.id}."
353
                            )
354

355
                        # the id must match for update
356
                        if quota:
1✔
357
                            val["id"] = quota.id or new_id
1✔
358

359
                        new_quota = models.Quota.from_dict(val)
1✔
360

361
                        if new_id or quota:
1✔
362
                            new_quota = self.quotas_repo.update_quota(new_quota)
1✔
363
                        else:
364
                            new_quota = self.quotas_repo.create_quota(new_quota)
×
365
                        rp.quota = new_quota.id
1✔
366
                        new_rp_model = new_rp_model.update(quota=new_quota)
1✔
367
                    case "classes":
1✔
368
                        new_classes = []
1✔
369
                        for cls in val:
1✔
370
                            class_id = cls.pop("id")
1✔
371
                            cls.pop("matching", None)
1✔
372
                            if len(cls) == 0:
1✔
373
                                raise errors.ValidationError(
×
374
                                    message="More fields than the id of the class "
375
                                    "should be provided when updating it"
376
                                )
377
                            new_classes_coroutines.append(
1✔
378
                                self.update_resource_class(
379
                                    api_user, resource_pool_id=id, resource_class_id=class_id, **cls
380
                                )
381
                            )
382
                    case _:
×
383
                        pass
×
384
            new_classes = await gather(*new_classes_coroutines)
1✔
385
            if new_classes is not None and len(new_classes) > 0:
1✔
386
                new_rp_model = new_rp_model.update(classes=new_classes)
1✔
387
            return new_rp_model
1✔
388

389
    @_only_admins
2✔
390
    async def delete_resource_pool(self, api_user: base_models.APIUser, id: int) -> Optional[models.ResourcePool]:
2✔
391
        """Delete a resource pool from the database."""
392
        async with self.session_maker() as session, session.begin():
2✔
393
            stmt = select(schemas.ResourcePoolORM).where(schemas.ResourcePoolORM.id == id)
2✔
394
            res = await session.execute(stmt)
2✔
395
            rp = res.scalars().first()
2✔
396
            if rp is not None:
2✔
397
                if rp.default:
1✔
398
                    raise errors.ValidationError(message="The default resource pool cannot be deleted.")
×
399
                await session.delete(rp)
1✔
400
                quota = None
1✔
401
                if rp.quota:
1✔
402
                    quota = self.quotas_repo.get_quota(rp.quota)
1✔
403
                    self.quotas_repo.delete_quota(rp.quota)
1✔
404
                return rp.dump(quota)
1✔
405
            return None
1✔
406

407
    @_only_admins
2✔
408
    async def delete_resource_class(
2✔
409
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int
410
    ) -> None:
411
        """Delete a specific resource class."""
412
        async with self.session_maker() as session, session.begin():
1✔
413
            stmt = (
1✔
414
                select(schemas.ResourceClassORM)
415
                .where(schemas.ResourceClassORM.id == resource_class_id)
416
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
417
            )
418
            res = await session.execute(stmt)
1✔
419
            cls = res.scalars().first()
1✔
420
            if cls is not None:
1✔
421
                if cls.default:
1✔
422
                    raise errors.ValidationError(message="The default resource class cannot be deleted.")
×
423
                await session.delete(cls)
1✔
424

425
    @_only_admins
2✔
426
    async def update_resource_class(
2✔
427
        self, api_user: base_models.APIUser, resource_pool_id: int, resource_class_id: int, **kwargs: Any
428
    ) -> models.ResourceClass:
429
        """Update a specific resource class."""
430
        async with self.session_maker() as session, session.begin():
1✔
431
            stmt = (
1✔
432
                select(schemas.ResourceClassORM)
433
                .where(schemas.ResourceClassORM.id == resource_class_id)
434
                .where(schemas.ResourceClassORM.resource_pool_id == resource_pool_id)
435
                .join(schemas.ResourcePoolORM, schemas.ResourceClassORM.resource_pool)
436
                .options(selectinload(schemas.ResourceClassORM.resource_pool))
437
            )
438
            res = await session.execute(stmt)
1✔
439
            cls: Optional[schemas.ResourceClassORM] = res.scalars().first()
1✔
440
            if cls is None:
1✔
441
                raise errors.MissingResourceError(
×
442
                    message=(
443
                        f"The resource class with id {resource_class_id} does not exist, the resource pool with "
444
                        f"id {resource_pool_id} does not exist or the requested resource class is not "
445
                        "associated with the resource pool"
446
                    )
447
                )
448
            for k, v in kwargs.items():
1✔
449
                match k:
1✔
450
                    case "node_affinities":
1✔
451
                        v = cast(list[dict[str, str | bool]], v)
1✔
452
                        existing_affinities: dict[str, schemas.NodeAffintyORM] = {i.key: i for i in cls.node_affinities}
1✔
453
                        new_affinities: dict[str, schemas.NodeAffintyORM] = {
1✔
454
                            i["key"]: schemas.NodeAffintyORM(**i) for i in v
455
                        }
456
                        for new_affinity_key, new_affinity in new_affinities.items():
1✔
457
                            if new_affinity_key in existing_affinities:
1✔
458
                                # UPDATE existing affinity
459
                                existing_affinity = existing_affinities[new_affinity_key]
1✔
460
                                if (
1✔
461
                                    new_affinity.required_during_scheduling
462
                                    != existing_affinity.required_during_scheduling
463
                                ):
464
                                    existing_affinity.required_during_scheduling = (
1✔
465
                                        new_affinity.required_during_scheduling
466
                                    )
467
                            else:
468
                                # CREATE a brand new affinity
469
                                cls.node_affinities.append(new_affinity)
1✔
470
                        # REMOVE an affinity
471
                        for existing_affinity_key, existing_affinity in existing_affinities.items():
1✔
472
                            if existing_affinity_key not in new_affinities:
1✔
473
                                cls.node_affinities.remove(existing_affinity)
1✔
474
                    case "tolerations":
1✔
475
                        v = cast(list[str], v)
1✔
476
                        existing_tolerations: dict[str, schemas.TolerationORM] = {
1✔
477
                            tol.key: tol for tol in cls.tolerations
478
                        }
479
                        new_tolerations: dict[str, schemas.TolerationORM] = {
1✔
480
                            tol: schemas.TolerationORM(key=tol) for tol in v
481
                        }
482
                        for new_tol_key, new_tol in new_tolerations.items():
1✔
483
                            if new_tol_key not in existing_tolerations:
1✔
484
                                # CREATE a brand new toleration
485
                                cls.tolerations.append(new_tol)
1✔
486
                        # REMOVE a toleration
487
                        for existing_tol_key, existing_tol in existing_tolerations.items():
1✔
488
                            if existing_tol_key not in new_tolerations:
1✔
489
                                cls.tolerations.remove(existing_tol)
1✔
490
                    case _:
1✔
491
                        setattr(cls, k, v)
1✔
492
            if cls.resource_pool is None:
1✔
493
                raise errors.BaseError(
×
494
                    message="Unexpected internal error.",
495
                    detail=f"The resource class {resource_class_id} is not associated with any resource pool.",
496
                )
497
            quota = self.quotas_repo.get_quota(cls.resource_pool.quota) if cls.resource_pool.quota else None
1✔
498
            cls_model = cls.dump()
1✔
499
            if quota and not quota.is_resource_class_compatible(cls_model):
1✔
500
                raise errors.ValidationError(
×
501
                    message=f"The resource class {cls_model} is not compatible with the quota {quota}"
502
                )
503
            return cls_model
1✔
504

505
    @_only_admins
2✔
506
    async def get_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> list[str]:
2✔
507
        """Get all tolerations of a resource class."""
508
        async with self.session_maker() as session:
1✔
509
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
510
            if len(res_classes) == 0:
1✔
511
                raise errors.MissingResourceError(
×
512
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
513
                    f"class with ID {class_id} do not exist, or they are not related."
514
                )
515
            stmt = select(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
516
            res = await session.execute(stmt)
1✔
517
            return [i.key for i in res.scalars().all()]
1✔
518

519
    @_only_admins
2✔
520
    async def delete_tolerations(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
521
        """Delete all tolerations for a specific resource class."""
522
        async with self.session_maker() as session, session.begin():
1✔
523
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
524
            if len(res_classes) == 0:
1✔
525
                raise errors.MissingResourceError(
×
526
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
527
                    f"class with ID {class_id} do not exist, or they are not related."
528
                )
529
            stmt = delete(schemas.TolerationORM).where(schemas.TolerationORM.resource_class_id == class_id)
1✔
530
            await session.execute(stmt)
1✔
531

532
    @_only_admins
2✔
533
    async def get_affinities(
2✔
534
        self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int
535
    ) -> list[models.NodeAffinity]:
536
        """Get all affinities for a resource class."""
537
        async with self.session_maker() as session:
1✔
538
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
539
            if len(res_classes) == 0:
1✔
UNCOV
540
                raise errors.MissingResourceError(
×
541
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
542
                    f"class with ID {class_id} do not exist, or they are not related."
543
                )
544
            stmt = select(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
545
            res = await session.execute(stmt)
1✔
546
            return [i.dump() for i in res.scalars().all()]
1✔
547

548
    @_only_admins
2✔
549
    async def delete_affinities(self, api_user: base_models.APIUser, resource_pool_id: int, class_id: int) -> None:
2✔
550
        """Delete all affinities from a resource class."""
551
        async with self.session_maker() as session, session.begin():
1✔
552
            res_classes = await self.get_classes(api_user, class_id, resource_pool_id=resource_pool_id)
1✔
553
            if len(res_classes) == 0:
1✔
554
                raise errors.MissingResourceError(
×
555
                    message=f"The resource pool with ID {resource_pool_id} or the resource "
556
                    f"class with ID {class_id} do not exist, or they are not related."
557
                )
558
            stmt = delete(schemas.NodeAffintyORM).where(schemas.NodeAffintyORM.resource_class_id == class_id)
1✔
559
            await session.execute(stmt)
1✔
560

561

562
@dataclass
2✔
563
class RespositoryUsers:
2✔
564
    """Information about which users can access a specific resource pool."""
2✔
565

566
    resource_pool_id: int
2✔
567
    allowed: list[base_models.User] = field(default_factory=list)
2✔
568
    disallowed: list[base_models.User] = field(default_factory=list)
2✔
569

570

571
class UserRepository(_Base):
2✔
572
    """The adapter used for accessing resource pool users with SQLAlchemy."""
2✔
573

574
    def __init__(
2✔
575
        self, session_maker: Callable[..., AsyncSession], quotas_repo: QuotaRepository, user_repo: UserRepo
576
    ) -> None:
577
        super().__init__(session_maker, quotas_repo)
2✔
578
        self.kc_user_repo = user_repo
2✔
579

580
    @_only_admins
2✔
581
    async def get_resource_pool_users(
2✔
582
        self,
583
        *,
584
        api_user: base_models.APIUser,
585
        resource_pool_id: int,
586
        keycloak_id: Optional[str] = None,
587
    ) -> RespositoryUsers:
588
        """Get users of a specific resource pool from the database."""
589
        async with self.session_maker() as session, session.begin():
1✔
590
            stmt = (
1✔
591
                select(schemas.ResourcePoolORM)
592
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
593
                .options(selectinload(schemas.ResourcePoolORM.users))
594
            )
595
            if keycloak_id is not None:
1✔
596
                stmt = stmt.join(schemas.ResourcePoolORM.users, isouter=True).where(
×
597
                    or_(
598
                        schemas.RPUserORM.keycloak_id == keycloak_id,
599
                        schemas.ResourcePoolORM.public == true(),
600
                        schemas.ResourceClassORM.default == true(),
601
                    )
602
                )
603
            res = await session.execute(stmt)
1✔
604
            rp = res.scalars().first()
1✔
605
            if rp is None:
1✔
606
                raise errors.MissingResourceError(message=f"Resource pool with id {resource_pool_id} does not exist")
×
607
            specific_user: base_models.User | None = None
1✔
608
            if keycloak_id:
1✔
609
                specific_user_res = (
×
610
                    await session.execute(select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id))
611
                ).scalar_one_or_none()
612
                specific_user = None if not specific_user_res else specific_user_res.dump()
×
613
            allowed: list[base_models.User] = []
1✔
614
            disallowed: list[base_models.User] = []
1✔
615
            if rp.default:
1✔
616
                disallowed_stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.no_default_access == true())
1✔
617
                if keycloak_id:
1✔
618
                    disallowed_stmt = disallowed_stmt.where(schemas.RPUserORM.keycloak_id == keycloak_id)
×
619
                disallowed_res = await session.execute(disallowed_stmt)
1✔
620
                disallowed = [user.dump() for user in disallowed_res.scalars().all()]
1✔
621
                if specific_user and specific_user not in disallowed:
1✔
622
                    allowed = [specific_user]
×
623
            elif rp.public and not rp.default:
1✔
624
                if specific_user:
×
625
                    allowed = [specific_user]
×
626
            elif not rp.public and not rp.default:
1✔
627
                allowed = [user.dump() for user in rp.users]
1✔
628
            return RespositoryUsers(rp.id, allowed, disallowed)
1✔
629

630
    async def get_user_resource_pools(
2✔
631
        self,
632
        api_user: base_models.APIUser,
633
        keycloak_id: str,
634
        resource_pool_id: Optional[int] = None,
635
        resource_pool_name: Optional[str] = None,
636
    ) -> list[models.ResourcePool]:
637
        """Get resource pools that a specific user has access to."""
638
        async with self.session_maker() as session, session.begin():
2✔
639
            if not api_user.is_admin and api_user.id != keycloak_id:
2✔
640
                raise errors.ValidationError(
×
641
                    message="Users cannot query for resource pools that belong to other users."
642
                )
643

644
            stmt = select(schemas.ResourcePoolORM).options(selectinload(schemas.ResourcePoolORM.classes))
2✔
645
            stmt = stmt.where(
2✔
646
                or_(
647
                    schemas.ResourcePoolORM.public == true(),
648
                    schemas.ResourcePoolORM.users.any(schemas.RPUserORM.keycloak_id == keycloak_id),
649
                )
650
            )
651
            if resource_pool_name is not None:
2✔
652
                stmt = stmt.where(schemas.ResourcePoolORM.name == resource_pool_name)
×
653
            if resource_pool_id is not None:
2✔
654
                stmt = stmt.where(schemas.ResourcePoolORM.id == resource_pool_id)
1✔
655
            # NOTE: The line below ensures that the right users can access the right resources, do not remove.
656
            stmt = _resource_pool_access_control(api_user, stmt)
2✔
657
            res = await session.execute(stmt)
2✔
658
            rps: Sequence[schemas.ResourcePoolORM] = res.scalars().all()
2✔
659
            output: list[models.ResourcePool] = []
2✔
660
            for rp in rps:
2✔
661
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
2✔
662
                output.append(rp.dump(quota))
2✔
663
            return output
2✔
664

665
    @_only_admins
2✔
666
    async def update_user_resource_pools(
2✔
667
        self, api_user: base_models.APIUser, keycloak_id: str, resource_pool_ids: list[int], append: bool = True
668
    ) -> list[models.ResourcePool]:
669
        """Update the resource pools that a specific user has access to."""
670
        async with self.session_maker() as session, session.begin():
1✔
671
            kc_user = await self.kc_user_repo.get_user(api_user, keycloak_id)
1✔
672
            if kc_user is None:
1✔
673
                raise errors.MissingResourceError(message=f"The user with ID {keycloak_id} does not exist")
×
674
            stmt = (
1✔
675
                select(schemas.RPUserORM)
676
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
677
                .options(selectinload(schemas.RPUserORM.resource_pools))
678
            )
679
            res = await session.execute(stmt)
1✔
680
            user = res.scalars().first()
1✔
681
            if user is None:
1✔
682
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
683
                session.add(user)
1✔
684
            stmt_rp = (
1✔
685
                select(schemas.ResourcePoolORM)
686
                .where(schemas.ResourcePoolORM.id.in_(resource_pool_ids))
687
                .options(selectinload(schemas.ResourcePoolORM.classes))
688
            )
689
            if user.no_default_access:
1✔
690
                stmt_rp = stmt_rp.where(schemas.ResourcePoolORM.default == false())
×
691
            res_rp = await session.execute(stmt_rp)
1✔
692
            rps_to_add = res_rp.scalars().all()
1✔
693
            if len(rps_to_add) != len(resource_pool_ids):
1✔
694
                missing_rps = set(resource_pool_ids).difference(set([i.id for i in rps_to_add]))
×
695
                raise errors.MissingResourceError(
×
696
                    message=(
697
                        f"The resource pools with ids: {missing_rps} do not exist or user doesn't have access to "
698
                        "default resource pool."
699
                    )
700
                )
701
            if user.no_default_access:
1✔
702
                default_rp = next((rp for rp in rps_to_add if rp.default), None)
×
703
                if default_rp:
×
704
                    raise errors.NoDefaultPoolAccessError(
×
705
                        message=f"User with keycloak id {keycloak_id} cannot access the default resource pool"
706
                    )
707
            if append:
1✔
708
                user_rp_ids = {rp.id for rp in user.resource_pools}
1✔
709
                rps_to_add = [rp for rp in rps_to_add if rp.id not in user_rp_ids]
1✔
710
                user.resource_pools.extend(rps_to_add)
1✔
711
            else:
712
                user.resource_pools = list(rps_to_add)
1✔
713
            output: list[models.ResourcePool] = []
1✔
714
            for rp in rps_to_add:
1✔
715
                quota = self.quotas_repo.get_quota(rp.quota) if rp.quota else None
1✔
716
                output.append(rp.dump(quota))
1✔
717
            return output
1✔
718

719
    @_only_admins
2✔
720
    async def delete_resource_pool_user(
2✔
721
        self, api_user: base_models.APIUser, resource_pool_id: int, keycloak_id: str
722
    ) -> None:
723
        """Remove a user from a specific resource pool."""
724
        async with self.session_maker() as session, session.begin():
1✔
725
            sub = (
1✔
726
                select(schemas.RPUserORM.id)
727
                .join(schemas.ResourcePoolORM, schemas.RPUserORM.resource_pools)
728
                .where(schemas.RPUserORM.keycloak_id == keycloak_id)
729
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
730
            )
731
            stmt = delete(schemas.resource_pools_users).where(schemas.resource_pools_users.c.user_id.in_(sub))
1✔
732
            await session.execute(stmt)
1✔
733

734
    @_only_admins
2✔
735
    async def update_resource_pool_users(
2✔
736
        self, api_user: base_models.APIUser, resource_pool_id: int, user_ids: Collection[str], append: bool = True
737
    ) -> list[base_models.User]:
738
        """Update the users to have access to a specific resource pool."""
739
        async with self.session_maker() as session, session.begin():
1✔
740
            stmt = (
1✔
741
                select(schemas.ResourcePoolORM)
742
                .where(schemas.ResourcePoolORM.id == resource_pool_id)
743
                .options(
744
                    selectinload(schemas.ResourcePoolORM.users),
745
                    selectinload(schemas.ResourcePoolORM.classes),
746
                )
747
            )
748
            res = await session.execute(stmt)
1✔
749
            rp: Optional[schemas.ResourcePoolORM] = res.scalars().first()
1✔
750
            if rp is None:
1✔
751
                raise errors.MissingResourceError(
×
752
                    message=f"The resource pool with id {resource_pool_id} does not exist"
753
                )
754
            if rp.default:
1✔
755
                # NOTE: If the resource pool is default just check if any users are prevented from having
756
                # default resource pool access - and remove the restriction.
757
                all_existing_users = await self.get_resource_pool_users(
1✔
758
                    api_user=api_user, resource_pool_id=resource_pool_id
759
                )
760
                users_to_modify = [user for user in all_existing_users.disallowed if user.keycloak_id in user_ids]
1✔
761
                return await gather(
1✔
762
                    *[
763
                        self.update_user(
764
                            api_user=api_user, keycloak_id=no_default_user.keycloak_id, no_default_access=False
765
                        )
766
                        for no_default_user in users_to_modify
767
                    ]
768
                )
769
            stmt_usr = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id.in_(user_ids))
1✔
770
            res_usr = await session.execute(stmt_usr)
1✔
771
            users_to_add_exist = res_usr.scalars().all()
1✔
772
            user_ids_to_add_exist = [i.keycloak_id for i in users_to_add_exist]
1✔
773
            users_to_add_missing = [
1✔
774
                schemas.RPUserORM(keycloak_id=user_id) for user_id in user_ids if user_id not in user_ids_to_add_exist
775
            ]
776
            if append:
1✔
777
                rp_user_ids = {rp.id for rp in rp.users}
1✔
778
                users_to_add = [u for u in list(users_to_add_exist) + users_to_add_missing if u.id not in rp_user_ids]
1✔
779
                rp.users.extend(users_to_add)
1✔
780
            else:
781
                rp.users = list(users_to_add_exist) + users_to_add_missing
1✔
782
            return [usr.dump() for usr in rp.users]
1✔
783

784
    @_only_admins
2✔
785
    async def update_user(self, api_user: base_models.APIUser, keycloak_id: str, **kwargs: Any) -> base_models.User:
2✔
786
        """Update a specific user."""
787
        async with self.session_maker() as session, session.begin():
1✔
788
            stmt = select(schemas.RPUserORM).where(schemas.RPUserORM.keycloak_id == keycloak_id)
1✔
789
            res = await session.execute(stmt)
1✔
790
            user: Optional[schemas.RPUserORM] = res.scalars().first()
1✔
791
            if not user:
1✔
792
                user = schemas.RPUserORM(keycloak_id=keycloak_id)
1✔
793
                session.add(user)
1✔
794
            allowed_updates = set(["no_default_access"])
1✔
795
            if not set(kwargs.keys()).issubset(allowed_updates):
1✔
796
                raise errors.ValidationError(
×
797
                    message=f"Only the following fields {allowed_updates} " "can be updated for a resource pool user.."
798
                )
799
            if (no_default_access := kwargs.get("no_default_access", None)) is not None:
1✔
800
                user.no_default_access = no_default_access
1✔
801
            return user.dump()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc