• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

localstack / localstack / 18580604313

17 Oct 2025 12:09AM UTC coverage: 86.896% (+0.01%) from 86.886%
18580604313

push

github

web-flow
APIGW: expand coverage for API Keys and Usage Plans (#13201)

Co-authored-by: Benjamin Simon <benjh.simon@gmail.com>

19 of 19 new or added lines in 1 file covered. (100.0%)

118 existing lines in 7 files now uncovered.

68346 of 78653 relevant lines covered (86.9%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.32
/localstack-core/localstack/services/s3/utils.py
1
import base64
1✔
2
import codecs
1✔
3
import datetime
1✔
4
import hashlib
1✔
5
import itertools
1✔
6
import logging
1✔
7
import re
1✔
8
import time
1✔
9
import zlib
1✔
10
from collections.abc import Mapping
1✔
11
from enum import StrEnum
1✔
12
from secrets import token_bytes
1✔
13
from typing import Any, Literal, NamedTuple, Protocol
1✔
14
from urllib import parse as urlparser
1✔
15
from zoneinfo import ZoneInfo
1✔
16

17
import xmltodict
1✔
18
from botocore.exceptions import ClientError
1✔
19
from botocore.utils import InvalidArnException
1✔
20

21
from localstack import config, constants
1✔
22
from localstack.aws.api import CommonServiceException, RequestContext
1✔
23
from localstack.aws.api.s3 import (
1✔
24
    AccessControlPolicy,
25
    BucketCannedACL,
26
    BucketName,
27
    ChecksumAlgorithm,
28
    ContentMD5,
29
    CopyObjectRequest,
30
    CopySource,
31
    ETag,
32
    GetObjectRequest,
33
    Grant,
34
    Grantee,
35
    HeadObjectRequest,
36
    InvalidArgument,
37
    InvalidRange,
38
    InvalidTag,
39
    LifecycleExpiration,
40
    LifecycleRule,
41
    LifecycleRules,
42
    Metadata,
43
    ObjectCannedACL,
44
    ObjectKey,
45
    ObjectSize,
46
    ObjectVersionId,
47
    Owner,
48
    Permission,
49
    PreconditionFailed,
50
    PutObjectRequest,
51
    SSEKMSKeyId,
52
    TaggingHeader,
53
    TagSet,
54
    UploadPartCopyRequest,
55
    UploadPartRequest,
56
)
57
from localstack.aws.api.s3 import Type as GranteeType
1✔
58
from localstack.aws.chain import HandlerChain
1✔
59
from localstack.aws.connect import connect_to
1✔
60
from localstack.http import Response
1✔
61
from localstack.services.s3 import checksums
1✔
62
from localstack.services.s3.constants import (
1✔
63
    ALL_USERS_ACL_GRANTEE,
64
    AUTHENTICATED_USERS_ACL_GRANTEE,
65
    CHECKSUM_ALGORITHMS,
66
    LOG_DELIVERY_ACL_GRANTEE,
67
    SIGNATURE_V2_PARAMS,
68
    SIGNATURE_V4_PARAMS,
69
    SYSTEM_METADATA_SETTABLE_HEADERS,
70
)
71
from localstack.services.s3.exceptions import InvalidRequest, MalformedXML
1✔
72
from localstack.utils.aws import arns
1✔
73
from localstack.utils.aws.arns import parse_arn
1✔
74
from localstack.utils.objects import singleton_factory
1✔
75
from localstack.utils.strings import (
1✔
76
    is_base64,
77
    to_bytes,
78
    to_str,
79
)
80
from localstack.utils.urls import localstack_host
1✔
81

82
LOG = logging.getLogger(__name__)
1✔
83

84
BUCKET_NAME_REGEX = (
1✔
85
    r"(?=^.{3,63}$)(?!^(\d+\.)+\d+$)"
86
    + r"(^(([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])\.)*([a-z0-9]|[a-z0-9][a-z0-9\-]*[a-z0-9])$)"
87
)
88

89
TAG_REGEX = re.compile(r"^[\w\s.:/=+\-@]*$")
1✔
90

91

92
S3_VIRTUAL_HOSTNAME_REGEX = (
1✔
93
    r"(?P<bucket>.*).s3.(?P<region>(?:us-gov|us|ap|ca|cn|eu|sa)-[a-z]+-\d)?.*"
94
)
95

96
_s3_virtual_host_regex = re.compile(S3_VIRTUAL_HOSTNAME_REGEX)
1✔
97

98

99
RFC1123 = "%a, %d %b %Y %H:%M:%S GMT"
1✔
100
_gmt_zone_info = ZoneInfo("GMT")
1✔
101

102

103
def s3_response_handler(chain: HandlerChain, context: RequestContext, response: Response):
1✔
104
    """
105
    This response handler is taking care of removing certain headers from S3 responses.
106
    We cannot handle this in the serializer, because the serializer handler calls `Response.update_from`, which does
107
    not allow you to remove headers, only add them.
108
    This handler can delete headers from the response.
109
    """
110
    # some requests, for example coming frome extensions, are flagged as S3 requests. This check confirms that it is
111
    # indeed truly an S3 request by checking if it parsed properly as an S3 operation
112
    if not context.service_operation:
1✔
113
        return
1✔
114

115
    # if AWS returns 204, it will not return a body, Content-Length and Content-Type
116
    # the web server is already taking care of deleting the body, but it's more explicit to remove it here
117
    if response.status_code == 204:
1✔
118
        response.data = b""
1✔
119
        response.headers.pop("Content-Type", None)
1✔
120
        response.headers.pop("Content-Length", None)
1✔
121

122
    elif (
1✔
123
        response.status_code == 200
124
        and context.request.method == "PUT"
125
        and response.headers.get("Content-Length") in (0, None)
126
    ):
127
        # AWS does not return a Content-Type if the Content-Length is 0
128
        response.headers.pop("Content-Type", None)
1✔
129

130

131
def get_owner_for_account_id(account_id: str):
1✔
132
    """
133
    This method returns the S3 Owner from the account id. for now, this is hardcoded as it was in moto, but we can then
134
    extend it to return different values depending on the account ID
135
    See https://docs.aws.amazon.com/AmazonS3/latest/API/API_Owner.html
136
    :param account_id: the owner account id
137
    :return: the Owner object containing the DisplayName and owner ID
138
    """
139
    return Owner(
1✔
140
        DisplayName="webfile",  # only in certain regions, see above
141
        ID="75aa57f09aa0c8caeab4f8c24e99d10f8e7faeebf76c078efc7c6caea54ba06a",
142
    )
143

144

145
def extract_bucket_key_version_id_from_copy_source(
1✔
146
    copy_source: CopySource,
147
) -> tuple[BucketName, ObjectKey, ObjectVersionId | None]:
148
    """
149
    Utility to parse bucket name, object key and optionally its versionId. It accepts the CopySource format:
150
    - <bucket-name/<object-key>?versionId=<version-id>, used for example in CopySource for CopyObject
151
    :param copy_source: the S3 CopySource to parse
152
    :return: parsed BucketName, ObjectKey and optionally VersionId
153
    """
154
    copy_source_parsed = urlparser.urlparse(copy_source)
1✔
155
    # we need to manually replace `+` character with a space character before URL decoding, because different languages
156
    # don't encode their URL the same way (%20 vs +), and Python doesn't unquote + into a space char
157
    parsed_path = urlparser.unquote(copy_source_parsed.path.replace("+", " ")).lstrip("/")
1✔
158

159
    if "/" not in parsed_path:
1✔
160
        raise InvalidArgument(
1✔
161
            "Invalid copy source object key",
162
            ArgumentName="x-amz-copy-source",
163
            ArgumentValue="x-amz-copy-source",
164
        )
165
    src_bucket, src_key = parsed_path.split("/", 1)
1✔
166
    src_version_id = urlparser.parse_qs(copy_source_parsed.query).get("versionId", [None])[0]
1✔
167

168
    return src_bucket, src_key, src_version_id
1✔
169

170

171
class ChecksumHash(Protocol):
1✔
172
    """
173
    This Protocol allows proper typing for different kind of hash used by S3 (hashlib.shaX, zlib.crc32 from
174
    S3CRC32Checksum, and botocore CrtCrc32cChecksum).
175
    """
176

177
    def digest(self) -> bytes: ...
1✔
178

179
    def update(self, value: bytes): ...
1✔
180

181

182
def get_s3_checksum_algorithm_from_request(
1✔
183
    request: PutObjectRequest | UploadPartRequest,
184
) -> ChecksumAlgorithm | None:
185
    checksum_algorithm: list[ChecksumAlgorithm] = [
1✔
186
        algo for algo in CHECKSUM_ALGORITHMS if request.get(f"Checksum{algo}")
187
    ]
188
    if not checksum_algorithm:
1✔
189
        return None
1✔
190

191
    if len(checksum_algorithm) > 1:
1✔
192
        raise InvalidRequest(
1✔
193
            "Expecting a single x-amz-checksum- header. Multiple checksum Types are not allowed."
194
        )
195

196
    return checksum_algorithm[0]
1✔
197

198

199
def get_s3_checksum_algorithm_from_trailing_headers(
1✔
200
    trailing_headers: str,
201
) -> ChecksumAlgorithm | None:
202
    checksum_algorithm: list[ChecksumAlgorithm] = [
1✔
203
        algo for algo in CHECKSUM_ALGORITHMS if f"x-amz-checksum-{algo.lower()}" in trailing_headers
204
    ]
205
    if not checksum_algorithm:
1✔
206
        return None
1✔
207

208
    if len(checksum_algorithm) > 1:
1✔
209
        raise InvalidRequest(
×
210
            "Expecting a single x-amz-checksum- header. Multiple checksum Types are not allowed."
211
        )
212

213
    return checksum_algorithm[0]
1✔
214

215

216
def get_s3_checksum(algorithm) -> ChecksumHash:
1✔
217
    match algorithm:
1✔
218
        case ChecksumAlgorithm.CRC32:
1✔
219
            return S3CRC32Checksum()
1✔
220

221
        case ChecksumAlgorithm.CRC32C:
1✔
222
            from botocore.httpchecksum import CrtCrc32cChecksum
1✔
223

224
            return CrtCrc32cChecksum()
1✔
225

226
        case ChecksumAlgorithm.CRC64NVME:
1✔
227
            from botocore.httpchecksum import CrtCrc64NvmeChecksum
1✔
228

229
            return CrtCrc64NvmeChecksum()
1✔
230

231
        case ChecksumAlgorithm.SHA1:
1✔
232
            return hashlib.sha1(usedforsecurity=False)
1✔
233

234
        case ChecksumAlgorithm.SHA256:
1✔
235
            return hashlib.sha256(usedforsecurity=False)
1✔
236

237
        case _:
×
238
            # TODO: check proper error? for now validated client side, need to check server response
239
            raise InvalidRequest("The value specified in the x-amz-trailer header is not supported")
×
240

241

242
class S3CRC32Checksum:
1✔
243
    """Implements a unified way of using zlib.crc32 compatible with hashlib.sha and botocore CrtCrc32cChecksum"""
244

245
    __slots__ = ["checksum"]
1✔
246

247
    def __init__(self):
1✔
248
        self.checksum = zlib.crc32(b"")
1✔
249

250
    def update(self, value: bytes):
1✔
251
        self.checksum = zlib.crc32(value, self.checksum)
1✔
252

253
    def digest(self) -> bytes:
1✔
254
        return self.checksum.to_bytes(4, "big")
1✔
255

256

257
class CombinedCrcHash:
1✔
258
    def __init__(self, checksum_type: ChecksumAlgorithm):
1✔
259
        match checksum_type:
1✔
260
            case ChecksumAlgorithm.CRC32:
1✔
261
                func = checksums.combine_crc32
1✔
262
            case ChecksumAlgorithm.CRC32C:
1✔
263
                func = checksums.combine_crc32c
1✔
264
            case ChecksumAlgorithm.CRC64NVME:
1✔
265
                func = checksums.combine_crc64_nvme
1✔
266
            case _:
×
267
                raise ValueError("You cannot combine SHA based checksums")
×
268

269
        self.combine_function = func
1✔
270
        self.checksum = b""
1✔
271

272
    def combine(self, value: bytes, object_len: int):
1✔
273
        if not self.checksum:
1✔
274
            self.checksum = value
1✔
275
            return
1✔
276

277
        self.checksum = self.combine_function(self.checksum, value, object_len)
1✔
278

279
    def digest(self):
1✔
280
        return self.checksum
1✔
281

282

283
class ObjectRange(NamedTuple):
1✔
284
    """
285
    NamedTuple representing a parsed Range header with the requested S3 object size
286
    https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Range
287
    """
288

289
    content_range: str  # the original Range header
1✔
290
    content_length: int  # the full requested object size
1✔
291
    begin: int  # the start of range
1✔
292
    end: int  # the end of the end
1✔
293

294

295
def parse_range_header(range_header: str, object_size: int) -> ObjectRange | None:
1✔
296
    """
297
    Takes a Range header, and returns a dataclass containing the necessary information to return only a slice of an
298
    S3 object. If the range header is invalid, we return None so that the request is treated as a regular request.
299
    :param range_header: a Range header
300
    :param object_size: the requested S3 object total size
301
    :return: ObjectRange or None if the Range header is invalid
302
    """
303
    last = object_size - 1
1✔
304
    try:
1✔
305
        _, rspec = range_header.split("=")
1✔
306
    except ValueError:
1✔
307
        return None
1✔
308
    if "," in rspec:
1✔
309
        return None
1✔
310

311
    try:
1✔
312
        begin, end = [int(i) if i else None for i in rspec.split("-")]
1✔
313
    except ValueError:
1✔
314
        # if we can't parse the Range header, S3 just treat the request as a non-range request
315
        return None
1✔
316

317
    if (begin is None and end == 0) or (begin is not None and begin > last):
1✔
318
        raise InvalidRange(
1✔
319
            "The requested range is not satisfiable",
320
            ActualObjectSize=str(object_size),
321
            RangeRequested=range_header,
322
        )
323

324
    if begin is not None:  # byte range
1✔
325
        end = last if end is None else min(end, last)
1✔
326
    elif end is not None:  # suffix byte range
1✔
327
        begin = object_size - min(end, object_size)
1✔
328
        end = last
1✔
329
    else:
330
        # Treat as non-range request
331
        return None
1✔
332

333
    if begin > min(end, last):
1✔
334
        # Treat as non-range request if after the logic is applied
335
        return None
1✔
336

337
    return ObjectRange(
1✔
338
        content_range=f"bytes {begin}-{end}/{object_size}",
339
        content_length=end - begin + 1,
340
        begin=begin,
341
        end=end,
342
    )
343

344

345
def parse_copy_source_range_header(copy_source_range: str, object_size: int) -> ObjectRange:
1✔
346
    """
347
    Takes a CopySourceRange parameter, and returns a dataclass containing the necessary information to return only a slice of an
348
    S3 object. The validation is much stricter than `parse_range_header`
349
    :param copy_source_range: a CopySourceRange parameter for UploadCopyPart
350
    :param object_size: the requested S3 object total size
351
    :raises InvalidArgument if the CopySourceRanger parameter does not follow validation
352
    :return: ObjectRange
353
    """
354
    last = object_size - 1
1✔
355
    try:
1✔
356
        _, rspec = copy_source_range.split("=")
1✔
357
    except ValueError:
1✔
358
        raise InvalidArgument(
1✔
359
            "The x-amz-copy-source-range value must be of the form bytes=first-last where first and last are the zero-based offsets of the first and last bytes to copy",
360
            ArgumentName="x-amz-copy-source-range",
361
            ArgumentValue=copy_source_range,
362
        )
363
    if "," in rspec:
1✔
364
        raise InvalidArgument(
1✔
365
            "The x-amz-copy-source-range value must be of the form bytes=first-last where first and last are the zero-based offsets of the first and last bytes to copy",
366
            ArgumentName="x-amz-copy-source-range",
367
            ArgumentValue=copy_source_range,
368
        )
369

370
    try:
1✔
371
        begin, end = [int(i) if i else None for i in rspec.split("-")]
1✔
372
    except ValueError:
1✔
373
        # if we can't parse the Range header, S3 just treat the request as a non-range request
374
        raise InvalidArgument(
1✔
375
            "The x-amz-copy-source-range value must be of the form bytes=first-last where first and last are the zero-based offsets of the first and last bytes to copy",
376
            ArgumentName="x-amz-copy-source-range",
377
            ArgumentValue=copy_source_range,
378
        )
379

380
    if begin is None or end is None or begin > end:
1✔
381
        raise InvalidArgument(
1✔
382
            "The x-amz-copy-source-range value must be of the form bytes=first-last where first and last are the zero-based offsets of the first and last bytes to copy",
383
            ArgumentName="x-amz-copy-source-range",
384
            ArgumentValue=copy_source_range,
385
        )
386

387
    if begin > last:
1✔
388
        # Treat as non-range request if after the logic is applied
389
        raise InvalidRequest(
1✔
390
            "The specified copy range is invalid for the source object size",
391
        )
392
    elif end > last:
1✔
393
        raise InvalidArgument(
1✔
394
            f"Range specified is not valid for source object of size: {object_size}",
395
            ArgumentName="x-amz-copy-source-range",
396
            ArgumentValue=copy_source_range,
397
        )
398

399
    return ObjectRange(
1✔
400
        content_range=f"bytes {begin}-{end}/{object_size}",
401
        content_length=end - begin + 1,
402
        begin=begin,
403
        end=end,
404
    )
405

406

407
def get_failed_upload_part_copy_source_preconditions(
1✔
408
    request: UploadPartCopyRequest, last_modified: datetime.datetime, etag: ETag
409
) -> str | None:
410
    """
411
    Utility which parses the conditions from a S3 UploadPartCopy request.
412
    Note: The order in which these conditions are checked if used in conjunction matters
413

414
    :param UploadPartCopyRequest request: The S3 UploadPartCopy request.
415
    :param datetime last_modified: The time the source object was last modified.
416
    :param ETag etag: The ETag of the source object.
417

418
    :returns: The name of the failed precondition.
419
    """
420
    if_match = request.get("CopySourceIfMatch")
1✔
421
    if_none_match = request.get("CopySourceIfNoneMatch")
1✔
422
    if_unmodified_since = request.get("CopySourceIfUnmodifiedSince")
1✔
423
    if_modified_since = request.get("CopySourceIfModifiedSince")
1✔
424

425
    if if_match:
1✔
426
        if if_match.strip('"') != etag.strip('"'):
1✔
427
            return "x-amz-copy-source-If-Match"
1✔
428
        if if_modified_since and if_modified_since > last_modified:
1✔
429
            return "x-amz-copy-source-If-Modified-Since"
×
430
        # CopySourceIfMatch is unaffected by CopySourceIfUnmodifiedSince so return early
431
        if if_unmodified_since:
1✔
432
            return None
1✔
433

434
    if if_unmodified_since and if_unmodified_since < last_modified:
1✔
435
        return "x-amz-copy-source-If-Unmodified-Since"
1✔
436

437
    if if_none_match and if_none_match.strip('"') == etag.strip('"'):
1✔
438
        return "x-amz-copy-source-If-None-Match"
1✔
439

440
    if if_modified_since and last_modified < if_modified_since < datetime.datetime.now(
1✔
441
        tz=_gmt_zone_info
442
    ):
443
        return "x-amz-copy-source-If-Modified-Since"
1✔
444

445

446
def get_full_default_bucket_location(bucket_name: BucketName) -> str:
1✔
447
    host_definition = localstack_host()
1✔
448
    if host_definition.host != constants.LOCALHOST_HOSTNAME:
1✔
449
        # the user has customised their LocalStack hostname, and may not support subdomains.
450
        # Return the location in path form.
451
        return f"{config.get_protocol()}://{host_definition.host_and_port()}/{bucket_name}/"
1✔
452
    else:
453
        return f"{config.get_protocol()}://{bucket_name}.s3.{host_definition.host_and_port()}/"
1✔
454

455

456
def etag_to_base_64_content_md5(etag: ETag) -> str:
1✔
457
    """
458
    Convert an ETag, representing a MD5 hexdigest (might be quoted), to its base64 encoded representation
459
    :param etag: an ETag, might be quoted
460
    :return: the base64 value
461
    """
462
    # get the bytes digest from the hexdigest
463
    byte_digest = codecs.decode(to_bytes(etag.strip('"')), "hex")
1✔
464
    return to_str(base64.b64encode(byte_digest))
1✔
465

466

467
def base_64_content_md5_to_etag(content_md5: ContentMD5) -> str | None:
1✔
468
    """
469
    Convert a ContentMD5 header, representing a base64 encoded representation of a MD5 binary digest to its ETag value,
470
    hex encoded
471
    :param content_md5: a ContentMD5 header, base64 encoded
472
    :return: the ETag value, hex coded MD5 digest, or None if the input is not valid b64 or the representation of a MD5
473
    hash
474
    """
475
    if not is_base64(content_md5):
1✔
476
        return None
1✔
477
    # get the hexdigest from the bytes digest
478
    byte_digest = base64.b64decode(content_md5)
1✔
479
    hex_digest = to_str(codecs.encode(byte_digest, "hex"))
1✔
480
    if len(hex_digest) != 32:
1✔
481
        return None
1✔
482

483
    return hex_digest
1✔
484

485

486
def is_presigned_url_request(context: RequestContext) -> bool:
1✔
487
    """
488
    Detects pre-signed URL from query string parameters
489
    Return True if any kind of presigned URL query string parameter is encountered
490
    :param context: the request context from the handler chain
491
    """
492
    # Detecting pre-sign url and checking signature
493
    query_parameters = context.request.args
1✔
494
    return any(p in query_parameters for p in SIGNATURE_V2_PARAMS) or any(
1✔
495
        p in query_parameters for p in SIGNATURE_V4_PARAMS
496
    )
497

498

499
def is_bucket_name_valid(bucket_name: str) -> bool:
1✔
500
    """
501
    ref. https://docs.aws.amazon.com/AmazonS3/latest/userguide/bucketnamingrules.html
502
    """
503
    return True if re.match(BUCKET_NAME_REGEX, bucket_name) else False
1✔
504

505

506
def get_permission_header_name(permission: Permission) -> str:
1✔
507
    return f"x-amz-grant-{permission.replace('_', '-').lower()}"
1✔
508

509

510
def get_permission_from_header(capitalized_field: str) -> Permission:
1✔
511
    headers_parts = [part.upper() for part in re.split(r"([A-Z][a-z]+)", capitalized_field) if part]
1✔
512
    return "_".join(headers_parts[1:])
1✔
513

514

515
def is_valid_canonical_id(canonical_id: str) -> bool:
1✔
516
    """
517
    Validate that the string is a hex string with 64 char
518
    """
519
    try:
1✔
520
        return int(canonical_id, 16) and len(canonical_id) == 64
1✔
521
    except ValueError:
1✔
522
        return False
1✔
523

524

525
def uses_host_addressing(headers: Mapping[str, str]) -> str | None:
1✔
526
    """
527
    Determines if the request is targeting S3 with virtual host addressing
528
    :param headers: the request headers
529
    :return: if the request targets S3 with virtual host addressing, returns the bucket name else None
530
    """
531
    host = headers.get("host", "")
1✔
532

533
    # try to extract the bucket from the hostname (the "in" check is a minor optimization, as the regex is very greedy)
534
    if ".s3." in host and (
1✔
535
        (match := _s3_virtual_host_regex.match(host)) and (bucket_name := match.group("bucket"))
536
    ):
537
        return bucket_name
1✔
538

539

540
def get_class_attrs_from_spec_class(spec_class: type[StrEnum]) -> set[str]:
1✔
541
    return {str(spec) for spec in spec_class}
1✔
542

543

544
def get_system_metadata_from_request(request: dict) -> Metadata:
1✔
545
    metadata: Metadata = {}
1✔
546

547
    for system_metadata_field in SYSTEM_METADATA_SETTABLE_HEADERS:
1✔
548
        if field_value := request.get(system_metadata_field):
1✔
549
            metadata[system_metadata_field] = field_value
1✔
550

551
    return metadata
1✔
552

553

554
def extract_bucket_name_and_key_from_headers_and_path(
1✔
555
    headers: dict[str, str], path: str
556
) -> tuple[str | None, str | None]:
557
    """
558
    Extract the bucket name and the object key from a request headers and path. This works with both virtual host
559
    and path style requests.
560
    :param headers: the request headers, used to get the Host
561
    :param path: the request path
562
    :return: if found, the bucket name and object key
563
    """
564
    bucket_name = None
1✔
565
    object_key = None
1✔
566
    host = headers.get("host", "")
1✔
567
    if ".s3" in host:
1✔
568
        vhost_match = _s3_virtual_host_regex.match(host)
1✔
569
        if vhost_match and vhost_match.group("bucket"):
1✔
570
            bucket_name = vhost_match.group("bucket") or None
1✔
571
            split = path.split("/", maxsplit=1)
1✔
572
            if len(split) > 1 and split[1]:
1✔
573
                object_key = split[1]
1✔
574
    else:
575
        path_without_params = path.partition("?")[0]
1✔
576
        split = path_without_params.split("/", maxsplit=2)
1✔
577
        bucket_name = split[1] or None
1✔
578
        if len(split) > 2:
1✔
579
            object_key = split[2]
1✔
580

581
    return bucket_name, object_key
1✔
582

583

584
def normalize_bucket_name(bucket_name):
1✔
585
    bucket_name = bucket_name or ""
1✔
586
    bucket_name = bucket_name.lower()
1✔
587
    return bucket_name
1✔
588

589

590
def get_bucket_and_key_from_s3_uri(s3_uri: str) -> tuple[str, str]:
1✔
591
    """
592
    Extracts the bucket name and key from s3 uri
593
    """
594
    output_bucket, _, output_key = s3_uri.removeprefix("s3://").partition("/")
1✔
595
    return output_bucket, output_key
1✔
596

597

598
def get_bucket_and_key_from_presign_url(presign_url: str) -> tuple[str, str]:
1✔
599
    """
600
    Extracts the bucket name and key from s3 presign url
601
    """
602
    parsed_url = urlparser.urlparse(presign_url)
1✔
603
    bucket = parsed_url.path.split("/")[1]
1✔
604
    key = "/".join(parsed_url.path.split("/")[2:]).split("?")[0]
1✔
605
    return bucket, key
1✔
606

607

608
def capitalize_header_name_from_snake_case(header_name: str) -> str:
1✔
609
    return "-".join([part.capitalize() for part in header_name.split("-")])
1✔
610

611

612
def get_kms_key_arn(kms_key: str, account_id: str, bucket_region: str) -> str | None:
1✔
613
    """
614
    In S3, the KMS key can be passed as a KeyId or a KeyArn. This method allows to always get the KeyArn from either.
615
    It can also validate if the key is in the same region, and raise an exception.
616
    :param kms_key: the KMS key id or ARN
617
    :param account_id: the bucket account id
618
    :param bucket_region: the bucket region
619
    :raise KMS.NotFoundException if the key is not in the same region
620
    :return: the key ARN if found and enabled
621
    """
622
    if not kms_key:
1✔
623
        return None
1✔
624
    try:
1✔
625
        parsed_arn = parse_arn(kms_key)
1✔
626
        key_region = parsed_arn["region"]
1✔
627
        # the KMS key should be in the same region as the bucket, we can raise an exception without calling KMS
628
        if bucket_region and key_region != bucket_region:
1✔
629
            raise CommonServiceException(
1✔
630
                code="KMS.NotFoundException", message=f"Invalid arn {key_region}"
631
            )
632

633
    except InvalidArnException:
1✔
634
        # if it fails, the passed ID is a UUID with no region data
635
        key_id = kms_key
1✔
636
        # recreate the ARN manually with the bucket region and bucket owner
637
        # if the KMS key is cross-account, user should provide an ARN and not a KeyId
638
        kms_key = arns.kms_key_arn(key_id=key_id, account_id=account_id, region_name=bucket_region)
1✔
639

640
    return kms_key
1✔
641

642

643
# TODO: replace Any by a replacement for S3Bucket, some kind of defined type?
644
def validate_kms_key_id(kms_key: str, bucket: Any) -> None:
1✔
645
    """
646
    Validate that the KMS key used to encrypt the object is valid
647
    :param kms_key: the KMS key id or ARN
648
    :param bucket: the targeted bucket
649
    :raise KMS.DisabledException if the key is disabled
650
    :raise KMS.NotFoundException if the key is not in the same region or does not exist
651
    :return: the key ARN if found and enabled
652
    """
653
    if hasattr(bucket, "region_name"):
1✔
UNCOV
654
        bucket_region = bucket.region_name
×
655
    else:
656
        bucket_region = bucket.bucket_region
1✔
657

658
    if hasattr(bucket, "account_id"):
1✔
UNCOV
659
        bucket_account_id = bucket.account_id
×
660
    else:
661
        bucket_account_id = bucket.bucket_account_id
1✔
662

663
    kms_key_arn = get_kms_key_arn(kms_key, bucket_account_id, bucket_region)
1✔
664

665
    # the KMS key should be in the same region as the bucket, create the client in the bucket region
666
    kms_client = connect_to(region_name=bucket_region).kms
1✔
667
    try:
1✔
668
        key = kms_client.describe_key(KeyId=kms_key_arn)
1✔
669
        if not key["KeyMetadata"]["Enabled"]:
1✔
670
            if key["KeyMetadata"]["KeyState"] == "PendingDeletion":
1✔
671
                raise CommonServiceException(
1✔
672
                    code="KMS.KMSInvalidStateException",
673
                    message=f"{key['KeyMetadata']['Arn']} is pending deletion.",
674
                )
675
            raise CommonServiceException(
1✔
676
                code="KMS.DisabledException", message=f"{key['KeyMetadata']['Arn']} is disabled."
677
            )
678

679
    except ClientError as e:
1✔
680
        if e.response["Error"]["Code"] == "NotFoundException":
1✔
681
            raise CommonServiceException(
1✔
682
                code="KMS.NotFoundException", message=e.response["Error"]["Message"]
683
            )
UNCOV
684
        raise
×
685

686

687
def create_s3_kms_managed_key_for_region(account_id: str, region_name: str) -> SSEKMSKeyId:
1✔
688
    kms_client = connect_to(aws_access_key_id=account_id, region_name=region_name).kms
1✔
689
    key = kms_client.create_key(
1✔
690
        Description="Default key that protects my S3 objects when no other key is defined"
691
    )
692

693
    return key["KeyMetadata"]["Arn"]
1✔
694

695

696
def rfc_1123_datetime(src: datetime.datetime) -> str:
1✔
697
    return src.strftime(RFC1123)
1✔
698

699

700
def str_to_rfc_1123_datetime(value: str) -> datetime.datetime:
1✔
701
    return datetime.datetime.strptime(value, RFC1123).replace(tzinfo=_gmt_zone_info)
1✔
702

703

704
def add_expiration_days_to_datetime(user_datatime: datetime.datetime, exp_days: int) -> str:
1✔
705
    """
706
    This adds expiration days to a datetime, rounding to the next day at midnight UTC.
707
    :param user_datatime: datetime object
708
    :param exp_days: provided days
709
    :return: return a datetime object, rounded to midnight, in string formatted to rfc_1123
710
    """
711
    rounded_datetime = user_datatime.replace(
1✔
712
        hour=0, minute=0, second=0, microsecond=0
713
    ) + datetime.timedelta(days=exp_days + 1)
714

715
    return rfc_1123_datetime(rounded_datetime)
1✔
716

717

718
def serialize_expiration_header(
1✔
719
    rule_id: str, lifecycle_exp: LifecycleExpiration, last_modified: datetime.datetime
720
):
721
    if exp_days := lifecycle_exp.get("Days"):
1✔
722
        # AWS round to the next day at midnight UTC
723
        exp_date = add_expiration_days_to_datetime(last_modified, exp_days)
1✔
724
    else:
725
        exp_date = rfc_1123_datetime(lifecycle_exp["Date"])
1✔
726

727
    return f'expiry-date="{exp_date}", rule-id="{rule_id}"'
1✔
728

729

730
def get_lifecycle_rule_from_object(
1✔
731
    lifecycle_conf_rules: LifecycleRules,
732
    object_key: ObjectKey,
733
    size: ObjectSize,
734
    object_tags: dict[str, str],
735
) -> LifecycleRule:
736
    for rule in lifecycle_conf_rules:
1✔
737
        if not (expiration := rule.get("Expiration")) or "ExpiredObjectDeleteMarker" in expiration:
1✔
738
            continue
1✔
739

740
        if not (rule_filter := rule.get("Filter")):
1✔
741
            return rule
1✔
742

743
        if and_rules := rule_filter.get("And"):
1✔
744
            if all(
1✔
745
                _match_lifecycle_filter(key, value, object_key, size, object_tags)
746
                for key, value in and_rules.items()
747
            ):
UNCOV
748
                return rule
×
749

750
        if any(
1✔
751
            _match_lifecycle_filter(key, value, object_key, size, object_tags)
752
            for key, value in rule_filter.items()
753
        ):
754
            # after validation, we can only one of `Prefix`, `Tag`, `ObjectSizeGreaterThan` or `ObjectSizeLessThan` in
755
            # the dict. Instead of manually checking, we can iterate of the only key and try to match it
756
            return rule
1✔
757

758

759
def _match_lifecycle_filter(
1✔
760
    filter_key: str,
761
    filter_value: str | int | dict[str, str],
762
    object_key: ObjectKey,
763
    size: ObjectSize,
764
    object_tags: dict[str, str],
765
):
766
    match filter_key:
1✔
767
        case "Prefix":
1✔
768
            return object_key.startswith(filter_value)
1✔
769
        case "Tag":
1✔
770
            return object_tags and object_tags.get(filter_value.get("Key")) == filter_value.get(
1✔
771
                "Value"
772
            )
773
        case "ObjectSizeGreaterThan":
1✔
774
            return size > filter_value
1✔
775
        case "ObjectSizeLessThan":
1✔
776
            return size < filter_value
1✔
777
        case "Tags":  # this is inside the `And` field
1✔
778
            return object_tags and all(
1✔
779
                object_tags.get(tag.get("Key")) == tag.get("Value") for tag in filter_value
780
            )
781

782

783
def parse_expiration_header(
1✔
784
    expiration_header: str,
785
) -> tuple[datetime.datetime | None, str | None]:
786
    try:
1✔
787
        header_values = dict(
1✔
788
            (p.strip('"') for p in v.split("=")) for v in expiration_header.split('", ')
789
        )
790
        expiration_date = str_to_rfc_1123_datetime(header_values["expiry-date"])
1✔
791
        return expiration_date, header_values["rule-id"]
1✔
792

793
    except (IndexError, ValueError, KeyError):
1✔
794
        return None, None
1✔
795

796

797
def validate_dict_fields(data: dict, required_fields: set, optional_fields: set = None):
1✔
798
    """
799
    Validate whether the `data` dict contains at least the required fields and not more than the union of the required
800
    and optional fields
801
    TODO: we could pass the TypedDict to also use its required/optional properties, but it could be sensitive to
802
     mistake/changes in the specs and not always right
803
    :param data: the dict we want to validate
804
    :param required_fields: a set containing the required fields
805
    :param optional_fields: a set containing the optional fields
806
    :return: bool, whether the dict is valid or not
807
    """
808
    if optional_fields is None:
1✔
809
        optional_fields = set()
1✔
810
    return (set_fields := set(data)) >= required_fields and set_fields <= (
1✔
811
        required_fields | optional_fields
812
    )
813

814

815
def parse_tagging_header(tagging_header: TaggingHeader) -> dict:
1✔
816
    try:
1✔
817
        parsed_tags = urlparser.parse_qs(tagging_header, keep_blank_values=True)
1✔
818
        tags: dict[str, str] = {}
1✔
819
        for key, val in parsed_tags.items():
1✔
820
            if len(val) != 1 or not TAG_REGEX.match(key) or not TAG_REGEX.match(val[0]):
1✔
821
                raise InvalidArgument(
1✔
822
                    "The header 'x-amz-tagging' shall be encoded as UTF-8 then URLEncoded URL query parameters without tag name duplicates.",
823
                    ArgumentName="x-amz-tagging",
824
                    ArgumentValue=tagging_header,
825
                )
826
            elif key.startswith("aws:"):
1✔
UNCOV
827
                raise
×
828
            tags[key] = val[0]
1✔
829
        return tags
1✔
830

831
    except ValueError:
1✔
UNCOV
832
        raise InvalidArgument(
×
833
            "The header 'x-amz-tagging' shall be encoded as UTF-8 then URLEncoded URL query parameters without tag name duplicates.",
834
            ArgumentName="x-amz-tagging",
835
            ArgumentValue=tagging_header,
836
        )
837

838

839
def validate_tag_set(tag_set: TagSet, type_set: Literal["bucket", "object"] = "bucket"):
1✔
840
    keys = set()
1✔
841
    for tag in tag_set:
1✔
842
        if set(tag) != {"Key", "Value"}:
1✔
UNCOV
843
            raise MalformedXML()
×
844

845
        key = tag["Key"]
1✔
846
        if key in keys:
1✔
847
            raise InvalidTag(
1✔
848
                "Cannot provide multiple Tags with the same key",
849
                TagKey=key,
850
            )
851

852
        if key.startswith("aws:"):
1✔
853
            if type_set == "bucket":
1✔
854
                message = "System tags cannot be added/updated by requester"
1✔
855
            else:
856
                message = "Your TagKey cannot be prefixed with aws:"
1✔
857
            raise InvalidTag(
1✔
858
                message,
859
                TagKey=key,
860
            )
861

862
        if not TAG_REGEX.match(key):
1✔
863
            raise InvalidTag(
1✔
864
                "The TagKey you have provided is invalid",
865
                TagKey=key,
866
            )
867
        elif not TAG_REGEX.match(tag["Value"]):
1✔
868
            raise InvalidTag(
1✔
869
                "The TagValue you have provided is invalid", TagKey=key, TagValue=tag["Value"]
870
            )
871

872
        keys.add(key)
1✔
873

874

875
def get_unique_key_id(
1✔
876
    bucket: BucketName, object_key: ObjectKey, version_id: ObjectVersionId
877
) -> str:
878
    return f"{bucket}/{object_key}/{version_id or 'null'}"
1✔
879

880

881
def get_retention_from_now(days: int = None, years: int = None) -> datetime.datetime:
1✔
882
    """
883
    This calculates a retention date from now, adding days or years to it
884
    :param days: provided days
885
    :param years: provided years, exclusive with days
886
    :return: return a datetime object
887
    """
888
    if not days and not years:
1✔
UNCOV
889
        raise ValueError("Either 'days' or 'years' needs to be provided")
×
890
    now = datetime.datetime.now(tz=_gmt_zone_info)
1✔
891
    if days:
1✔
892
        retention = now + datetime.timedelta(days=days)
1✔
893
    else:
UNCOV
894
        retention = now.replace(year=now.year + years)
×
895

896
    return retention
1✔
897

898

899
def get_failed_precondition_copy_source(
1✔
900
    request: CopyObjectRequest, last_modified: datetime.datetime, etag: ETag
901
) -> str | None:
902
    """
903
    Validate if the source object LastModified and ETag matches a precondition, and if it does, return the failed
904
    precondition
905
    # see https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html
906
    :param request: the CopyObjectRequest
907
    :param last_modified: source object LastModified
908
    :param etag: source object ETag
909
    :return str: the failed precondition to raise
910
    """
911
    if (cs_if_match := request.get("CopySourceIfMatch")) and etag.strip('"') != cs_if_match.strip(
1✔
912
        '"'
913
    ):
914
        return "x-amz-copy-source-If-Match"
1✔
915

916
    elif (
1✔
917
        cs_if_unmodified_since := request.get("CopySourceIfUnmodifiedSince")
918
    ) and last_modified > cs_if_unmodified_since:
919
        return "x-amz-copy-source-If-Unmodified-Since"
1✔
920

921
    elif (cs_if_none_match := request.get("CopySourceIfNoneMatch")) and etag.strip(
1✔
922
        '"'
923
    ) == cs_if_none_match.strip('"'):
924
        return "x-amz-copy-source-If-None-Match"
1✔
925

926
    elif (
1✔
927
        cs_if_modified_since := request.get("CopySourceIfModifiedSince")
928
    ) and last_modified < cs_if_modified_since < datetime.datetime.now(tz=_gmt_zone_info):
929
        return "x-amz-copy-source-If-Modified-Since"
1✔
930

931

932
def validate_failed_precondition(
1✔
933
    request: GetObjectRequest | HeadObjectRequest, last_modified: datetime.datetime, etag: ETag
934
) -> None:
935
    """
936
    Validate if the object LastModified and ETag matches a precondition, and if it does, return the failed
937
    precondition
938
    :param request: the GetObjectRequest or HeadObjectRequest
939
    :param last_modified: S3 object LastModified
940
    :param etag: S3 object ETag
941
    :raises PreconditionFailed
942
    :raises NotModified, 304 with an empty body
943
    """
944
    precondition_failed = None
1✔
945
    # last_modified needs to be rounded to a second so that strict equality can be enforced from a RFC1123 header
946
    last_modified = last_modified.replace(microsecond=0)
1✔
947
    if (if_match := request.get("IfMatch")) and etag != if_match.strip('"'):
1✔
948
        precondition_failed = "If-Match"
1✔
949

950
    elif (
1✔
951
        if_unmodified_since := request.get("IfUnmodifiedSince")
952
    ) and last_modified > if_unmodified_since:
953
        precondition_failed = "If-Unmodified-Since"
1✔
954

955
    if precondition_failed:
1✔
956
        raise PreconditionFailed(
1✔
957
            "At least one of the pre-conditions you specified did not hold",
958
            Condition=precondition_failed,
959
        )
960

961
    if ((if_none_match := request.get("IfNoneMatch")) and etag == if_none_match.strip('"')) or (
1✔
962
        (if_modified_since := request.get("IfModifiedSince"))
963
        and last_modified <= if_modified_since < datetime.datetime.now(tz=_gmt_zone_info)
964
    ):
965
        raise CommonServiceException(
1✔
966
            message="Not Modified",
967
            code="NotModified",
968
            status_code=304,
969
        )
970

971

972
def get_canned_acl(
1✔
973
    canned_acl: BucketCannedACL | ObjectCannedACL, owner: Owner
974
) -> AccessControlPolicy:
975
    """
976
    Return the proper Owner and Grants from a CannedACL
977
    See https://docs.aws.amazon.com/AmazonS3/latest/userguide/acl-overview.html#canned-acl
978
    :param canned_acl: an S3 CannedACL
979
    :param owner: the current owner of the bucket or object
980
    :return: an AccessControlPolicy containing the Grants and Owner
981
    """
982
    owner_grantee = Grantee(**owner, Type=GranteeType.CanonicalUser)
1✔
983
    grants = [Grant(Grantee=owner_grantee, Permission=Permission.FULL_CONTROL)]
1✔
984

985
    match canned_acl:
1✔
986
        case ObjectCannedACL.private:
1✔
987
            pass  # no other permissions
1✔
988
        case ObjectCannedACL.public_read:
1✔
989
            grants.append(Grant(Grantee=ALL_USERS_ACL_GRANTEE, Permission=Permission.READ))
1✔
990

991
        case ObjectCannedACL.public_read_write:
1✔
992
            grants.append(Grant(Grantee=ALL_USERS_ACL_GRANTEE, Permission=Permission.READ))
1✔
993
            grants.append(Grant(Grantee=ALL_USERS_ACL_GRANTEE, Permission=Permission.WRITE))
1✔
UNCOV
994
        case ObjectCannedACL.authenticated_read:
×
UNCOV
995
            grants.append(
×
996
                Grant(Grantee=AUTHENTICATED_USERS_ACL_GRANTEE, Permission=Permission.READ)
997
            )
UNCOV
998
        case ObjectCannedACL.bucket_owner_read:
×
UNCOV
999
            pass  # TODO: bucket owner ACL
×
UNCOV
1000
        case ObjectCannedACL.bucket_owner_full_control:
×
UNCOV
1001
            pass  # TODO: bucket owner ACL
×
UNCOV
1002
        case ObjectCannedACL.aws_exec_read:
×
1003
            pass  # TODO: bucket owner, EC2 Read
×
1004
        case BucketCannedACL.log_delivery_write:
×
UNCOV
1005
            grants.append(Grant(Grantee=LOG_DELIVERY_ACL_GRANTEE, Permission=Permission.READ_ACP))
×
UNCOV
1006
            grants.append(Grant(Grantee=LOG_DELIVERY_ACL_GRANTEE, Permission=Permission.WRITE))
×
1007

1008
    return AccessControlPolicy(Owner=owner, Grants=grants)
1✔
1009

1010

1011
def create_redirect_for_post_request(
1✔
1012
    base_redirect: str, bucket: BucketName, object_key: ObjectKey, etag: ETag
1013
):
1014
    """
1015
    POST requests can redirect if successful. It will take the URL provided and append query string parameters
1016
    (key, bucket and ETag). It needs to be a full URL.
1017
    :param base_redirect: the URL provided for redirection
1018
    :param bucket: bucket name
1019
    :param object_key: object key
1020
    :param etag: key ETag
1021
    :return: the URL provided with the new appended query string parameters
1022
    """
1023
    parts = urlparser.urlparse(base_redirect)
1✔
1024
    if not parts.netloc:
1✔
1025
        raise ValueError("The provided URL is not valid")
1✔
1026
    queryargs = urlparser.parse_qs(parts.query)
1✔
1027
    queryargs["key"] = [object_key]
1✔
1028
    queryargs["bucket"] = [bucket]
1✔
1029
    queryargs["etag"] = [etag]
1✔
1030
    redirect_queryargs = urlparser.urlencode(queryargs, doseq=True)
1✔
1031
    newparts = (
1✔
1032
        parts.scheme,
1033
        parts.netloc,
1034
        parts.path,
1035
        parts.params,
1036
        redirect_queryargs,
1037
        parts.fragment,
1038
    )
1039
    return urlparser.urlunparse(newparts)
1✔
1040

1041

1042
def parse_post_object_tagging_xml(tagging: str) -> dict | None:
1✔
1043
    try:
1✔
1044
        tag_set = {}
1✔
1045
        tags = xmltodict.parse(tagging)
1✔
1046
        xml_tags = tags.get("Tagging", {}).get("TagSet", {}).get("Tag", [])
1✔
1047
        if not xml_tags:
1✔
1048
            # if the Tagging does not respect the schema, just return
1049
            return
1✔
1050
        if not isinstance(xml_tags, list):
1✔
1051
            xml_tags = [xml_tags]
1✔
1052
        for tag in xml_tags:
1✔
1053
            tag_set[tag["Key"]] = tag["Value"]
1✔
1054

1055
        return tag_set
1✔
1056

1057
    except Exception:
1✔
1058
        raise MalformedXML()
1✔
1059

1060

1061
def generate_safe_version_id() -> str:
1✔
1062
    """
1063
    Generate a safe version id for XML rendering.
1064
    VersionId cannot have `-` in it, as it fails in XML
1065
    Combine an ever-increasing part in the 8 first characters, and a random element.
1066
    We need the sequence part in order to properly implement pagination around ListObjectVersions.
1067
    By prefixing the version-id with a global increasing number, we can sort the versions
1068
    :return: an S3 VersionId containing a timestamp part in the first 8 characters
1069
    """
1070
    tok = next(global_version_id_sequence()).to_bytes(length=6) + token_bytes(18)
1✔
1071
    return base64.b64encode(tok, altchars=b"._").rstrip(b"=").decode("ascii")
1✔
1072

1073

1074
@singleton_factory
1✔
1075
def global_version_id_sequence():
1✔
1076
    start = int(time.time() * 1000)
1✔
1077
    # itertools.count is thread safe over the GIL since its getAndIncrement operation is a single python bytecode op
1078
    return itertools.count(start)
1✔
1079

1080

1081
def is_version_older_than_other(version_id: str, other: str):
1✔
1082
    """
1083
    Compare the sequence part of a VersionId against the sequence part of a VersionIdMarker. Used for pagination
1084
    See `generate_safe_version_id`
1085
    """
1086
    return base64.b64decode(version_id, altchars=b"._") < base64.b64decode(other, altchars=b"._")
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc