• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 11034384359

25 Sep 2024 01:46PM UTC coverage: 90.429% (+0.03%) from 90.403%
11034384359

Pull #425

github

web-flow
Merge 871b24a54 into 013683161
Pull Request #425: refactor: define a single method to remove entities from authz

28 of 33 new or added lines in 1 file covered. (84.85%)

2 existing lines in 2 files now uncovered.

9335 of 10323 relevant lines covered (90.43%)

1.6 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.18
/components/renku_data_services/users/db.py
1
"""Database adapters and helpers for users."""
2✔
2

3
import secrets
2✔
4
from collections.abc import Callable
2✔
5
from dataclasses import asdict, dataclass, field
2✔
6
from datetime import datetime, timedelta
2✔
7
from typing import Any, cast
2✔
8

9
from sanic.log import logger
2✔
10
from sqlalchemy import delete, func, select
2✔
11
from sqlalchemy.ext.asyncio import AsyncSession
2✔
12

13
from renku_data_services import base_models
2✔
14
from renku_data_services.authz.authz import Authz, AuthzOperation, ResourceType
2✔
15
from renku_data_services.base_api.auth import APIUser, only_authenticated
2✔
16
from renku_data_services.errors import errors
2✔
17
from renku_data_services.message_queue import events
2✔
18
from renku_data_services.message_queue.avro_models.io.renku.events import v2 as avro_schema_v2
2✔
19
from renku_data_services.message_queue.db import EventRepository
2✔
20
from renku_data_services.message_queue.interface import IMessageQueue
2✔
21
from renku_data_services.message_queue.redis_queue import dispatch_message
2✔
22
from renku_data_services.namespace.db import GroupRepository
2✔
23
from renku_data_services.namespace.orm import NamespaceORM
2✔
24
from renku_data_services.users.config import UserPreferencesConfig
2✔
25
from renku_data_services.users.kc_api import IKeycloakAPI
2✔
26
from renku_data_services.users.models import (
2✔
27
    KeycloakAdminEvent,
28
    PinnedProjects,
29
    UserInfo,
30
    UserInfoFieldUpdate,
31
    UserInfoUpdate,
32
    UserPreferences,
33
)
34
from renku_data_services.users.orm import LastKeycloakEventTimestamp, UserORM, UserPreferencesORM
2✔
35
from renku_data_services.utils.core import with_db_transaction
2✔
36
from renku_data_services.utils.cryptography import decrypt_string, encrypt_string
2✔
37

38

39
@dataclass
2✔
40
class UserRepo:
2✔
41
    """An adapter for accessing users from the database."""
2✔
42

43
    session_maker: Callable[..., AsyncSession]
2✔
44
    message_queue: IMessageQueue
2✔
45
    event_repo: EventRepository
2✔
46
    group_repo: GroupRepository
2✔
47
    encryption_key: bytes = field(repr=False)
2✔
48
    authz: Authz
2✔
49

50
    def __post_init__(self) -> None:
2✔
51
        self._users_sync = UsersSync(
2✔
52
            self.session_maker, self.message_queue, self.event_repo, self.group_repo, self.authz
53
        )
54

55
    async def initialize(self, kc_api: IKeycloakAPI) -> None:
2✔
56
        """Do a total sync of users from Keycloak if there is nothing in the DB."""
57
        users = await self._get_users()
2✔
58
        if len(users) > 0:
2✔
59
            return
×
60
        await self._users_sync.users_sync(kc_api)
2✔
61

62
    async def _add_api_user(self, user: APIUser) -> UserInfo:
2✔
63
        if not user.id:
1✔
64
            raise errors.UnauthorizedError(message="The user has to be authenticated to be inserted in the DB.")
×
65
        result = await self._users_sync.update_or_insert_user(
1✔
66
            user_id=user.id,
67
            payload=dict(
68
                first_name=user.first_name,
69
                last_name=user.last_name,
70
                email=user.email,
71
            ),
72
        )
73
        return result.new
1✔
74

75
    async def get_user(self, id: str) -> UserInfo | None:
2✔
76
        """Get a specific user from the database."""
77
        async with self.session_maker() as session:
2✔
78
            result = await session.scalars(select(UserORM).where(UserORM.keycloak_id == id))
2✔
79
            user = result.one_or_none()
2✔
80
            if user is None:
2✔
81
                return None
2✔
82
            if user.namespace is None:
2✔
83
                raise errors.ProgrammingError(message=f"Cannot find a user namespace for user {id}.")
×
84
            return user.namespace.dump_user()
2✔
85

86
    async def get_or_create_user(self, requested_by: APIUser, id: str) -> UserInfo | None:
2✔
87
        """Get a specific user from the database and create it potentially if it does not exist.
88

89
        If the caller is the same user that is being retrieved and they are authenticated and
90
        their user information is not in the database then this call adds the user in the DB
91
        in addition to returning the user information.
92
        """
93
        async with self.session_maker() as session, session.begin():
2✔
94
            user = await self.get_user(id=id)
2✔
95
            if not user and id == requested_by.id:
2✔
96
                return await self._add_api_user(requested_by)
1✔
97
            return user
2✔
98

99
    @only_authenticated
2✔
100
    async def get_users(self, requested_by: APIUser, email: str | None = None) -> list[UserInfo]:
2✔
101
        """Get users from the database."""
102
        if not email and not requested_by.is_admin:
2✔
103
            raise errors.ForbiddenError(message="Non-admin users cannot list all users.")
1✔
104
        users = await self._get_users(email)
2✔
105

106
        is_api_user_missing = not any([requested_by.id == user.id for user in users])
2✔
107

108
        if not email and is_api_user_missing:
2✔
109
            api_user_info = await self._add_api_user(requested_by)
1✔
110
            users.append(api_user_info)
1✔
111
        return users
2✔
112

113
    async def _get_users(self, email: str | None = None) -> list[UserInfo]:
2✔
114
        async with self.session_maker() as session:
2✔
115
            stmt = select(UserORM)
2✔
116
            if email:
2✔
117
                stmt = stmt.where(UserORM.email == email)
1✔
118
            result = await session.scalars(stmt)
2✔
119
            users = result.all()
2✔
120

121
            for user in users:
2✔
122
                if user.namespace is None:
2✔
123
                    raise errors.ProgrammingError(message=f"Cannot find a user namespace for user {id}.")
×
124

125
            return [user.dump() for user in users if user.namespace is not None]
2✔
126

127
    @only_authenticated
2✔
128
    async def get_or_create_user_secret_key(self, requested_by: APIUser) -> str:
2✔
129
        """Get a user's secret encryption key or create it if it doesn't exist."""
130

131
        async with self.session_maker() as session, session.begin():
2✔
132
            stmt = select(UserORM).where(UserORM.keycloak_id == requested_by.id)
2✔
133
            user = await session.scalar(stmt)
2✔
134
            if not user:
2✔
135
                raise errors.MissingResourceError(message=f"User with id {requested_by.id} not found")
×
136
            if user.secret_key is not None:
2✔
137
                return decrypt_string(self.encryption_key, user.keycloak_id, user.secret_key)
2✔
138
            # create a new secret key
139
            secret_key = secrets.token_urlsafe(32)
2✔
140
            user.secret_key = encrypt_string(self.encryption_key, user.keycloak_id, secret_key)
2✔
141
            session.add(user)
2✔
142

143
        return secret_key
2✔
144

145

146
class UsersSync:
2✔
147
    """Sync users from Keycloak to the database."""
2✔
148

149
    def __init__(
2✔
150
        self,
151
        session_maker: Callable[..., AsyncSession],
152
        message_queue: IMessageQueue,
153
        event_repo: EventRepository,
154
        group_repo: GroupRepository,
155
        authz: Authz,
156
    ) -> None:
157
        self.session_maker = session_maker
2✔
158
        self.message_queue: IMessageQueue = message_queue
2✔
159
        self.event_repo: EventRepository = event_repo
2✔
160
        self.group_repo = group_repo
2✔
161
        self.authz = authz
2✔
162

163
    async def _get_user(self, id: str) -> UserInfo | None:
2✔
164
        """Get a specific user."""
165
        async with self.session_maker() as session, session.begin():
2✔
166
            stmt = select(UserORM).where(UserORM.keycloak_id == id)
2✔
167
            res = await session.execute(stmt)
2✔
168
            orm = res.scalar_one_or_none()
2✔
169
            return orm.dump() if orm else None
2✔
170

171
    @with_db_transaction
2✔
172
    @Authz.authz_change(AuthzOperation.update_or_insert, ResourceType.user)
2✔
173
    @dispatch_message(events.UpdateOrInsertUser)
2✔
174
    async def update_or_insert_user(
2✔
175
        self, user_id: str, payload: dict[str, Any], *, session: AsyncSession | None = None
176
    ) -> UserInfoUpdate:
177
        """Update a user or insert it if it does not exist."""
178
        if not session:
2✔
179
            raise errors.ProgrammingError(message="A database session is required")
×
180
        res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user_id))
2✔
181
        existing_user = res.scalar_one_or_none()
2✔
182
        if existing_user:
2✔
183
            return await self._update_user(session=session, user_id=user_id, existing_user=existing_user, **payload)
1✔
184
        else:
185
            return await self._insert_user(session=session, user_id=user_id, **payload)
2✔
186

187
    async def _insert_user(self, session: AsyncSession, user_id: str, **kwargs: Any) -> UserInfoUpdate:
2✔
188
        """Insert a user."""
189
        kwargs.pop("keycloak_id", None)
2✔
190
        kwargs.pop("id", None)
2✔
191
        slug = base_models.Slug.from_user(
2✔
192
            kwargs.get("email"), kwargs.get("first_name"), kwargs.get("last_name"), user_id
193
        ).value
194
        namespace = await self.group_repo._create_user_namespace_slug(
2✔
195
            session, user_slug=slug, retry_enumerate=5, retry_random=True
196
        )
197
        slug = base_models.Slug.from_name(namespace)
2✔
198
        new_user = UserORM(keycloak_id=user_id, namespace=NamespaceORM(slug=slug.value, user_id=user_id), **kwargs)
2✔
199
        new_user.namespace.user = new_user
2✔
200
        session.add(new_user)
2✔
201
        await session.flush()
2✔
202
        return UserInfoUpdate(None, new_user.dump())
2✔
203

204
    async def _update_user(
2✔
205
        self, session: AsyncSession, user_id: str, existing_user: UserORM | None, **kwargs: Any
206
    ) -> UserInfoUpdate:
207
        """Update a user."""
208
        if not existing_user:
1✔
209
            async with self.session_maker() as session, session.begin():
×
210
                res = await session.execute(select(UserORM).where(UserORM.keycloak_id == user_id))
×
211
                existing_user = res.scalar_one_or_none()
×
212
        if not existing_user:
1✔
213
            raise errors.MissingResourceError(message=f"The user with id '{user_id}' cannot be found")
×
214
        old_user = existing_user.dump()
1✔
215

216
        kwargs.pop("keycloak_id", None)
1✔
217
        kwargs.pop("id", None)
1✔
218
        session.add(existing_user)  # reattach to session
1✔
219
        for field_name, field_value in kwargs.items():
1✔
220
            if getattr(existing_user, field_name, None) != field_value:
1✔
221
                setattr(existing_user, field_name, field_value)
1✔
222
        namespace = await self.group_repo.get_user_namespace(user_id)
1✔
223
        if not namespace:
1✔
224
            raise errors.ProgrammingError(
×
225
                message=f"Cannot find a user namespace for user {user_id} when updating the user."
226
            )
227
        return UserInfoUpdate(old_user, existing_user.dump())
1✔
228

229
    @with_db_transaction
2✔
230
    @dispatch_message(avro_schema_v2.UserRemoved)
2✔
231
    async def _remove_user(self, user_id: str, *, session: AsyncSession | None = None) -> str | None:
2✔
232
        """Remove a user from the database."""
233
        if not session:
1✔
234
            raise errors.ProgrammingError(message="A database session is required")
×
235
        logger.info(f"Trying to remove user with ID {user_id}")
1✔
236
        stmt = delete(UserORM).where(UserORM.keycloak_id == user_id).returning(UserORM)
1✔
237
        user = await session.scalar(stmt)
1✔
238
        await self.authz._remove_user_namespace(user_id)
1✔
239
        if not user:
1✔
240
            logger.info(f"User with ID {user_id} was not found.")
1✔
241
            return None
1✔
242
        logger.info(f"User with ID {user_id} was removed from the database.")
1✔
243
        logger.info(f"User namespace with ID {user_id} was removed from the authorization database.")
1✔
244
        return user_id
1✔
245

246
    async def users_sync(self, kc_api: IKeycloakAPI) -> None:
2✔
247
        """Sync all users from Keycloak into the users database."""
248
        logger.info("Starting a total user database sync.")
2✔
249
        kc_users = kc_api.get_users()
2✔
250

251
        async def _do_update(raw_kc_user: dict[str, Any]) -> None:
2✔
252
            kc_user = UserInfo.from_kc_user_payload(raw_kc_user)
2✔
253
            logger.info(f"Checking user with Keycloak ID {kc_user.id}")
2✔
254
            db_user = await self._get_user(kc_user.id)
2✔
255
            if db_user != kc_user:
2✔
256
                logger.info(f"Inserting or updating user {db_user} -> {kc_user}")
2✔
257
                await self.update_or_insert_user(kc_user.id, asdict(kc_user))
2✔
258

259
        # NOTE: If asyncio.gather is used here you quickly exhaust all DB connections
260
        # or timeout on waiting for available connections
261
        for user in kc_users:
2✔
262
            await _do_update(user)
2✔
263

264
    async def events_sync(self, kc_api: IKeycloakAPI) -> None:
2✔
265
        """Use the events from Keycloak to update the users database."""
266
        async with self.session_maker() as session, session.begin():
1✔
267
            res_count = await session.execute(select(func.count()).select_from(UserORM))
1✔
268
            count = res_count.scalar() or 0
1✔
269
            if count == 0:
1✔
270
                await self.users_sync(kc_api)
×
271
            logger.info("Starting periodic event sync.")
1✔
272
            stmt = select(LastKeycloakEventTimestamp)
1✔
273
            latest_utc_timestamp_orm = (await session.execute(stmt)).scalar_one_or_none()
1✔
274
            previous_sync_latest_utc_timestamp = (
1✔
275
                latest_utc_timestamp_orm.timestamp_utc if latest_utc_timestamp_orm is not None else None
276
            )
277
            logger.info(f"The previous sync latest event is {previous_sync_latest_utc_timestamp} UTC")
1✔
278
            now_utc = datetime.utcnow()
1✔
279
            start_date = now_utc.date() - timedelta(days=1)
1✔
280
            logger.info(f"Pulling events with a start date of {start_date} UTC")
1✔
281
            user_events = kc_api.get_user_events(start_date=start_date)
1✔
282
            update_admin_events = kc_api.get_admin_events(
1✔
283
                start_date=start_date, event_types=[KeycloakAdminEvent.CREATE, KeycloakAdminEvent.UPDATE]
284
            )
285
            delete_admin_events = kc_api.get_admin_events(
1✔
286
                start_date=start_date, event_types=[KeycloakAdminEvent.DELETE]
287
            )
288
            parsed_updates = UserInfoFieldUpdate.from_json_admin_events(update_admin_events)
1✔
289
            parsed_updates.extend(UserInfoFieldUpdate.from_json_user_events(user_events))
1✔
290
            parsed_deletions = UserInfoFieldUpdate.from_json_admin_events(delete_admin_events)
1✔
291
            parsed_updates = sorted(parsed_updates, key=lambda x: x.timestamp_utc)
1✔
292
            parsed_deletions = sorted(parsed_deletions, key=lambda x: x.timestamp_utc)
1✔
293
            if previous_sync_latest_utc_timestamp is not None:
1✔
294
                # Some events have already been processed - filter out old events we have seen
295
                logger.info(f"Filtering events older than {previous_sync_latest_utc_timestamp}")
1✔
296
                parsed_updates = [u for u in parsed_updates if u.timestamp_utc > previous_sync_latest_utc_timestamp]
1✔
297
                parsed_deletions = [u for u in parsed_deletions if u.timestamp_utc > previous_sync_latest_utc_timestamp]
1✔
298
            latest_update_timestamp = None
1✔
299
            latest_delete_timestamp = None
1✔
300
            for update in parsed_updates:
1✔
301
                logger.info(f"Processing update event {update}")
1✔
302
                await self.update_or_insert_user(update.user_id, {update.field_name: update.new_value})
1✔
303
                latest_update_timestamp = update.timestamp_utc
1✔
304
            for deletion in parsed_deletions:
1✔
305
                logger.info(f"Processing deletion event {deletion}")
1✔
306
                await self._remove_user(deletion.user_id)
1✔
307
                latest_delete_timestamp = deletion.timestamp_utc
1✔
308
            # Update the latest processed event timestamp
309
            current_sync_latest_utc_timestamp = latest_update_timestamp
1✔
310
            if latest_delete_timestamp is not None and (
1✔
311
                current_sync_latest_utc_timestamp is None or current_sync_latest_utc_timestamp < latest_delete_timestamp
312
            ):
313
                current_sync_latest_utc_timestamp = latest_delete_timestamp
1✔
314
            if current_sync_latest_utc_timestamp is not None:
1✔
315
                if latest_utc_timestamp_orm is None:
1✔
316
                    session.add(LastKeycloakEventTimestamp(current_sync_latest_utc_timestamp))
1✔
317
                    logger.info(
1✔
318
                        f"Inserted the latest sync event timestamp in the database: {current_sync_latest_utc_timestamp}"
319
                    )
320
                else:
321
                    latest_utc_timestamp_orm.timestamp_utc = current_sync_latest_utc_timestamp
1✔
322
                    logger.info(
1✔
323
                        f"Updated the latest sync event timestamp in the database: {current_sync_latest_utc_timestamp}"
324
                    )
325

326

327
@dataclass
2✔
328
class UserPreferencesRepository:
2✔
329
    """Repository for user preferences."""
2✔
330

331
    session_maker: Callable[..., AsyncSession]
2✔
332
    user_preferences_config: UserPreferencesConfig
2✔
333

334
    @only_authenticated
2✔
335
    async def get_user_preferences(
2✔
336
        self,
337
        requested_by: APIUser,
338
    ) -> UserPreferences:
339
        """Get user preferences from the database."""
340
        async with self.session_maker() as session:
2✔
341
            res = await session.scalars(select(UserPreferencesORM).where(UserPreferencesORM.user_id == requested_by.id))
2✔
342
            user_preferences = res.one_or_none()
2✔
343

344
            if user_preferences is None:
2✔
345
                raise errors.MissingResourceError(message="Preferences not found for user.", quiet=True)
1✔
346
            return user_preferences.dump()
1✔
347

348
    @only_authenticated
2✔
349
    async def delete_user_preferences(self, requested_by: APIUser) -> None:
2✔
350
        """Delete user preferences from the database."""
351
        async with self.session_maker() as session, session.begin():
1✔
352
            res = await session.scalars(select(UserPreferencesORM).where(UserPreferencesORM.user_id == requested_by.id))
1✔
353
            user_preferences = res.one_or_none()
1✔
354

355
            if user_preferences is None:
1✔
356
                return
×
357

358
            await session.delete(user_preferences)
1✔
359

360
    @only_authenticated
2✔
361
    async def add_pinned_project(self, requested_by: APIUser, project_slug: str) -> UserPreferences:
2✔
362
        """Adds a new pinned project to the user's preferences."""
363
        async with self.session_maker() as session, session.begin():
2✔
364
            res = await session.scalars(select(UserPreferencesORM).where(UserPreferencesORM.user_id == requested_by.id))
2✔
365
            user_preferences = res.one_or_none()
2✔
366

367
            if user_preferences is None:
2✔
368
                new_preferences = UserPreferences(
2✔
369
                    user_id=cast(str, requested_by.id), pinned_projects=PinnedProjects(project_slugs=[project_slug])
370
                )
371
                user_preferences = UserPreferencesORM.load(new_preferences)
2✔
372
                session.add(user_preferences)
2✔
373
                return user_preferences.dump()
2✔
374

375
            project_slugs: list[str]
376
            project_slugs = user_preferences.pinned_projects.get("project_slugs", [])
1✔
377

378
            # Do nothing if the project is already listed
379
            for slug in project_slugs:
1✔
380
                if project_slug.lower() == slug.lower():
1✔
381
                    return user_preferences.dump()
1✔
382

383
            # Check if we have reached the maximum number of pins
384
            if (
1✔
385
                self.user_preferences_config.max_pinned_projects > 0
386
                and len(project_slugs) >= self.user_preferences_config.max_pinned_projects
387
            ):
UNCOV
388
                raise errors.ValidationError(
×
389
                    message="Maximum number of pinned projects already allocated"
390
                    + f" (limit: {self.user_preferences_config.max_pinned_projects}, current: {len(project_slugs)})"
391
                )
392

393
            new_project_slugs = list(project_slugs) + [project_slug]
1✔
394
            pinned_projects = PinnedProjects(project_slugs=new_project_slugs).model_dump()
1✔
395
            user_preferences.pinned_projects = pinned_projects
1✔
396
            return user_preferences.dump()
1✔
397

398
    @only_authenticated
2✔
399
    async def remove_pinned_project(self, requested_by: APIUser, project_slug: str) -> UserPreferences:
2✔
400
        """Removes on or all pinned projects from the user's preferences."""
401
        async with self.session_maker() as session, session.begin():
2✔
402
            res = await session.scalars(select(UserPreferencesORM).where(UserPreferencesORM.user_id == requested_by.id))
2✔
403
            user_preferences = res.one_or_none()
2✔
404

405
            if user_preferences is None:
2✔
406
                raise errors.MissingResourceError(message="Preferences not found for user.", quiet=True)
×
407

408
            project_slugs: list[str]
409
            project_slugs = user_preferences.pinned_projects.get("project_slugs", [])
2✔
410

411
            # Remove all projects if `project_slug` is None
412
            new_project_slugs = (
2✔
413
                [slug for slug in project_slugs if project_slug.lower() != slug.lower()] if project_slug else []
414
            )
415

416
            pinned_projects = PinnedProjects(project_slugs=new_project_slugs).model_dump()
2✔
417
            user_preferences.pinned_projects = pinned_projects
2✔
418
            return user_preferences.dump()
2✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc