• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 18123243513

30 Sep 2025 08:10AM UTC coverage: 86.702% (-0.01%) from 86.714%
18123243513

Pull #1019

github

web-flow
Merge e726c4543 into 0690bab65
Pull Request #1019: feat: Attempt to support dockerhub private images

70 of 101 new or added lines in 9 files covered. (69.31%)

106 existing lines in 6 files now uncovered.

22357 of 25786 relevant lines covered (86.7%)

1.52 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.48
/components/renku_data_services/connected_services/db.py
1
"""Adapters for connected services database classes."""
2

3
from base64 import b64decode, b64encode
2✔
4
from collections.abc import AsyncGenerator, Callable
2✔
5
from contextlib import asynccontextmanager
2✔
6
from typing import Any
2✔
7
from urllib.parse import urljoin, urlparse
2✔
8

9
from authlib.integrations.base_client import InvalidTokenError
2✔
10
from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuthError
2✔
11
from sqlalchemy import and_, select
2✔
12
from sqlalchemy.ext.asyncio import AsyncSession
2✔
13
from sqlalchemy.orm import selectinload
2✔
14
from ulid import ULID
2✔
15

16
import renku_data_services.base_models as base_models
2✔
17
from renku_data_services import errors
2✔
18
from renku_data_services.app_config import logging
2✔
19
from renku_data_services.base_api.pagination import PaginationRequest
2✔
20
from renku_data_services.connected_services import models
2✔
21
from renku_data_services.connected_services import orm as schemas
2✔
22
from renku_data_services.connected_services.provider_adapters import (
2✔
23
    GitHubAdapter,
24
    ProviderAdapter,
25
    get_provider_adapter,
26
)
27
from renku_data_services.connected_services.utils import generate_code_verifier
2✔
28
from renku_data_services.notebooks.api.classes.image import Image, ImageRepoDockerAPI
2✔
29
from renku_data_services.users.db import APIUser
2✔
30
from renku_data_services.utils.cryptography import decrypt_string, encrypt_string
2✔
31

32
logger = logging.getLogger(__name__)
2✔
33

34

35
class ConnectedServicesRepository:
2✔
36
    """Repository for connected services."""
37

38
    def __init__(
2✔
39
        self,
40
        session_maker: Callable[..., AsyncSession],
41
        encryption_key: bytes,
42
        async_oauth2_client_class: type[AsyncOAuth2Client],
43
    ):
44
        self.session_maker = session_maker
2✔
45
        self.encryption_key = encryption_key
2✔
46
        self.async_oauth2_client_class = async_oauth2_client_class
2✔
47
        self.supported_image_registry_providers = {
2✔
48
            models.ProviderKind.gitlab,
49
            models.ProviderKind.github,
50
            models.ProviderKind.dockerhub,
51
        }
52

53
    async def get_oauth2_clients(
2✔
54
        self,
55
        user: base_models.APIUser,
56
    ) -> list[models.OAuth2Client]:
57
        """Get all OAuth2 Clients from the database."""
58
        async with self.session_maker() as session:
2✔
59
            result = await session.scalars(select(schemas.OAuth2ClientORM))
2✔
60
            clients = result.all()
2✔
61
            return [c.dump(user_is_admin=user.is_admin) for c in clients]
2✔
62

63
    async def get_oauth2_client(self, provider_id: str, user: base_models.APIUser) -> models.OAuth2Client:
2✔
64
        """Get one OAuth2 Client from the database."""
65
        async with self.session_maker() as session:
2✔
66
            result = await session.scalars(
2✔
67
                select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.id == provider_id)
68
            )
69
            client = result.one_or_none()
2✔
70
            if client is None:
2✔
71
                raise errors.MissingResourceError(
1✔
72
                    message=f"OAuth2 Client with id '{provider_id}' does not exist or you do not have access to it."
73
                )
74
            return client.dump(user_is_admin=user.is_admin)
1✔
75

76
    async def insert_oauth2_client(
2✔
77
        self, user: base_models.APIUser, new_client: models.UnsavedOAuth2Client
78
    ) -> models.OAuth2Client:
79
        """Insert a new OAuth2 Client environment."""
80
        if user.id is None:
2✔
81
            raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
×
82
        if not user.is_admin:
2✔
83
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
84

85
        provider_id = base_models.Slug.from_name(new_client.id).value
2✔
86
        encrypted_client_secret = (
2✔
87
            encrypt_string(self.encryption_key, user.id, new_client.client_secret) if new_client.client_secret else None
88
        )
89
        client = schemas.OAuth2ClientORM(
2✔
90
            id=provider_id,
91
            kind=new_client.kind,
92
            app_slug=new_client.app_slug or "",
93
            client_id=new_client.client_id,
94
            client_secret=encrypted_client_secret,
95
            display_name=new_client.display_name,
96
            scope=new_client.scope,
97
            url=new_client.url,
98
            use_pkce=new_client.use_pkce or False,
99
            created_by_id=user.id,
100
            image_registry_url=new_client.image_registry_url,
101
            oidc_issuer_url=new_client.oidc_issuer_url or None,
102
        )
103

104
        async with self.session_maker() as session, session.begin():
2✔
105
            result = await session.scalars(
2✔
106
                select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.id == client.id)
107
            )
108
            existing_client = result.one_or_none()
2✔
109
            if existing_client is not None:
2✔
110
                raise errors.ValidationError(message=f"OAuth2 Client with id '{client.id}' already exists.")
×
111

112
            session.add(client)
2✔
113
            await session.flush()
2✔
114
            await session.refresh(client)
2✔
115
            return client.dump(user_is_admin=user.is_admin)
2✔
116

117
    async def update_oauth2_client(
2✔
118
        self,
119
        user: base_models.APIUser,
120
        provider_id: str,
121
        patch: models.OAuth2ClientPatch,
122
    ) -> models.OAuth2Client:
123
        """Update an OAuth2 Client entry."""
124
        if not user.is_admin:
2✔
125
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
126

127
        async with self.session_maker() as session, session.begin():
2✔
128
            result = await session.scalars(
2✔
129
                select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.id == provider_id)
130
            )
131
            client = result.one_or_none()
2✔
132
            if client is None:
2✔
133
                raise errors.MissingResourceError(message=f"OAuth2 Client with id '{provider_id}' does not exist.")
1✔
134

135
            if patch.kind is not None:
1✔
136
                client.kind = patch.kind
1✔
137
            if patch.app_slug is not None:
1✔
138
                client.app_slug = patch.app_slug
1✔
139
            if patch.client_id is not None:
1✔
140
                client.client_id = patch.client_id
×
141
            if patch.client_secret:
1✔
142
                client.client_secret = encrypt_string(self.encryption_key, client.created_by_id, patch.client_secret)
×
143
            elif patch.client_secret == "":  # nosec B105
1✔
144
                client.client_secret = None
×
145
            if patch.display_name is not None:
1✔
146
                client.display_name = patch.display_name
1✔
147
            if patch.scope is not None:
1✔
148
                client.scope = patch.scope
1✔
149
            if patch.url is not None:
1✔
150
                client.url = patch.url
1✔
151
            if patch.use_pkce is not None:
1✔
152
                client.use_pkce = patch.use_pkce
1✔
153
            if patch.image_registry_url:
1✔
154
                # Patching with a string of at least length 1 updates the value
155
                client.image_registry_url = patch.image_registry_url
1✔
156
            elif patch.image_registry_url == "":
×
157
                # Patching with "", removes the value
158
                client.image_registry_url = None
×
159
            if patch.oidc_issuer_url:
1✔
160
                client.oidc_issuer_url = patch.oidc_issuer_url
1✔
161
            elif patch.oidc_issuer_url == "":
×
162
                client.oidc_issuer_url = None
×
163
            # Unset oidc_issuer_url when the kind has been changed to a value other than 'generic_oidc'
164
            if client.kind != models.ProviderKind.generic_oidc:
1✔
165
                client.oidc_issuer_url = None
×
166

167
            await session.flush()
1✔
168
            await session.refresh(client)
1✔
169

170
            return client.dump(user_is_admin=user.is_admin)
1✔
171

172
    async def delete_oauth2_client(self, user: base_models.APIUser, provider_id: str) -> None:
2✔
173
        """Delete an OAuth2 Client."""
174
        if not user.is_admin:
2✔
175
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
176

177
        async with self.session_maker() as session, session.begin():
2✔
178
            result = await session.scalars(
2✔
179
                select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.id == provider_id)
180
            )
181
            client = result.one_or_none()
2✔
182

183
            if client is None:
2✔
184
                return
1✔
185

186
            await session.delete(client)
1✔
187

188
    async def authorize_client(
2✔
189
        self, user: base_models.APIUser, provider_id: str, callback_url: str, next_url: str | None = None
190
    ) -> str:
191
        """Authorize an OAuth2 Client."""
192
        if not user.is_authenticated or user.id is None:
2✔
193
            raise errors.UnauthorizedError(message="You do not have the required permissions for this operation.")
×
194

195
        async with self.session_maker() as session, session.begin():
2✔
196
            result = await session.scalars(
2✔
197
                select(schemas.OAuth2ClientORM).where(schemas.OAuth2ClientORM.id == provider_id)
198
            )
199
            client = result.one_or_none()
2✔
200

201
            if client is None:
2✔
202
                raise errors.MissingResourceError(message=f"OAuth2 Client with id '{provider_id}' does not exist.")
1✔
203

204
            adapter = get_provider_adapter(client)
1✔
205
            client_secret = (
1✔
206
                decrypt_string(self.encryption_key, client.created_by_id, client.client_secret)
207
                if client.client_secret
208
                else None
209
            )
210
            code_verifier = generate_code_verifier() if client.use_pkce else None
1✔
211
            code_challenge_method = "S256" if client.use_pkce else None
1✔
212
            async with self.async_oauth2_client_class(
1✔
213
                client_id=client.client_id,
214
                client_secret=client_secret,
215
                scope=client.scope,
216
                redirect_uri=callback_url,
217
                code_challenge_method=code_challenge_method,
218
            ) as oauth2_client:
219
                url: str
220
                state: str
221
                url, state = oauth2_client.create_authorization_url(
1✔
222
                    adapter.authorization_url, code_verifier=code_verifier, **adapter.authorization_url_extra_params
223
                )
224

225
                result_conn = await session.scalars(
1✔
226
                    select(schemas.OAuth2ConnectionORM)
227
                    .where(schemas.OAuth2ConnectionORM.client_id == client.id)
228
                    .where(schemas.OAuth2ConnectionORM.user_id == user.id)
229
                )
230
                connection = result_conn.one_or_none()
1✔
231

232
                if connection is None:
1✔
233
                    connection = schemas.OAuth2ConnectionORM(
1✔
234
                        user_id=user.id,
235
                        client_id=client.id,
236
                        token=None,
237
                        state=state,
238
                        status=models.ConnectionStatus.pending,
239
                        code_verifier=code_verifier,
240
                        next_url=next_url,
241
                    )
242
                    session.add(connection)
1✔
243
                else:
244
                    connection.state = state
×
245
                    connection.status = models.ConnectionStatus.pending
×
246
                    connection.code_verifier = code_verifier
×
247
                    connection.next_url = next_url
×
248

249
                await session.flush()
1✔
250
                await session.refresh(connection)
1✔
251

252
                return url
1✔
253

254
    async def custom_connect(self, user: APIUser, client_id: str, token: dict[str, Any]) -> ULID:
2✔
255
        """Adds a custom connection using the opaque token as given."""
256
        if not user.is_authenticated or user.id is None:
1✔
NEW
257
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
258

259
        if client_id == "" or token == {} or token.get("access_token") is None:
1✔
NEW
260
            raise errors.ValidationError(message="Client id and token are mandatory")
×
261

262
        token_set = self._encrypt_token_set(token=token, user_id=user.id)
1✔
263
        supported_providers = {models.ProviderKind.dockerhub}
1✔
264
        async with self.session_maker() as session, session.begin():
1✔
265
            result = await session.scalars(
1✔
266
                select(schemas.OAuth2ClientORM)
267
                .where(schemas.OAuth2ClientORM.id == client_id)
268
                .where(schemas.OAuth2ClientORM.kind.in_(supported_providers))
269
            )
270
            client = result.one_or_none()
1✔
271
            if client is None:
1✔
272
                raise errors.MissingResourceError(
1✔
273
                    message=f"OAuth2 Client with id '{client_id}' does not exist or doesn't support direct connections."
274
                )
275

NEW
276
            conn_orm = schemas.OAuth2ConnectionORM(
×
277
                user_id=user.id,
278
                client_id=client_id,
279
                token=token_set,
280
                state=None,
281
                status=models.ConnectionStatus.connected,
282
                code_verifier=None,
283
                next_url=None,
284
            )
NEW
285
            session.add(conn_orm)
×
NEW
286
            return conn_orm.id
×
287

288
    async def authorize_callback(self, state: str, raw_url: str, callback_url: str) -> str | None:
2✔
289
        """Performs the OAuth2 authorization callback.
290

291
        Returns the `next_url` parameter value the authorization flow was started with.
292
        """
293
        if not state:
1✔
294
            raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
295

296
        async with self.session_maker() as session, session.begin():
1✔
297
            result = await session.scalars(
1✔
298
                select(schemas.OAuth2ConnectionORM)
299
                .where(schemas.OAuth2ConnectionORM.state == state)
300
                .options(selectinload(schemas.OAuth2ConnectionORM.client))
301
            )
302
            connection = result.one_or_none()
1✔
303

304
            if connection is None:
1✔
305
                raise errors.ForbiddenError(message="You do not have the required permissions for this operation.")
×
306

307
            client = connection.client
1✔
308
            adapter = get_provider_adapter(client)
1✔
309
            client_secret = (
1✔
310
                decrypt_string(self.encryption_key, client.created_by_id, client.client_secret)
311
                if client.client_secret
312
                else None
313
            )
314
            code_verifier = connection.code_verifier
1✔
315
            code_challenge_method = "S256" if code_verifier else None
1✔
316
            async with self.async_oauth2_client_class(
1✔
317
                client_id=client.client_id,
318
                client_secret=client_secret,
319
                scope=client.scope,
320
                redirect_uri=callback_url,
321
                code_challenge_method=code_challenge_method,
322
                state=connection.state,
323
            ) as oauth2_client:
324
                token = await oauth2_client.fetch_token(
1✔
325
                    adapter.token_endpoint_url, authorization_response=raw_url, code_verifier=code_verifier
326
                )
327

328
                logger.info(f"Token for client {client.id} has keys: {', '.join(token.keys())}")
1✔
329

330
                next_url = connection.next_url
1✔
331

332
                connection.token = self._encrypt_token_set(token=token, user_id=connection.user_id)
1✔
333
                connection.state = None
1✔
334
                connection.status = models.ConnectionStatus.connected
1✔
335
                connection.next_url = None
1✔
336

337
                return next_url
1✔
338

339
    async def delete_oauth2_connection(self, user: base_models.APIUser, connection_id: ULID) -> bool:
2✔
340
        """Delete one connection of the given user."""
341
        if not user.is_authenticated or user.id is None:
1✔
342
            return False
×
343

344
        async with self.session_maker() as session, session.begin():
1✔
345
            result = await session.scalars(
1✔
346
                select(schemas.OAuth2ConnectionORM)
347
                .where(schemas.OAuth2ConnectionORM.id == connection_id)
348
                .where(schemas.OAuth2ConnectionORM.user_id == user.id)
349
            )
350
            conn = result.one_or_none()
1✔
351

352
            if conn is None:
1✔
353
                return False
1✔
354

355
            await session.delete(conn)
1✔
356
            return True
1✔
357

358
    async def get_oauth2_connections(
2✔
359
        self,
360
        user: base_models.APIUser,
361
    ) -> list[models.OAuth2Connection]:
362
        """Get all OAuth2 connections for the user from the database."""
363
        if not user.is_authenticated or user.id is None:
2✔
364
            return []
×
365

366
        async with self.session_maker() as session:
2✔
367
            result = await session.scalars(
2✔
368
                select(schemas.OAuth2ConnectionORM).where(schemas.OAuth2ConnectionORM.user_id == user.id)
369
            )
370
            connections = result.all()
2✔
371
            return [c.dump() for c in connections]
2✔
372

373
    async def get_oauth2_connection_or_none(
2✔
374
        self, connection_id: ULID, user: base_models.APIUser
375
    ) -> models.OAuth2Connection | None:
376
        """Get one OAuth2 connection from the database. Throw if the user is not authenticated."""
377
        if not user.is_authenticated or user.id is None:
×
378
            raise errors.MissingResourceError(
×
379
                message=f"OAuth2 connection with id '{connection_id}' does not exist or you do not have access to it."
380
            )
381

382
        async with self.session_maker() as session:
×
383
            result = await session.scalars(
×
384
                select(schemas.OAuth2ConnectionORM)
385
                .where(schemas.OAuth2ConnectionORM.id == connection_id)
386
                .where(schemas.OAuth2ConnectionORM.user_id == user.id)
387
            )
388
            connection = result.one_or_none()
×
389
            if connection:
×
390
                return connection.dump()
×
391
            else:
392
                return None
×
393

394
    async def get_oauth2_connection(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2Connection:
2✔
395
        """Get one OAuth2 connection from the database.
396

397
        Throw if the connection doesn't exist or the user is not authenticated.
398
        """
399
        connection = await self.get_oauth2_connection_or_none(connection_id, user)
×
400
        if connection is None:
×
401
            raise errors.MissingResourceError(
×
402
                message=f"OAuth2 connection with id '{connection_id}' does not exist or you do not have access to it."
403
            )
404

405
        return connection
×
406

407
    async def get_oauth2_connected_account(
2✔
408
        self, connection_id: ULID, user: base_models.APIUser
409
    ) -> models.ConnectedAccount:
410
        """Get the account information from a OAuth2 connection."""
411
        async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, adapter):
1✔
412
            request_url = urljoin(adapter.api_url, adapter.user_info_endpoint)
1✔
413
            try:
1✔
414
                if adapter.user_info_method == "POST":
1✔
415
                    response = await oauth2_client.post(request_url, headers=adapter.api_common_headers)
×
416
                else:
417
                    response = await oauth2_client.get(request_url, headers=adapter.api_common_headers)
1✔
418
            except InvalidTokenError as e:
×
419
                raise errors.UnauthorizedError(message="OAuth2 token for connected service invalid or expired.") from e
×
420

421
            if response.status_code > 200:
1✔
NEW
422
                raise errors.UnauthorizedError(message=f"Could not get account information.{response.text}")
×
423

424
            account = adapter.api_validate_account_response(response)
1✔
425
            return account
1✔
426

427
    async def get_non_oauth2_token(self, connection_id: ULID, user: base_models.APIUser) -> models.OAuth2TokenSet:
2✔
428
        """Return the connection token."""
NEW
429
        if not user.is_authenticated or user.id is None:
×
NEW
430
            raise errors.MissingResourceError(
×
431
                message=f"OAuth2 connection with id '{connection_id}' does not exist or you do not have access to it."
432
            )
433

NEW
434
        connection = await self._get_valid_connection(connection_id, user)
×
NEW
435
        token = self._decrypt_token_set(token=connection.token or {}, user_id=user.id)
×
NEW
436
        return token
×
437

438
    async def get_oauth2_connection_token(
2✔
439
        self, connection_id: ULID, user: base_models.APIUser
440
    ) -> models.OAuth2TokenSet:
441
        """Get the OAuth2 access token from one connection from the database."""
442
        async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (oauth2_client, _, _):
1✔
443
            try:
1✔
444
                await oauth2_client.ensure_active_token(oauth2_client.token)
1✔
445
            except OAuthError as err:
×
446
                if err.error == "bad_refresh_token":
×
447
                    raise errors.InvalidTokenError(
×
448
                        message="The refresh token for the connected service has expired or is invalid.",
449
                        detail=f"Please reconnect your integration for the service with ID {str(connection_id)} "
450
                        "and try again.",
451
                    ) from err
452
                raise
×
453
            token_model = models.OAuth2TokenSet.from_dict(oauth2_client.token)
1✔
454
            return token_model
1✔
455

456
    async def get_provider_for_image(self, user: APIUser, image: Image) -> models.ImageProvider | None:
2✔
457
        """Find a provider supporting the given an image."""
458
        registry_urls = [f"http://{image.hostname}", f"https://{image.hostname}"]
1✔
459
        async with self.session_maker() as session:
1✔
460
            stmt = (
1✔
461
                select(schemas.OAuth2ClientORM, schemas.OAuth2ConnectionORM)
462
                .join(
463
                    schemas.OAuth2ConnectionORM,
464
                    and_(
465
                        schemas.OAuth2ConnectionORM.client_id == schemas.OAuth2ClientORM.id,
466
                        schemas.OAuth2ConnectionORM.user_id == user.id,
467
                    ),
468
                    isouter=True,  # isouter makes it a left-join, not an outer join
469
                )
470
                .where(schemas.OAuth2ClientORM.image_registry_url.in_(registry_urls))
471
                .where(schemas.OAuth2ClientORM.kind.in_(self.supported_image_registry_providers))
472
                # there could be multiple matching - just take the first arbitrary 🤷
473
                .order_by(schemas.OAuth2ConnectionORM.updated_at.desc())
474
                .limit(1)
475
            )
476
            result = await session.execute(stmt)
1✔
477
            row = result.one_or_none()
1✔
478
            if row is None or row.OAuth2ClientORM is None:
1✔
479
                return None
1✔
480
            else:
481
                return models.ImageProvider(
1✔
482
                    row.OAuth2ClientORM.dump(),
483
                    models.ConnectedUser(row.OAuth2ConnectionORM.dump(), user)
484
                    if row.OAuth2ConnectionORM is not None
485
                    else None,
486
                    str(row.OAuth2ClientORM.image_registry_url),  # above query makes it non-nil
487
                )
488

489
    async def get_image_repo_client(self, image_provider: models.ImageProvider) -> ImageRepoDockerAPI:
2✔
490
        """Create a image repository client for the given user and image provider."""
491
        url = urlparse(image_provider.registry_url)
×
492
        repo_api = ImageRepoDockerAPI(hostname=url.netloc, scheme=url.scheme)
×
493
        if image_provider.is_connected():
×
494
            assert image_provider.connected_user is not None
×
495
            user = image_provider.connected_user.user
×
496
            conn = image_provider.connected_user.connection
×
NEW
497
            if image_provider.provider.kind == models.ProviderKind.dockerhub:
×
NEW
498
                token_set = await self.get_non_oauth2_token(conn.id, user)
×
499
            else:
NEW
500
                token_set = await self.get_oauth2_connection_token(conn.id, user)
×
501
            access_token = token_set.access_token
×
502
            if access_token:
×
NEW
503
                logger.debug(
×
504
                    f"Use connection {conn.id} to {image_provider.provider.id} for user {user.id}/{token_set.username}"
505
                )
NEW
506
                repo_api = repo_api.with_oauth2_token(access_token, token_set.username)
×
UNCOV
507
        return repo_api
×
508

509
    async def get_oauth2_app_installations(
2✔
510
        self, connection_id: ULID, user: base_models.APIUser, pagination: PaginationRequest
511
    ) -> models.AppInstallationList:
512
        """Get the installations from a OAuth2 connection."""
513
        async with self.get_async_oauth2_client(connection_id=connection_id, user=user) as (
1✔
514
            oauth2_client,
515
            connection,
516
            adapter,
517
        ):
518
            # NOTE: App installations are only available from GitHub
519
            if connection.client.kind == models.ProviderKind.github and isinstance(adapter, GitHubAdapter):
1✔
520
                request_url = urljoin(adapter.api_url, "user/installations")
1✔
521
                params = dict(page=pagination.page, per_page=pagination.per_page)
1✔
522
                try:
1✔
523
                    response = await oauth2_client.get(request_url, params=params, headers=adapter.api_common_headers)
1✔
524
                except OAuthError as err:
×
525
                    if err.error == "bad_refresh_token":
×
526
                        raise errors.InvalidTokenError(
×
527
                            message="The refresh token for the connected service has expired or is invalid.",
528
                            detail=f"Please reconnect your integration for the service with ID {str(connection_id)} "
529
                            "and try again.",
530
                        ) from err
531
                    raise
×
532

533
                if response.status_code > 200:
1✔
534
                    raise errors.UnauthorizedError(message="Could not get installation information.")
×
535

536
                return adapter.api_validate_app_installations_response(response)
1✔
537

538
            return models.AppInstallationList(total_count=0, installations=[])
1✔
539

540
    async def _get_valid_connection(
2✔
541
        self, connection_id: ULID, user: base_models.APIUser
542
    ) -> schemas.OAuth2ConnectionORM:
543
        """Return a valid, connected connection."""
544
        async with self.session_maker() as session:
1✔
545
            result = await session.scalars(
1✔
546
                select(schemas.OAuth2ConnectionORM)
547
                .where(schemas.OAuth2ConnectionORM.id == connection_id)
548
                .where(schemas.OAuth2ConnectionORM.user_id == user.id)
549
                .options(selectinload(schemas.OAuth2ConnectionORM.client))
550
            )
551
            connection = result.one_or_none()
1✔
552
            if connection is None:
1✔
553
                raise errors.MissingResourceError(
×
554
                    message=f"OAuth2 connection with id '{connection_id}' does not exist or you do not have access to it."  # noqa: E501
555
                )
556

557
            if connection.status != models.ConnectionStatus.connected or connection.token is None:
1✔
558
                raise errors.UnauthorizedError(message=f"OAuth2 connection with id '{connection_id}' is not valid.")
×
559
            return connection
1✔
560

561
    @asynccontextmanager
2✔
562
    async def get_async_oauth2_client(
2✔
563
        self, connection_id: ULID, user: base_models.APIUser
564
    ) -> AsyncGenerator[tuple[AsyncOAuth2Client, schemas.OAuth2ConnectionORM, ProviderAdapter], None]:
565
        """Get the AsyncOAuth2Client for the given connection_id and user."""
566
        if not user.is_authenticated or user.id is None:
1✔
NEW
567
            raise errors.MissingResourceError(
×
568
                message=f"OAuth2 connection with id '{connection_id}' does not exist or you do not have access to it."
569
            )
570
        connection = await self._get_valid_connection(connection_id, user)
1✔
571
        client = connection.client
1✔
572
        token = self._decrypt_token_set(token=connection.token or {}, user_id=user.id)
1✔
573

574
        async def update_token(token: dict[str, Any], refresh_token: str | None = None) -> None:
1✔
575
            if refresh_token is None:
×
576
                return
×
577
            async with self.session_maker() as session, session.begin():
×
578
                session.add(connection)
×
579
                await session.refresh(connection)
×
580
                connection.token = self._encrypt_token_set(token=token, user_id=connection.user_id)
×
581
                await session.flush()
×
582
                await session.refresh(connection)
×
583
                logger.info("Token refreshed!")
×
584

585
        adapter = get_provider_adapter(client)
1✔
586
        client_secret = (
1✔
587
            decrypt_string(self.encryption_key, client.created_by_id, client.client_secret)
588
            if client.client_secret
589
            else None
590
        )
591
        code_verifier = connection.code_verifier
1✔
592
        code_challenge_method = "S256" if code_verifier else None
1✔
593
        yield (
1✔
594
            self.async_oauth2_client_class(
595
                client_id=client.client_id,
596
                client_secret=client_secret,
597
                scope=client.scope,
598
                code_challenge_method=code_challenge_method,
599
                token_endpoint=adapter.token_endpoint_url,
600
                token=token,
601
                update_token=update_token,
602
            ),
603
            connection,
604
            adapter,
605
        )
606

607
    def _encrypt_token_set(self, token: dict[str, Any], user_id: str) -> models.OAuth2TokenSet:
2✔
608
        """Encrypts sensitive fields of token set before persisting at rest."""
609
        result = models.OAuth2TokenSet.from_dict(token)
2✔
610
        if result.access_token:
2✔
611
            result["access_token"] = b64encode(
2✔
612
                encrypt_string(self.encryption_key, user_id, result.access_token)
613
            ).decode("ascii")
614
        if result.refresh_token:
2✔
615
            result["refresh_token"] = b64encode(
1✔
616
                encrypt_string(self.encryption_key, user_id, result.refresh_token)
617
            ).decode("ascii")
618
        return result
2✔
619

620
    def _decrypt_token_set(self, token: dict[str, Any], user_id: str) -> models.OAuth2TokenSet:
2✔
621
        """Encrypts sensitive fields of token set before persisting at rest."""
622
        result = models.OAuth2TokenSet.from_dict(token)
1✔
623
        if result.access_token:
1✔
624
            result["access_token"] = decrypt_string(self.encryption_key, user_id, b64decode(result.access_token))
1✔
625
        if result.refresh_token:
1✔
626
            result["refresh_token"] = decrypt_string(self.encryption_key, user_id, b64decode(result.refresh_token))
1✔
627
        return result
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc