• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

uc-cdis / fence / 25175224135

30 Apr 2026 03:49PM UTC coverage: 75.077% (-0.003%) from 75.08%
25175224135

Pull #1335

github

BinamB
add whitelist for cognito logout
Pull Request #1335: Add idp logout for cognito

8489 of 11307 relevant lines covered (75.08%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

63.93
fence/utils.py
1
import bcrypt
1✔
2
import collections
1✔
3
from functools import wraps
1✔
4
import logging
1✔
5
import json
1✔
6
from random import SystemRandom
1✔
7
import re
1✔
8
import string
1✔
9
import requests
1✔
10
from urllib.parse import urlencode, urlparse
1✔
11
from urllib.parse import parse_qs, urlsplit, urlunsplit
1✔
12
import sys
1✔
13

14
from cdislogging import get_logger
1✔
15
import flask
1✔
16
from werkzeug.datastructures import ImmutableMultiDict
1✔
17

18
from fence.models import Client, User, query_for_user
1✔
19
from fence.errors import NotFound, UserError
1✔
20
from fence.config import config
1✔
21
from authlib.oauth2.rfc6749.util import scope_to_list
1✔
22
from authlib.oauth2.rfc6749.errors import InvalidScopeError
1✔
23

24
rng = SystemRandom()
1✔
25
alphanumeric = string.ascii_uppercase + string.ascii_lowercase + string.digits
1✔
26
logger = get_logger(__name__)
1✔
27

28

29
def random_str(length):
1✔
30
    return "".join(rng.choice(alphanumeric) for _ in range(length))
1✔
31

32

33
def json_res(data):
1✔
34
    return flask.Response(json.dumps(data), mimetype="application/json")
×
35

36

37
def fetch_url_data(url: str, format: str, expected_status_code: int = 200) -> str:
1✔
38
    """
39
    Perform a GET request and return the raw response data.
40
    Using this function instead of making the request directly in the caller function allows us
41
    to mock the returned data in unit tests.
42

43
    Args:
44
        url (str): URL to GET
45

46
    Returns:
47
        str: raw response data
48
    """
49
    res = requests.get(url)
×
50
    assert res.status_code == expected_status_code, f"Unable to fetch data from '{url}'"
×
51
    if format == "text":
×
52
        return res.text
×
53
    elif format == "json":
×
54
        return res.json()
×
55
    raise Exception(f"Unknown 'fetch_url_data' format '{format}'")
×
56

57

58
def generate_client_credentials(confidential):
1✔
59
    """
60
    Generate a new client ID. If the client is confidential, also generate a new client secret.
61
    The unhashed secret should be returned to the user and the hashed secret should be stored
62
    in the database for later use.
63

64
    Args:
65
        confidential (bool): true if the client is confidential, false if it is public
66

67
    Returns:
68
        tuple: (client ID, unhashed client secret or None, hashed client secret or None)
69
    """
70
    client_id = random_str(40)
1✔
71
    client_secret = None
1✔
72
    hashed_secret = None
1✔
73
    if confidential:
1✔
74
        client_secret = random_str(55)
1✔
75
        hashed_secret = bcrypt.hashpw(
1✔
76
            client_secret.encode("utf-8"), bcrypt.gensalt()
77
        ).decode("utf-8")
78
    return client_id, client_secret, hashed_secret
1✔
79

80

81
def create_client(
1✔
82
    DB,
83
    username=None,
84
    urls=[],
85
    name="",
86
    description="",
87
    auto_approve=False,
88
    is_admin=False,
89
    grant_types=None,
90
    confidential=True,
91
    arborist=None,
92
    policies=None,
93
    allowed_scopes=None,
94
    expires_in=None,
95
):
96
    client_id, client_secret, hashed_secret = generate_client_credentials(confidential)
1✔
97
    if arborist is not None:
1✔
98
        arborist.create_client(client_id, policies)
×
99
    driver = get_SQLAlchemyDriver(DB)
1✔
100
    auth_method = "client_secret_basic" if confidential else "none"
1✔
101

102
    allowed_scopes = allowed_scopes or config["CLIENT_ALLOWED_SCOPES"]
1✔
103
    if not set(allowed_scopes).issubset(set(config["CLIENT_ALLOWED_SCOPES"])):
1✔
104
        raise ValueError(
1✔
105
            "Each allowed scope must be one of: {}".format(
106
                config["CLIENT_ALLOWED_SCOPES"]
107
            )
108
        )
109

110
    if "openid" not in allowed_scopes:
1✔
111
        allowed_scopes.append("openid")
1✔
112
        logger.warning('Adding required "openid" scope to list of allowed scopes.')
1✔
113

114
    with driver.session as s:
1✔
115
        user = None
1✔
116
        if username:
1✔
117
            user = query_for_user(session=s, username=username)
1✔
118
            if not user:
1✔
119
                user = User(username=username, is_admin=is_admin)
1✔
120
                s.add(user)
1✔
121

122
        if s.query(Client).filter(Client.name == name).first():
1✔
123
            if arborist is not None:
1✔
124
                arborist.delete_client(client_id)
×
125
            raise Exception("client {} already exists".format(name))
1✔
126

127
        client = Client(
1✔
128
            client_id=client_id,
129
            client_secret=hashed_secret,
130
            user=user,
131
            redirect_uris=urls,
132
            allowed_scopes=" ".join(allowed_scopes),
133
            description=description,
134
            name=name,
135
            auto_approve=auto_approve,
136
            grant_types=grant_types,
137
            is_confidential=confidential,
138
            token_endpoint_auth_method=auth_method,
139
            expires_in=expires_in,
140
        )
141
        s.add(client)
1✔
142
        s.commit()
1✔
143

144
    return client_id, client_secret
1✔
145

146

147
def hash_secret(f):
1✔
148
    @wraps(f)
×
149
    def wrapper(*args, **kwargs):
×
150
        has_secret = "client_secret" in flask.request.form
×
151
        has_client_id = "client_id" in flask.request.form
×
152
        if flask.request.form and has_secret and has_client_id:
×
153
            form = flask.request.form.to_dict()
×
154
            with flask.current_app.db.session as session:
×
155
                client = (
×
156
                    session.query(Client)
157
                    .filter(Client.client_id == form["client_id"])
158
                    .first()
159
                )
160
                if client:
×
161
                    form["client_secret"] = bcrypt.hashpw(
×
162
                        form["client_secret"].encode("utf-8"),
163
                        client.client_secret.encode("utf-8"),
164
                    ).decode("utf-8")
165
                flask.request.form = ImmutableMultiDict(form)
×
166

167
        return f(*args, **kwargs)
×
168

169
    return wrapper
×
170

171

172
def wrap_list_required(f):
1✔
173
    @wraps(f)
1✔
174
    def wrapper(d, *args, **kwargs):
1✔
175
        data_is_a_list = False
×
176
        if isinstance(d, list):
×
177
            d = {"data": d}
×
178
            data_is_a_list = True
×
179
        if not data_is_a_list:
×
180
            return f(d, *args, **kwargs)
×
181
        else:
182
            result = f(d, *args, **kwargs)
×
183
            return result["data"]
×
184

185
    return wrapper
1✔
186

187

188
@wrap_list_required
1✔
189
def convert_key(d, converter):
1✔
190
    if isinstance(d, str) or not isinstance(d, collections.Iterable):
×
191
        return d
×
192

193
    new = {}
×
194
    for k, v in d.items():
×
195
        new_v = v
×
196
        if isinstance(v, dict):
×
197
            new_v = convert_key(v, converter)
×
198
        elif isinstance(v, list):
×
199
            new_v = list()
×
200
            for x in v:
×
201
                new_v.append(convert_key(x, converter))
×
202
        new[converter(k)] = new_v
×
203
    return new
×
204

205

206
@wrap_list_required
1✔
207
def convert_value(d, converter):
1✔
208
    if isinstance(d, str) or not isinstance(d, collections.Iterable):
×
209
        return converter(d)
×
210

211
    new = {}
×
212
    for k, v in d.items():
×
213
        new_v = v
×
214
        if isinstance(v, dict):
×
215
            new_v = convert_value(v, converter)
×
216
        elif isinstance(v, list):
×
217
            new_v = list()
×
218
            for x in v:
×
219
                new_v.append(convert_value(x, converter))
×
220
        new[k] = converter(new_v)
×
221
    return new
×
222

223

224
def to_underscore(s):
1✔
225
    s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", s)
×
226
    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
×
227

228

229
def strip(s):
1✔
230
    if isinstance(s, str):
×
231
        return s.strip()
×
232
    return s
×
233

234

235
def clear_cookies(response):
1✔
236
    """
237
    Set all cookies to empty and expired.
238
    """
239
    for cookie_name in list(flask.request.cookies.keys()):
1✔
240
        response.set_cookie(key=cookie_name, value="", expires=0, httponly=True)
1✔
241

242

243
def get_error_params(error, description):
1✔
244
    params = ""
×
245
    if error:
×
246
        args = {"error": error, "error_description": description}
×
247
        params = urlencode(args)
×
248
    return params
×
249

250

251
def append_query_params(original_url, **kwargs):
1✔
252
    """
253
    Add additional query string arguments to the given url.
254

255
    Example call:
256
        new_url = append_query_params(
257
            original_url, error='this is an error',
258
            another_arg='this is another argument')
259
    """
260
    scheme, netloc, path, query_string, fragment = urlsplit(original_url)
1✔
261
    query_params = parse_qs(query_string)
1✔
262
    if kwargs is not None:
1✔
263
        for key, value in kwargs.items():
1✔
264
            query_params[key] = [value]
1✔
265

266
    new_query_string = urlencode(query_params, doseq=True)
1✔
267
    new_url = urlunsplit((scheme, netloc, path, new_query_string, fragment))
1✔
268
    return new_url
1✔
269

270

271
def split_url_and_query_params(url):
1✔
272
    scheme, netloc, path, query_string, fragment = urlsplit(url)
1✔
273
    query_params = parse_qs(query_string)
1✔
274
    url = urlunsplit((scheme, netloc, path, None, fragment))
1✔
275
    return url, query_params
1✔
276

277

278
def send_email(from_email, to_emails, subject, text, smtp_domain):
1✔
279
    """
280
    Send email to group of emails using mail gun api.
281

282
    https://app.mailgun.com/
283

284
    Args:
285
        from_email(str): from email
286
        to_emails(list): list of emails to receive the messages
287
        text(str): the text message
288
        smtp_domain(dict): smtp domain server
289

290
            {
291
                "smtp_hostname": "smtp.mailgun.org",
292
                "default_login": "postmaster@mailgun.planx-pla.net",
293
                "api_url": "https://api.mailgun.net/v3/mailgun.planx-pla.net",
294
                "smtp_password": "password", # pragma: allowlist secret
295
                "api_key": "api key" # pragma: allowlist secret
296
            }
297

298
    Returns:
299
        Http response
300

301
    Exceptions:
302
        KeyError
303

304
    """
305
    if smtp_domain not in config["GUN_MAIL"] or not config["GUN_MAIL"].get(
×
306
        smtp_domain
307
    ).get("smtp_password"):
308
        raise NotFound(
×
309
            "SMTP Domain '{}' does not exist in configuration for GUN_MAIL or "
310
            "smtp_password was not provided. "
311
            "Cannot send email.".format(smtp_domain)
312
        )
313

314
    api_key = config["GUN_MAIL"][smtp_domain].get("api_key", "")
×
315
    email_url = config["GUN_MAIL"][smtp_domain].get("api_url", "") + "/messages"
×
316

317
    return requests.post(
×
318
        email_url,
319
        auth=("api", api_key),
320
        data={"from": from_email, "to": to_emails, "subject": subject, "text": text},
321
    )
322

323

324
def get_valid_expiration_from_request(
1✔
325
    expiry_param="expires_in", max_limit=None, default=None
326
):
327
    """
328
    Thin wrapper around get_valid_expiration; looks for default query parameter "expires_in"
329
    in flask request, unless a different parameter name was specified.
330
    """
331
    return get_valid_expiration(
1✔
332
        flask.request.args.get(expiry_param), max_limit=max_limit, default=default
333
    )
334

335

336
def get_valid_expiration(requested_expiration, max_limit=None, default=None):
1✔
337
    """
338
    If requested_expiration is not a positive integer and not None, throw error.
339
    If max_limit is provided and requested_expiration exceeds max_limit,
340
      return max_limit.
341
    If requested_expiration is None, return default (which may also be None).
342
    Else return requested_expiration.
343
    """
344
    if requested_expiration is None:
1✔
345
        return default
1✔
346
    try:
1✔
347
        rv = int(requested_expiration)
1✔
348
        assert rv > 0
1✔
349
        if max_limit:
1✔
350
            rv = min(rv, max_limit)
1✔
351
        return rv
1✔
352
    except (ValueError, AssertionError):
1✔
353
        raise UserError(
1✔
354
            "Requested expiry must be a positive integer; instead got {}".format(
355
                requested_expiration
356
            )
357
        )
358

359

360
def _print_func_name(function):
1✔
361
    return "{}.{}".format(function.__module__, function.__name__)
1✔
362

363

364
def _print_kwargs(kwargs):
1✔
365
    return ", ".join("{}={}".format(k, repr(v)) for k, v in list(kwargs.items()))
1✔
366

367

368
def log_backoff_retry(details):
1✔
369
    args_str = ", ".join(map(str, details["args"]))
1✔
370
    kwargs_str = (
1✔
371
        (", " + _print_kwargs(details["kwargs"])) if details.get("kwargs") else ""
372
    )
373
    func_call_log = "{}({}{})".format(
1✔
374
        _print_func_name(details["target"]), args_str, kwargs_str
375
    )
376
    logging.warning(
1✔
377
        "backoff: call {func_call} delay {wait:0.1f} seconds after {tries} tries".format(
378
            func_call=func_call_log, **details
379
        )
380
    )
381

382

383
def log_backoff_giveup(details):
1✔
384
    args_str = ", ".join(map(str, details["args"]))
1✔
385
    kwargs_str = (
1✔
386
        (", " + _print_kwargs(details["kwargs"])) if details.get("kwargs") else ""
387
    )
388
    func_call_log = "{}({}{})".format(
1✔
389
        _print_func_name(details["target"]), args_str, kwargs_str
390
    )
391
    logging.error(
1✔
392
        "backoff: gave up call {func_call} after {tries} tries; exception: {exc}".format(
393
            func_call=func_call_log, exc=sys.exc_info(), **details
394
        )
395
    )
396

397

398
def exception_do_not_retry(error):
1✔
399
    def _is_status(code):
1✔
400
        return (
1✔
401
            str(getattr(error, "code", None)) == code
402
            or str(getattr(error, "status", None)) == code
403
            or str(getattr(error, "status_code", None)) == code
404
        )
405

406
    if _is_status("409") or _is_status("404"):
1✔
407
        return True
×
408

409
    return False
1✔
410

411

412
def get_from_cache(item_id, memory_cache, db_cache_table, db_cache_table_id_field="id"):
1✔
413
    """
414
    Attempt to get a cached item and store in memory cache from db if necessary.
415

416
    NOTE: This requires custom implementation for putting items in the db cache table.
417
    """
418
    # try to retrieve from local in-memory cache
419
    rv, expires_at = memory_cache.get(item_id, (None, 0))
×
420
    if expires_at > expiry:
×
421
        return rv
×
422

423
    # try to retrieve from database cache
424
    if hasattr(flask.current_app, "db"):  # we don't have db in startup
×
425
        with flask.current_app.db.session as session:
×
426
            cache = (
×
427
                session.query(db_cache_table)
428
                .filter(
429
                    getattr(db_cache_table, db_cache_table_id_field, None) == item_id
430
                )
431
                .first()
432
            )
433
            if cache and cache.expires_at and cache.expires_at > expiry:
×
434
                rv = dict(cache)
×
435

436
                # store in memory cache
437
                memory_cache[item_id] = rv, cache.expires_at
×
438
                return rv
×
439

440

441
def get_SQLAlchemyDriver(db_conn_url):
1✔
442
    from userdatamodel.driver import SQLAlchemyDriver
1✔
443

444
    # override userdatamodel's `setup_db` function which creates tables
445
    # and runs database migrations, because Alembic handles that now.
446
    # TODO move userdatamodel code to Fence and remove dependencies to it
447
    SQLAlchemyDriver.setup_db = lambda _: None
1✔
448
    return SQLAlchemyDriver(db_conn_url)
1✔
449

450

451
# Default settings to control usage of backoff library.
452
DEFAULT_BACKOFF_SETTINGS = {
1✔
453
    "on_backoff": log_backoff_retry,
454
    "on_giveup": log_backoff_giveup,
455
    "max_tries": config["DEFAULT_BACKOFF_SETTINGS_MAX_TRIES"],
456
    "giveup": exception_do_not_retry,
457
}
458

459

460
def validate_scopes(request_scopes, client):
1✔
461
    if not client:
1✔
462
        raise Exception("Client object is None")
×
463

464
    if request_scopes:
1✔
465
        scopes = scope_to_list(request_scopes)
1✔
466
        # can we get some debug logs here that log the client, what scopes they have, and what scopes were requested
467
        if not client.check_requested_scopes(set(scopes)):
1✔
468
            logger.debug(
1✔
469
                "Request Scope are "
470
                + " ".join(scopes)
471
                + " but client supported scopes are "
472
                + client.scope
473
            )
474
            raise InvalidScopeError("Failed to Authorize due to unsupported scope")
1✔
475

476
    return True
1✔
477

478

479
def strtobool(val: str) -> bool:
1✔
480
    """Convert a string representation of truth to true (1) or false (0).
481

482
    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
483
    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
484
    'val' is anything else.
485
    """
486
    val = val.lower()
1✔
487
    if val in ("y", "yes", "t", "true", "on", "1"):
1✔
488
        return True
1✔
489
    elif val in ("n", "no", "f", "false", "off", "0"):
1✔
490
        return False
1✔
491
    else:
492
        raise ValueError(f"invalid truth value {val!r}")
1✔
493

494

495
def allowed_login_redirects():
1✔
496
    """
497
    Determine which redirects a login redirect endpoint (``/login/google``, etc) should
498
    be allowed to redirect back to after login. By default this includes the base URL
499
    from this flask application, and also includes the redirect URLs registered for any
500
    OAuth clients.
501

502
    Return:
503
        List[str]: allowed redirect URLs
504
    """
505
    allowed = config.get("LOGIN_REDIRECT_WHITELIST", [])
1✔
506
    allowed.append(config["BASE_URL"])
1✔
507
    with flask.current_app.db.session as session:
1✔
508
        clients = session.query(Client).all()
1✔
509
        for client in clients:
1✔
510
            if isinstance(client.redirect_uris, list):
1✔
511
                allowed.extend(client.redirect_uris)
1✔
512
            elif isinstance(client.redirect_uris, str):
×
513
                allowed.append(client.redirect_uris)
×
514
    return {domain(url) for url in allowed}
1✔
515

516

517
def domain(url):
1✔
518
    """
519
    Return just the domain for a URL, no schema or path etc. This is to consistently
520
    compare different URLs from flask, the config, and from the user.
521
    """
522
    if not url:
1✔
523
        return ""
×
524
    if url.startswith("/"):
1✔
525
        return urlparse(config["BASE_URL"]).netloc
×
526
    return urlparse(url).netloc
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc