• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

localstack / localstack / 20324119909

17 Dec 2025 04:57PM UTC coverage: 86.929% (+0.01%) from 86.917%
20324119909

push

github

web-flow
Fix RPC v2 CBOR timestamp parsing for float (#13541)

2 of 2 new or added lines in 1 file covered. (100.0%)

21 existing lines in 1 file now uncovered.

70013 of 80540 relevant lines covered (86.93%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.14
/localstack-core/localstack/services/sns/v2/provider.py
1
import contextlib
1✔
2
import copy
1✔
3
import functools
1✔
4
import json
1✔
5
import logging
1✔
6
import re
1✔
7

8
from botocore.utils import InvalidArnException
1✔
9
from rolo import Request, Router, route
1✔
10

11
from localstack.aws.api import CommonServiceException, RequestContext
1✔
12
from localstack.aws.api.sns import (
1✔
13
    ActionsList,
14
    AmazonResourceName,
15
    BatchEntryIdsNotDistinctException,
16
    CheckIfPhoneNumberIsOptedOutResponse,
17
    ConfirmSubscriptionResponse,
18
    CreateEndpointResponse,
19
    CreatePlatformApplicationResponse,
20
    CreateTopicResponse,
21
    DelegatesList,
22
    Endpoint,
23
    EndpointDisabledException,
24
    GetDataProtectionPolicyResponse,
25
    GetEndpointAttributesResponse,
26
    GetPlatformApplicationAttributesResponse,
27
    GetSMSAttributesResponse,
28
    GetSubscriptionAttributesResponse,
29
    GetTopicAttributesResponse,
30
    InvalidParameterException,
31
    InvalidParameterValueException,
32
    ListEndpointsByPlatformApplicationResponse,
33
    ListPhoneNumbersOptedOutResponse,
34
    ListPlatformApplicationsResponse,
35
    ListString,
36
    ListSubscriptionsByTopicResponse,
37
    ListSubscriptionsResponse,
38
    ListTagsForResourceResponse,
39
    ListTopicsResponse,
40
    MapStringToString,
41
    MessageAttributeMap,
42
    NotFoundException,
43
    OptInPhoneNumberResponse,
44
    PhoneNumber,
45
    PlatformApplication,
46
    PublishBatchRequestEntryList,
47
    PublishBatchResponse,
48
    PublishBatchResultEntry,
49
    PublishResponse,
50
    SetSMSAttributesResponse,
51
    SnsApi,
52
    String,
53
    SubscribeResponse,
54
    Subscription,
55
    SubscriptionAttributesMap,
56
    TagKeyList,
57
    TagList,
58
    TagResourceResponse,
59
    TooManyEntriesInBatchRequestException,
60
    TopicAttributesMap,
61
    UntagResourceResponse,
62
    attributeName,
63
    attributeValue,
64
    authenticateOnUnsubscribe,
65
    endpoint,
66
    label,
67
    message,
68
    messageStructure,
69
    nextToken,
70
    protocol,
71
    string,
72
    subject,
73
    subscriptionARN,
74
    topicARN,
75
    topicName,
76
)
77
from localstack.constants import AWS_REGION_US_EAST_1, DEFAULT_AWS_ACCOUNT_ID
1✔
78
from localstack.http import Response
1✔
79
from localstack.services.edge import ROUTER
1✔
80
from localstack.services.plugins import ServiceLifecycleHook
1✔
81
from localstack.services.sns.analytics import internal_api_calls
1✔
82
from localstack.services.sns.certificate import SNS_SERVER_CERT
1✔
83
from localstack.services.sns.constants import (
1✔
84
    ATTR_TYPE_REGEX,
85
    DUMMY_SUBSCRIPTION_PRINCIPAL,
86
    MAXIMUM_MESSAGE_LENGTH,
87
    MSG_ATTR_NAME_REGEX,
88
    PLATFORM_ENDPOINT_MSGS_ENDPOINT,
89
    SMS_MSGS_ENDPOINT,
90
    SNS_CERT_ENDPOINT,
91
    SNS_PROTOCOLS,
92
    SUBSCRIPTION_TOKENS_ENDPOINT,
93
    VALID_APPLICATION_PLATFORMS,
94
    VALID_MSG_ATTR_NAME_CHARS,
95
    VALID_POLICY_ACTIONS,
96
    VALID_SUBSCRIPTION_ATTR_NAME,
97
)
98
from localstack.services.sns.filter import FilterPolicyValidator
1✔
99
from localstack.services.sns.publisher import (
1✔
100
    PublishDispatcher,
101
    SnsBatchPublishContext,
102
    SnsPublishContext,
103
)
104
from localstack.services.sns.v2.models import (
1✔
105
    SMS_ATTRIBUTE_NAMES,
106
    SMS_DEFAULT_SENDER_REGEX,
107
    SMS_TYPES,
108
    EndpointAttributeNames,
109
    PlatformApplicationDetails,
110
    PlatformEndpoint,
111
    SnsMessage,
112
    SnsMessageType,
113
    SnsStore,
114
    SnsSubscription,
115
    Topic,
116
    sns_stores,
117
)
118
from localstack.services.sns.v2.utils import (
1✔
119
    create_platform_endpoint_arn,
120
    create_subscription_arn,
121
    encode_subscription_token_with_region,
122
    get_next_page_token_from_arn,
123
    get_region_from_subscription_token,
124
    get_topic_subscriptions,
125
    is_valid_e164_number,
126
    parse_and_validate_platform_application_arn,
127
    parse_and_validate_topic_arn,
128
    validate_subscription_attribute,
129
)
130
from localstack.state import StateVisitor
1✔
131
from localstack.utils.aws.arns import (
1✔
132
    extract_account_id_from_arn,
133
    extract_region_from_arn,
134
    get_partition,
135
    parse_arn,
136
    sns_platform_application_arn,
137
    sns_topic_arn,
138
)
139
from localstack.utils.collections import PaginatedList, select_from_typed_dict
1✔
140
from localstack.utils.strings import to_bytes
1✔
141

142
# set up logger
143
LOG = logging.getLogger(__name__)
1✔
144

145
SNS_TOPIC_NAME_PATTERN_FIFO = r"^[a-zA-Z0-9_-]{1,256}\.fifo$"
1✔
146
SNS_TOPIC_NAME_PATTERN = r"^[a-zA-Z0-9_-]{1,256}$"
1✔
147

148

149
class SnsProvider(SnsApi, ServiceLifecycleHook):
1✔
150
    def __init__(self) -> None:
1✔
151
        super().__init__()
1✔
152
        self._publisher = PublishDispatcher()
1✔
153
        self._signature_cert_pem: str = SNS_SERVER_CERT
1✔
154

155
    def accept_state_visitor(self, visitor: StateVisitor):
1✔
UNCOV
156
        visitor.visit(sns_stores)
×
157

158
    def on_before_stop(self):
1✔
159
        self._publisher.shutdown()
1✔
160

161
    def on_after_init(self):
1✔
162
        # Allow sent platform endpoint messages to be retrieved from the SNS endpoint
163
        register_sns_api_resource(ROUTER)
1✔
164
        # add the route to serve the certificate used to validate message signatures
165
        ROUTER.add(self.get_signature_cert_pem_file)
1✔
166

167
    @route(SNS_CERT_ENDPOINT, methods=["GET"])
1✔
168
    def get_signature_cert_pem_file(self, request: Request):
1✔
169
        # see http://sns-public-resources.s3.amazonaws.com/SNS_Message_Signing_Release_Note_Jan_25_2011.pdf
170
        # see https://docs.aws.amazon.com/sns/latest/dg/sns-verify-signature-of-message.html
171
        return Response(self._signature_cert_pem, 200)
1✔
172

173
    ## Topic Operations
174

175
    def create_topic(
1✔
176
        self,
177
        context: RequestContext,
178
        name: topicName,
179
        attributes: TopicAttributesMap | None = None,
180
        tags: TagList | None = None,
181
        data_protection_policy: attributeValue | None = None,
182
        **kwargs,
183
    ) -> CreateTopicResponse:
184
        store = self.get_store(context.account_id, context.region)
1✔
185
        topic_arn = sns_topic_arn(
1✔
186
            topic_name=name, region_name=context.region, account_id=context.account_id
187
        )
188
        topic: Topic = store.topics.get(topic_arn)
1✔
189
        attributes = attributes or {}
1✔
190
        if topic:
1✔
191
            attrs = topic["attributes"]
1✔
192
            for k, v in attributes.values():
1✔
UNCOV
193
                if not attrs.get(k) or not attrs.get(k) == v:
×
194
                    # TODO:
UNCOV
195
                    raise InvalidParameterException("Fix this Exception message and type")
×
196
            tag_resource_success = _check_matching_tags(topic_arn, tags, store)
1✔
197
            if not tag_resource_success:
1✔
198
                raise InvalidParameterException(
1✔
199
                    "Invalid parameter: Tags Reason: Topic already exists with different tags"
200
                )
201
            return CreateTopicResponse(TopicArn=topic_arn)
1✔
202

203
        if attributes.get("FifoTopic") and attributes["FifoTopic"].lower() == "true":
1✔
204
            fifo_match = re.match(SNS_TOPIC_NAME_PATTERN_FIFO, name)
1✔
205
            if not fifo_match:
1✔
206
                # TODO: check this with a separate test
UNCOV
207
                raise InvalidParameterException(
×
208
                    "Fifo Topic names must end with .fifo and must be made up of only uppercase and lowercase ASCII letters, numbers, underscores, and hyphens, and must be between 1 and 256 characters long."
209
                )
210
        else:
211
            # AWS does not seem to save explicit settings of fifo = false
212

213
            attributes.pop("FifoTopic", None)
1✔
214
            name_match = re.match(SNS_TOPIC_NAME_PATTERN, name)
1✔
215
            if not name_match:
1✔
216
                raise InvalidParameterException("Invalid parameter: Topic Name")
1✔
217

218
        attributes["EffectiveDeliveryPolicy"] = _create_default_effective_delivery_policy()
1✔
219

220
        topic = _create_topic(name=name, attributes=attributes, context=context)
1✔
221
        if tags:
1✔
222
            self.tag_resource(context=context, resource_arn=topic_arn, tags=tags)
1✔
223

224
        store.topics[topic_arn] = topic
1✔
225

226
        return CreateTopicResponse(TopicArn=topic_arn)
1✔
227

228
    def get_topic_attributes(
1✔
229
        self, context: RequestContext, topic_arn: topicARN, **kwargs
230
    ) -> GetTopicAttributesResponse:
231
        topic: Topic = self._get_topic(arn=topic_arn, context=context)
1✔
232
        if topic:
1✔
233
            attributes = topic["attributes"]
1✔
234
            return GetTopicAttributesResponse(Attributes=attributes)
1✔
235
        else:
UNCOV
236
            raise NotFoundException("Topic does not exist")
×
237

238
    def delete_topic(self, context: RequestContext, topic_arn: topicARN, **kwargs) -> None:
1✔
239
        store = self.get_store(context.account_id, context.region)
1✔
240

241
        store.topics.pop(topic_arn, None)
1✔
242

243
    def list_topics(
1✔
244
        self, context: RequestContext, next_token: nextToken | None = None, **kwargs
245
    ) -> ListTopicsResponse:
246
        store = self.get_store(context.account_id, context.region)
1✔
247
        topics = [{"TopicArn": t["arn"]} for t in list(store.topics.values())]
1✔
248
        topics = PaginatedList(topics)
1✔
249
        page, nxt = topics.get_page(
1✔
250
            token_generator=lambda x: get_next_page_token_from_arn(x["TopicArn"]),
251
            next_token=next_token,
252
            page_size=100,
253
        )
254
        topics = {"Topics": page, "NextToken": nxt}
1✔
255
        return ListTopicsResponse(**topics)
1✔
256

257
    def set_topic_attributes(
1✔
258
        self,
259
        context: RequestContext,
260
        topic_arn: topicARN,
261
        attribute_name: attributeName,
262
        attribute_value: attributeValue | None = None,
263
        **kwargs,
264
    ) -> None:
265
        topic: Topic = self._get_topic(arn=topic_arn, context=context)
1✔
266
        if attribute_name == "FifoTopic":
1✔
267
            raise InvalidParameterException("Invalid parameter: AttributeName")
1✔
268
        topic["attributes"][attribute_name] = attribute_value
1✔
269

270
    ## Subscribe operations
271

272
    def subscribe(
1✔
273
        self,
274
        context: RequestContext,
275
        topic_arn: topicARN,
276
        protocol: protocol,
277
        endpoint: endpoint | None = None,
278
        attributes: SubscriptionAttributesMap | None = None,
279
        return_subscription_arn: bool | None = None,
280
        **kwargs,
281
    ) -> SubscribeResponse:
282
        parsed_topic_arn = parse_and_validate_topic_arn(topic_arn)
1✔
283
        if context.region != parsed_topic_arn["region"]:
1✔
284
            raise InvalidParameterException("Invalid parameter: TopicArn")
1✔
285

286
        store = self.get_store(account_id=parsed_topic_arn["account"], region=context.region)
1✔
287

288
        topic = self._get_topic(arn=topic_arn, context=context)
1✔
289
        topic_subscriptions = topic["subscriptions"]
1✔
290
        if not endpoint:
1✔
291
            # TODO: check AWS behaviour (because endpoint is optional)
UNCOV
292
            raise NotFoundException("Endpoint not specified in subscription")
×
293
        if protocol not in SNS_PROTOCOLS:
1✔
294
            raise InvalidParameterException(
1✔
295
                f"Invalid parameter: Amazon SNS does not support this protocol string: {protocol}"
296
            )
297
        elif protocol in ["http", "https"] and not endpoint.startswith(f"{protocol}://"):
1✔
UNCOV
298
            raise InvalidParameterException(
×
299
                "Invalid parameter: Endpoint must match the specified protocol"
300
            )
301
        elif protocol == "sms" and not is_valid_e164_number(endpoint):
1✔
302
            raise InvalidParameterException(f"Invalid SMS endpoint: {endpoint}")
1✔
303

304
        elif protocol == "sqs":
1✔
305
            try:
1✔
306
                parse_arn(endpoint)
1✔
307
            except InvalidArnException:
1✔
308
                raise InvalidParameterException("Invalid parameter: SQS endpoint ARN")
1✔
309

310
        elif protocol == "application":
1✔
311
            # TODO: Validate exact behaviour
312
            try:
1✔
313
                parse_arn(endpoint)
1✔
UNCOV
314
            except InvalidArnException:
×
UNCOV
315
                raise InvalidParameterException("Invalid parameter: ApplicationEndpoint ARN")
×
316

317
        if ".fifo" in endpoint and ".fifo" not in topic_arn:
1✔
318
            # TODO: move to sqs protocol block if possible
319
            raise InvalidParameterException(
1✔
320
                "Invalid parameter: Invalid parameter: Endpoint Reason: FIFO SQS Queues can not be subscribed to standard SNS topics"
321
            )
322

323
        sub_attributes = copy.deepcopy(attributes) if attributes else None
1✔
324
        if sub_attributes:
1✔
325
            for attr_name, attr_value in sub_attributes.items():
1✔
326
                validate_subscription_attribute(
1✔
327
                    attribute_name=attr_name,
328
                    attribute_value=attr_value,
329
                    topic_arn=topic_arn,
330
                    endpoint=endpoint,
331
                    is_subscribe_call=True,
332
                )
333
                if raw_msg_delivery := sub_attributes.get("RawMessageDelivery"):
1✔
334
                    sub_attributes["RawMessageDelivery"] = raw_msg_delivery.lower()
1✔
335

336
        # An endpoint may only be subscribed to a topic once. Subsequent
337
        # subscribe calls do nothing (subscribe is idempotent), except if its attributes are different.
338
        for existing_topic_subscription in topic_subscriptions:
1✔
339
            sub = store.subscriptions.get(existing_topic_subscription, {})
1✔
340
            if sub.get("Endpoint") == endpoint:
1✔
341
                if sub_attributes:
1✔
342
                    # validate the subscription attributes aren't different
343
                    for attr in VALID_SUBSCRIPTION_ATTR_NAME:
1✔
344
                        # if a new attribute is present and different from an existent one, raise
345
                        if (new_attr := sub_attributes.get(attr)) and sub.get(attr) != new_attr:
1✔
346
                            raise InvalidParameterException(
1✔
347
                                "Invalid parameter: Attributes Reason: Subscription already exists with different attributes"
348
                            )
349

350
                return SubscribeResponse(SubscriptionArn=sub["SubscriptionArn"])
1✔
351
        principal = DUMMY_SUBSCRIPTION_PRINCIPAL.format(
1✔
352
            partition=get_partition(context.region), account_id=context.account_id
353
        )
354
        subscription_arn = create_subscription_arn(topic_arn)
1✔
355
        subscription = SnsSubscription(
1✔
356
            # http://docs.aws.amazon.com/cli/latest/reference/sns/get-subscription-attributes.html
357
            TopicArn=topic_arn,
358
            Endpoint=endpoint,
359
            Protocol=protocol,
360
            SubscriptionArn=subscription_arn,
361
            PendingConfirmation="true",
362
            Owner=context.account_id,
363
            RawMessageDelivery="false",  # default value, will be overridden if set
364
            FilterPolicyScope="MessageAttributes",  # default value, will be overridden if set
365
            SubscriptionPrincipal=principal,  # dummy value, could be fetched with a call to STS?
366
        )
367
        if sub_attributes:
1✔
368
            subscription.update(sub_attributes)
1✔
369
            if "FilterPolicy" in sub_attributes:
1✔
370
                filter_policy = (
1✔
371
                    json.loads(sub_attributes["FilterPolicy"])
372
                    if sub_attributes["FilterPolicy"]
373
                    else None
374
                )
375
                if filter_policy:
1✔
376
                    validator = FilterPolicyValidator(
1✔
377
                        scope=subscription.get("FilterPolicyScope", "MessageAttributes"),
378
                        is_subscribe_call=True,
379
                    )
380
                    validator.validate_filter_policy(filter_policy)
1✔
381

382
                store.subscription_filter_policy[subscription_arn] = filter_policy
1✔
383

384
        store.subscriptions[subscription_arn] = subscription
1✔
385

386
        topic_subscriptions.append(subscription_arn)
1✔
387

388
        # store the token and subscription arn
389
        # TODO: the token is a 288 hex char string
390
        subscription_token = encode_subscription_token_with_region(region=context.region)
1✔
391
        store.subscription_tokens[subscription_token] = subscription_arn
1✔
392

393
        response_subscription_arn = subscription_arn
1✔
394
        # Send out confirmation message for HTTP(S), fix for https://github.com/localstack/localstack/issues/881
395
        if protocol in ["http", "https"]:
1✔
396
            message_ctx = SnsMessage(
1✔
397
                type=SnsMessageType.SubscriptionConfirmation,
398
                token=subscription_token,
399
                message=f"You have chosen to subscribe to the topic {topic_arn}.\nTo confirm the subscription, visit the SubscribeURL included in this message.",
400
            )
401
            publish_ctx = SnsPublishContext(
1✔
402
                message=message_ctx,
403
                store=store,
404
                request_headers=context.request.headers,
405
                topic_attributes=topic["attributes"],
406
            )
407
            self._publisher.publish_to_topic_subscriber(
1✔
408
                ctx=publish_ctx,
409
                topic_arn=topic_arn,
410
                subscription_arn=subscription_arn,
411
            )
412
            if not return_subscription_arn:
1✔
413
                response_subscription_arn = "pending confirmation"
1✔
414

415
        elif protocol not in ["email", "email-json"]:
1✔
416
            # Only HTTP(S) and email subscriptions are not auto validated
417
            # Except if the endpoint and the topic are not in the same AWS account, then you'd need to manually confirm
418
            # the subscription with the token
419
            # TODO: revisit for multi-account
420
            # TODO: test with AWS for email & email-json confirmation message
421
            # we need to add the following check:
422
            # if parsed_topic_arn["account"] == endpoint account (depending on the type, SQS, lambda, parse the arn)
423
            subscription["PendingConfirmation"] = "false"
1✔
424
            subscription["ConfirmationWasAuthenticated"] = "true"
1✔
425

426
        return SubscribeResponse(SubscriptionArn=response_subscription_arn)
1✔
427

428
    def unsubscribe(
1✔
429
        self, context: RequestContext, subscription_arn: subscriptionARN, **kwargs
430
    ) -> None:
431
        if subscription_arn is None:
1✔
432
            raise InvalidParameterException(
1✔
433
                "Invalid parameter: SubscriptionArn Reason: no value for required parameter",
434
            )
435
        count = len(subscription_arn.split(":"))
1✔
436
        try:
1✔
437
            parsed_arn = parse_arn(subscription_arn)
1✔
438
        except InvalidArnException:
1✔
439
            # TODO: check for invalid SubscriptionGUID
440
            raise InvalidParameterException(
1✔
441
                f"Invalid parameter: SubscriptionArn Reason: An ARN must have at least 6 elements, not {count}"
442
            )
443

444
        account_id = parsed_arn["account"]
1✔
445
        region_name = parsed_arn["region"]
1✔
446

447
        store = self.get_store(account_id=account_id, region=region_name)
1✔
448
        if count == 6 and subscription_arn not in store.subscriptions:
1✔
449
            raise InvalidParameterException("Invalid parameter: SubscriptionId")
1✔
450

451
        # TODO: here was a moto_backend.unsubscribe call, check correct functionality and remove this comment
452
        #  before switching to v2 for production
453

454
        # pop the subscription at the end, to avoid race condition by iterating over the topic subscriptions
455
        subscription = store.subscriptions.get(subscription_arn)
1✔
456

457
        if not subscription:
1✔
458
            # unsubscribe is idempotent, so unsubscribing from a non-existing topic does nothing
459
            return
1✔
460

461
        if subscription["Protocol"] in ["http", "https"]:
1✔
462
            # TODO: actually validate this (re)subscribe behaviour somehow (localhost.run?)
463
            #  we might need to save the sub token in the store
464
            # TODO: AWS only sends the UnsubscribeConfirmation if the call is unauthenticated or the requester is not
465
            #  the owner
466
            subscription_token = encode_subscription_token_with_region(region=context.region)
1✔
467
            message_ctx = SnsMessage(
1✔
468
                type=SnsMessageType.UnsubscribeConfirmation,
469
                token=subscription_token,
470
                message=f"You have chosen to deactivate subscription {subscription_arn}.\nTo cancel this operation and restore the subscription, visit the SubscribeURL included in this message.",
471
            )
472
            publish_ctx = SnsPublishContext(
1✔
473
                message=message_ctx,
474
                store=store,
475
                request_headers=context.request.headers,
476
                # TODO: add the topic attributes once we ported them from moto to LocalStack
477
                # topic_attributes=vars(moto_topic),
478
            )
479
            self._publisher.publish_to_topic_subscriber(
1✔
480
                publish_ctx,
481
                topic_arn=subscription["TopicArn"],
482
                subscription_arn=subscription_arn,
483
            )
484

485
        with contextlib.suppress(KeyError):
1✔
486
            store.topics[subscription["TopicArn"]]["subscriptions"].remove(subscription_arn)
1✔
487
        store.subscription_filter_policy.pop(subscription_arn, None)
1✔
488
        store.subscriptions.pop(subscription_arn, None)
1✔
489

490
    def get_subscription_attributes(
1✔
491
        self, context: RequestContext, subscription_arn: subscriptionARN, **kwargs
492
    ) -> GetSubscriptionAttributesResponse:
493
        store = self.get_store(account_id=context.account_id, region=context.region)
1✔
494
        sub = store.subscriptions.get(subscription_arn)
1✔
495
        if not sub:
1✔
496
            raise NotFoundException("Subscription does not exist")
1✔
497
        removed_attrs = ["sqs_queue_url"]
1✔
498
        if "FilterPolicyScope" in sub and not sub.get("FilterPolicy"):
1✔
499
            removed_attrs.append("FilterPolicyScope")
1✔
500
            removed_attrs.append("FilterPolicy")
1✔
501
        elif "FilterPolicy" in sub and "FilterPolicyScope" not in sub:
1✔
UNCOV
502
            sub["FilterPolicyScope"] = "MessageAttributes"
×
503

504
        attributes = {k: v for k, v in sub.items() if k not in removed_attrs}
1✔
505
        return GetSubscriptionAttributesResponse(Attributes=attributes)
1✔
506

507
    def set_subscription_attributes(
1✔
508
        self,
509
        context: RequestContext,
510
        subscription_arn: subscriptionARN,
511
        attribute_name: attributeName,
512
        attribute_value: attributeValue = None,
513
        **kwargs,
514
    ) -> None:
515
        store = self.get_store(account_id=context.account_id, region=context.region)
1✔
516
        sub = store.subscriptions.get(subscription_arn)
1✔
517
        if not sub:
1✔
518
            raise NotFoundException("Subscription does not exist")
1✔
519

520
        validate_subscription_attribute(
1✔
521
            attribute_name=attribute_name,
522
            attribute_value=attribute_value,
523
            topic_arn=sub["TopicArn"],
524
            endpoint=sub["Endpoint"],
525
        )
526
        if attribute_name == "RawMessageDelivery":
1✔
527
            attribute_value = attribute_value.lower()
1✔
528

529
        elif attribute_name == "FilterPolicy":
1✔
530
            filter_policy = json.loads(attribute_value) if attribute_value else None
1✔
531
            if filter_policy:
1✔
532
                validator = FilterPolicyValidator(
1✔
533
                    scope=sub.get("FilterPolicyScope", "MessageAttributes"),
534
                    is_subscribe_call=False,
535
                )
536
                validator.validate_filter_policy(filter_policy)
1✔
537

538
            store.subscription_filter_policy[subscription_arn] = filter_policy
1✔
539

540
        sub[attribute_name] = attribute_value
1✔
541

542
    def confirm_subscription(
1✔
543
        self,
544
        context: RequestContext,
545
        topic_arn: topicARN,
546
        token: String,
547
        authenticate_on_unsubscribe: authenticateOnUnsubscribe = None,
548
        **kwargs,
549
    ) -> ConfirmSubscriptionResponse:
550
        # TODO: validate format on the token (seems to be 288 hex chars)
551
        # this request can come from any http client, it might not be signed (we would need to implement
552
        # `authenticate_on_unsubscribe` to force a signing client to do this request.
553
        # so, the region and account_id might not be in the request. Use the ones from the topic_arn
554
        try:
1✔
555
            parsed_arn = parse_arn(topic_arn)
1✔
556
        except InvalidArnException:
1✔
557
            raise InvalidParameterException("Invalid parameter: Topic")
1✔
558

559
        store = self.get_store(account_id=parsed_arn["account"], region=parsed_arn["region"])
1✔
560

561
        # it seems SNS is able to know what the region of the topic should be, even though a wrong topic is accepted
562
        if parsed_arn["region"] != get_region_from_subscription_token(token):
1✔
563
            raise InvalidParameterException("Invalid parameter: Topic")
1✔
564

565
        subscription_arn = store.subscription_tokens.get(token)
1✔
566
        if not subscription_arn:
1✔
UNCOV
567
            raise InvalidParameterException("Invalid parameter: Token")
×
568

569
        subscription = store.subscriptions.get(subscription_arn)
1✔
570
        if not subscription:
1✔
571
            # subscription could have been deleted in the meantime
UNCOV
572
            raise InvalidParameterException("Invalid parameter: Token")
×
573

574
        # ConfirmSubscription is idempotent
575
        if subscription.get("PendingConfirmation") == "false":
1✔
576
            return ConfirmSubscriptionResponse(SubscriptionArn=subscription_arn)
1✔
577

578
        subscription["PendingConfirmation"] = "false"
1✔
579
        subscription["ConfirmationWasAuthenticated"] = "true"
1✔
580

581
        return ConfirmSubscriptionResponse(SubscriptionArn=subscription_arn)
1✔
582

583
    def list_subscriptions(
1✔
584
        self, context: RequestContext, next_token: nextToken = None, **kwargs
585
    ) -> ListSubscriptionsResponse:
586
        store = self.get_store(context.account_id, context.region)
1✔
587
        subscriptions = [
1✔
588
            select_from_typed_dict(Subscription, sub) for sub in list(store.subscriptions.values())
589
        ]
590
        paginated_subscriptions = PaginatedList(subscriptions)
1✔
591
        page, next_token = paginated_subscriptions.get_page(
1✔
592
            token_generator=lambda x: get_next_page_token_from_arn(x["SubscriptionArn"]),
593
            page_size=100,
594
            next_token=next_token,
595
        )
596

597
        response = ListSubscriptionsResponse(Subscriptions=page)
1✔
598
        if next_token:
1✔
599
            response["NextToken"] = next_token
1✔
600
        return response
1✔
601

602
    def list_subscriptions_by_topic(
1✔
603
        self, context: RequestContext, topic_arn: topicARN, next_token: nextToken = None, **kwargs
604
    ) -> ListSubscriptionsByTopicResponse:
605
        self._get_topic(topic_arn, context)  # for validation purposes only
1✔
606
        parsed_topic_arn = parse_and_validate_topic_arn(topic_arn)
1✔
607
        store = self.get_store(parsed_topic_arn["account"], parsed_topic_arn["region"])
1✔
608
        subscriptions = get_topic_subscriptions(store, topic_arn)
1✔
609

610
        paginated_subscriptions = PaginatedList(subscriptions)
1✔
611
        page, next_token = paginated_subscriptions.get_page(
1✔
612
            token_generator=lambda x: get_next_page_token_from_arn(x["SubscriptionArn"]),
613
            page_size=100,
614
            next_token=next_token,
615
        )
616

617
        response = ListSubscriptionsResponse(Subscriptions=page)
1✔
618
        if next_token:
1✔
619
            response["NextToken"] = next_token
1✔
620
        return response
1✔
621

622
    #
623
    # Publish
624
    #
625

626
    def publish(
1✔
627
        self,
628
        context: RequestContext,
629
        message: message,
630
        topic_arn: topicARN | None = None,
631
        target_arn: String | None = None,
632
        phone_number: PhoneNumber | None = None,
633
        subject: subject | None = None,
634
        message_structure: messageStructure | None = None,
635
        message_attributes: MessageAttributeMap | None = None,
636
        message_deduplication_id: String | None = None,
637
        message_group_id: String | None = None,
638
        **kwargs,
639
    ) -> PublishResponse:
640
        if subject == "":
1✔
641
            raise InvalidParameterException("Invalid parameter: Subject")
1✔
642
        if not message or all(not m for m in message):
1✔
643
            raise InvalidParameterException("Invalid parameter: Empty message")
1✔
644

645
        # TODO: check for topic + target + phone number at the same time?
646
        # TODO: more validation on phone, it might be opted out?
647
        if phone_number and not is_valid_e164_number(phone_number):
1✔
648
            raise InvalidParameterException(
1✔
649
                f"Invalid parameter: PhoneNumber Reason: {phone_number} is not valid to publish to"
650
            )
651

652
        if message_attributes:
1✔
653
            _validate_message_attributes(message_attributes)
1✔
654

655
        if _get_total_publish_size(message, message_attributes) > MAXIMUM_MESSAGE_LENGTH:
1✔
656
            raise InvalidParameterException("Invalid parameter: Message too long")
1✔
657

658
        # for compatibility reasons, AWS allows users to use either TargetArn or TopicArn for publishing to a topic
659
        # use any of them for topic validation
660
        topic_or_target_arn = topic_arn or target_arn
1✔
661
        topic = None
1✔
662

663
        if is_fifo := (topic_or_target_arn and ".fifo" in topic_or_target_arn):
1✔
664
            if not message_group_id:
1✔
665
                raise InvalidParameterException(
1✔
666
                    "Invalid parameter: The MessageGroupId parameter is required for FIFO topics",
667
                )
668
            topic = self._get_topic(topic_or_target_arn, context)
1✔
669
            if topic["attributes"]["ContentBasedDeduplication"] == "false":
1✔
670
                if not message_deduplication_id:
1✔
671
                    raise InvalidParameterException(
1✔
672
                        "Invalid parameter: The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly",
673
                    )
674
        elif message_deduplication_id:
1✔
675
            # this is the first one to raise if both are set while the topic is not fifo
676
            raise InvalidParameterException(
1✔
677
                "Invalid parameter: MessageDeduplicationId Reason: The request includes MessageDeduplicationId parameter that is not valid for this topic type"
678
            )
679

680
        is_endpoint_publish = target_arn and ":endpoint/" in target_arn
1✔
681
        if message_structure == "json":
1✔
682
            try:
1✔
683
                message = json.loads(message)
1✔
684
                # Keys in the JSON object that correspond to supported transport protocols must have
685
                # simple JSON string values.
686
                # Non-string values will cause the key to be ignored.
687
                message = {key: field for key, field in message.items() if isinstance(field, str)}
1✔
688
                # TODO: check no default key for direct TargetArn endpoint publish, need credentials
689
                # see example: https://docs.aws.amazon.com/sns/latest/dg/sns-send-custom-platform-specific-payloads-mobile-devices.html
690
                if "default" not in message and not is_endpoint_publish:
1✔
691
                    raise InvalidParameterException(
1✔
692
                        "Invalid parameter: Message Structure - No default entry in JSON message body"
693
                    )
694
            except json.JSONDecodeError:
1✔
695
                raise InvalidParameterException(
1✔
696
                    "Invalid parameter: Message Structure - JSON message body failed to parse"
697
                )
698

699
        if not phone_number:
1✔
700
            # use the account to get the store from the TopicArn (you can only publish in the same region as the topic)
701
            parsed_arn = parse_and_validate_topic_arn(topic_or_target_arn)
1✔
702
            store = self.get_store(account_id=parsed_arn["account"], region=context.region)
1✔
703
            if is_endpoint_publish:
1✔
704
                if not (platform_endpoint := store.platform_endpoints.get(target_arn)):
1✔
705
                    raise InvalidParameterException(
1✔
706
                        "Invalid parameter: TargetArn Reason: No endpoint found for the target arn specified"
707
                    )
708
                elif (
1✔
709
                    not platform_endpoint.platform_endpoint["Attributes"]
710
                    .get("Enabled", "false")
711
                    .lower()
712
                    == "true"
713
                ):
714
                    raise EndpointDisabledException("Endpoint is disabled")
1✔
715
            else:
716
                topic = self._get_topic(topic_or_target_arn, context)
1✔
717
        else:
718
            # use the store from the request context
719
            store = self.get_store(account_id=context.account_id, region=context.region)
1✔
720

721
        message_ctx = SnsMessage(
1✔
722
            type=SnsMessageType.Notification,
723
            message=message,
724
            message_attributes=message_attributes,
725
            message_deduplication_id=message_deduplication_id,
726
            message_group_id=message_group_id,
727
            message_structure=message_structure,
728
            subject=subject,
729
            is_fifo=is_fifo,
730
        )
731
        publish_ctx = SnsPublishContext(
1✔
732
            message=message_ctx, store=store, request_headers=context.request.headers
733
        )
734

735
        if is_endpoint_publish:
1✔
736
            self._publisher.publish_to_application_endpoint(
1✔
737
                ctx=publish_ctx, endpoint_arn=target_arn
738
            )
739
        elif phone_number:
1✔
740
            self._publisher.publish_to_phone_number(ctx=publish_ctx, phone_number=phone_number)
1✔
741
        else:
742
            # beware if the subscription is FIFO, the order might not be guaranteed.
743
            # 2 quick call to this method in succession might not be executed in order in the executor?
744
            # TODO: test how this behaves in a FIFO context with a lot of threads.
745
            publish_ctx.topic_attributes |= topic["attributes"]
1✔
746
            self._publisher.publish_to_topic(publish_ctx, topic_or_target_arn)
1✔
747

748
        if is_fifo:
1✔
749
            return PublishResponse(
1✔
750
                MessageId=message_ctx.message_id, SequenceNumber=message_ctx.sequencer_number
751
            )
752

753
        return PublishResponse(MessageId=message_ctx.message_id)
1✔
754

755
    def publish_batch(
1✔
756
        self,
757
        context: RequestContext,
758
        topic_arn: topicARN,
759
        publish_batch_request_entries: PublishBatchRequestEntryList,
760
        **kwargs,
761
    ) -> PublishBatchResponse:
762
        if len(publish_batch_request_entries) > 10:
1✔
763
            raise TooManyEntriesInBatchRequestException(
1✔
764
                "The batch request contains more entries than permissible."
765
            )
766

767
        parsed_arn = parse_and_validate_topic_arn(topic_arn)
1✔
768
        store = self.get_store(account_id=parsed_arn["account"], region=context.region)
1✔
769
        topic = self._get_topic(topic_arn, context)
1✔
770
        ids = [entry["Id"] for entry in publish_batch_request_entries]
1✔
771
        if len(set(ids)) != len(publish_batch_request_entries):
1✔
772
            raise BatchEntryIdsNotDistinctException(
1✔
773
                "Two or more batch entries in the request have the same Id."
774
            )
775

776
        response: PublishBatchResponse = {"Successful": [], "Failed": []}
1✔
777

778
        # TODO: write AWS validated tests with FilterPolicy and batching
779
        # TODO: find a scenario where we can fail to send a message synchronously to be able to report it
780
        # right now, it seems that AWS fails the whole publish if something is wrong in the format of 1 message
781

782
        total_batch_size = 0
1✔
783
        message_contexts = []
1✔
784
        for entry_index, entry in enumerate(publish_batch_request_entries, start=1):
1✔
785
            message_payload = entry.get("Message")
1✔
786
            message_attributes = entry.get("MessageAttributes", {})
1✔
787
            if message_attributes:
1✔
788
                # if a message contains non-valid message attributes, it
789
                # will fail for the first non-valid message encountered, and raise ParameterValueInvalid
790
                _validate_message_attributes(message_attributes, position=entry_index)
1✔
791

792
            total_batch_size += _get_total_publish_size(message_payload, message_attributes)
1✔
793

794
            # TODO: WRITE AWS VALIDATED
795
            if entry.get("MessageStructure") == "json":
1✔
796
                try:
1✔
797
                    message = json.loads(message_payload)
1✔
798
                    # Keys in the JSON object that correspond to supported transport protocols must have
799
                    # simple JSON string values.
800
                    # Non-string values will cause the key to be ignored.
801
                    message = {
1✔
802
                        key: field for key, field in message.items() if isinstance(field, str)
803
                    }
804
                    if "default" not in message:
1✔
805
                        raise InvalidParameterException(
1✔
806
                            "Invalid parameter: Message Structure - No default entry in JSON message body"
807
                        )
808
                    entry["Message"] = message  # noqa
1✔
809
                except json.JSONDecodeError:
1✔
UNCOV
810
                    raise InvalidParameterException(
×
811
                        "Invalid parameter: Message Structure - JSON message body failed to parse"
812
                    )
813

814
            if is_fifo := (topic_arn.endswith(".fifo")):
1✔
815
                if not all("MessageGroupId" in entry for entry in publish_batch_request_entries):
1✔
816
                    raise InvalidParameterException(
1✔
817
                        "Invalid parameter: The MessageGroupId parameter is required for FIFO topics"
818
                    )
819
                if topic["attributes"]["ContentBasedDeduplication"] == "false":
1✔
820
                    if not all(
1✔
821
                        "MessageDeduplicationId" in entry for entry in publish_batch_request_entries
822
                    ):
823
                        raise InvalidParameterException(
1✔
824
                            "Invalid parameter: The topic should either have ContentBasedDeduplication enabled or MessageDeduplicationId provided explicitly",
825
                        )
826

827
            msg_ctx = SnsMessage.from_batch_entry(entry, is_fifo=is_fifo)
1✔
828
            message_contexts.append(msg_ctx)
1✔
829
            success = PublishBatchResultEntry(
1✔
830
                Id=entry["Id"],
831
                MessageId=msg_ctx.message_id,
832
            )
833
            if is_fifo:
1✔
834
                success["SequenceNumber"] = msg_ctx.sequencer_number
1✔
835
            response["Successful"].append(success)
1✔
836

837
        if total_batch_size > MAXIMUM_MESSAGE_LENGTH:
1✔
838
            raise CommonServiceException(
1✔
839
                code="BatchRequestTooLong",
840
                message="The length of all the messages put together is more than the limit.",
841
                sender_fault=True,
842
            )
843

844
        publish_ctx = SnsBatchPublishContext(
1✔
845
            messages=message_contexts,
846
            store=store,
847
            request_headers=context.request.headers,
848
            topic_attributes=topic["attributes"],
849
        )
850
        self._publisher.publish_batch_to_topic(publish_ctx, topic_arn)
1✔
851

852
        return response
1✔
853

854
    #
855
    # PlatformApplications
856
    #
857
    def create_platform_application(
1✔
858
        self,
859
        context: RequestContext,
860
        name: String,
861
        platform: String,
862
        attributes: MapStringToString,
863
        **kwargs,
864
    ) -> CreatePlatformApplicationResponse:
865
        _validate_platform_application_name(name)
1✔
866
        if platform not in VALID_APPLICATION_PLATFORMS:
1✔
867
            raise InvalidParameterException(
1✔
868
                f"Invalid parameter: Platform Reason: {platform} is not supported"
869
            )
870

871
        _validate_platform_application_attributes(attributes)
1✔
872

873
        # attribute validation specific to create_platform_application
874
        if "PlatformCredential" in attributes and "PlatformPrincipal" not in attributes:
1✔
875
            raise InvalidParameterException(
1✔
876
                "Invalid parameter: Attributes Reason: PlatformCredential attribute provided without PlatformPrincipal"
877
            )
878

879
        elif "PlatformPrincipal" in attributes and "PlatformCredential" not in attributes:
1✔
880
            raise InvalidParameterException(
1✔
881
                "Invalid parameter: Attributes Reason: PlatformPrincipal attribute provided without PlatformCredential"
882
            )
883

884
        store = self.get_store(context.account_id, context.region)
1✔
885
        # We are not validating the access data here like AWS does (against ADM and the like)
886
        attributes.pop("PlatformPrincipal")
1✔
887
        attributes.pop("PlatformCredential")
1✔
888
        _attributes = {"Enabled": "true"}
1✔
889
        _attributes.update(attributes)
1✔
890
        application_arn = sns_platform_application_arn(
1✔
891
            platform_application_name=name,
892
            platform=platform,
893
            account_id=context.account_id,
894
            region_name=context.region,
895
        )
896
        platform_application_details = PlatformApplicationDetails(
1✔
897
            platform_application=PlatformApplication(
898
                PlatformApplicationArn=application_arn,
899
                Attributes=_attributes,
900
            ),
901
            platform_endpoints={},
902
        )
903
        store.platform_applications[application_arn] = platform_application_details
1✔
904

905
        return platform_application_details.platform_application
1✔
906

907
    def delete_platform_application(
1✔
908
        self, context: RequestContext, platform_application_arn: String, **kwargs
909
    ) -> None:
910
        store = self.get_store(context.account_id, context.region)
1✔
911
        store.platform_applications.pop(platform_application_arn, None)
1✔
912
        # TODO: if the platform had endpoints, should we remove them from the store? There is no way to list
913
        #   endpoints without an application, so this is impossible to check the state of AWS here
914

915
    def list_platform_applications(
1✔
916
        self, context: RequestContext, next_token: String | None = None, **kwargs
917
    ) -> ListPlatformApplicationsResponse:
918
        store = self.get_store(context.account_id, context.region)
1✔
919
        platform_applications = store.platform_applications.values()
1✔
920
        paginated_applications = PaginatedList(platform_applications)
1✔
921
        page, token = paginated_applications.get_page(
1✔
922
            token_generator=lambda x: get_next_page_token_from_arn(x["PlatformApplicationArn"]),
923
            page_size=100,
924
            next_token=next_token,
925
        )
926

927
        response = ListPlatformApplicationsResponse(
1✔
928
            PlatformApplications=[platform_app.platform_application for platform_app in page]
929
        )
930
        if token:
1✔
UNCOV
931
            response["NextToken"] = token
×
932
        return response
1✔
933

934
    def get_platform_application_attributes(
1✔
935
        self, context: RequestContext, platform_application_arn: String, **kwargs
936
    ) -> GetPlatformApplicationAttributesResponse:
937
        platform_application = self._get_platform_application(platform_application_arn, context)
1✔
938
        attributes = platform_application["Attributes"]
1✔
939
        return GetPlatformApplicationAttributesResponse(Attributes=attributes)
1✔
940

941
    def set_platform_application_attributes(
1✔
942
        self,
943
        context: RequestContext,
944
        platform_application_arn: String,
945
        attributes: MapStringToString,
946
        **kwargs,
947
    ) -> None:
948
        parse_and_validate_platform_application_arn(platform_application_arn)
1✔
949
        _validate_platform_application_attributes(attributes)
1✔
950

951
        platform_application = self._get_platform_application(platform_application_arn, context)
1✔
952
        platform_application["Attributes"].update(attributes)
1✔
953

954
    #
955
    # Platform Endpoints
956
    #
957

958
    def create_platform_endpoint(
1✔
959
        self,
960
        context: RequestContext,
961
        platform_application_arn: String,
962
        token: String,
963
        custom_user_data: String | None = None,
964
        attributes: MapStringToString | None = None,
965
        **kwargs,
966
    ) -> CreateEndpointResponse:
967
        store = self.get_store(context.account_id, context.region)
1✔
968
        application = store.platform_applications.get(platform_application_arn)
1✔
969
        if not application:
1✔
970
            raise NotFoundException("PlatformApplication does not exist")
1✔
971
        endpoint_arn = application.platform_endpoints.get(token, {})
1✔
972
        attributes = attributes or {}
1✔
973
        _validate_endpoint_attributes(attributes, allow_empty=True)
1✔
974
        # CustomUserData can be specified both in attributes and as parameter. Attributes take precedence
975
        attributes.setdefault(EndpointAttributeNames.CUSTOM_USER_DATA, custom_user_data)
1✔
976
        _attributes = {"Enabled": "true", "Token": token, **attributes}
1✔
977
        if endpoint_arn and (
1✔
978
            platform_endpoint_details := store.platform_endpoints.get(endpoint_arn)
979
        ):
980
            # endpoint for that application with that particular token already exists
981
            if not platform_endpoint_details.platform_endpoint["Attributes"] == _attributes:
1✔
982
                raise InvalidParameterException(
1✔
983
                    f"Invalid parameter: Token Reason: Endpoint {endpoint_arn} already exists with the same Token, but different attributes."
984
                )
985
            else:
986
                return CreateEndpointResponse(EndpointArn=endpoint_arn)
1✔
987

988
        endpoint_arn = create_platform_endpoint_arn(platform_application_arn)
1✔
989
        platform_endpoint = PlatformEndpoint(
1✔
990
            platform_application_arn=endpoint_arn,
991
            platform_endpoint=Endpoint(
992
                Attributes=_attributes,
993
                EndpointArn=endpoint_arn,
994
            ),
995
        )
996
        store.platform_endpoints[endpoint_arn] = platform_endpoint
1✔
997
        application.platform_endpoints[token] = endpoint_arn
1✔
998

999
        return CreateEndpointResponse(EndpointArn=endpoint_arn)
1✔
1000

1001
    def delete_endpoint(self, context: RequestContext, endpoint_arn: String, **kwargs) -> None:
1✔
1002
        store = self.get_store(context.account_id, context.region)
1✔
1003
        platform_endpoint_details = store.platform_endpoints.pop(endpoint_arn, None)
1✔
1004
        if platform_endpoint_details:
1✔
1005
            platform_application = store.platform_applications.get(
1✔
1006
                platform_endpoint_details.platform_application_arn
1007
            )
1008
            if platform_application:
1✔
UNCOV
1009
                platform_endpoint = platform_endpoint_details.platform_endpoint
×
UNCOV
1010
                platform_application.platform_endpoints.pop(
×
1011
                    platform_endpoint["Attributes"]["Token"], None
1012
                )
1013

1014
    def list_endpoints_by_platform_application(
1✔
1015
        self,
1016
        context: RequestContext,
1017
        platform_application_arn: String,
1018
        next_token: String | None = None,
1019
        **kwargs,
1020
    ) -> ListEndpointsByPlatformApplicationResponse:
1021
        store = self.get_store(context.account_id, context.region)
1✔
1022
        platform_application = store.platform_applications.get(platform_application_arn)
1✔
1023
        if not platform_application:
1✔
1024
            raise NotFoundException("PlatformApplication does not exist")
1✔
1025
        endpoint_arns = platform_application.platform_endpoints.values()
1✔
1026
        paginated_endpoint_arns = PaginatedList(endpoint_arns)
1✔
1027
        page, token = paginated_endpoint_arns.get_page(
1✔
1028
            token_generator=lambda x: get_next_page_token_from_arn(x),
1029
            page_size=100,
1030
            next_token=next_token,
1031
        )
1032

1033
        response = ListEndpointsByPlatformApplicationResponse(
1✔
1034
            Endpoints=[
1035
                store.platform_endpoints[endpoint_arn].platform_endpoint
1036
                for endpoint_arn in page
1037
                if endpoint_arn in store.platform_endpoints
1038
            ]
1039
        )
1040
        if token:
1✔
UNCOV
1041
            response["NextToken"] = token
×
1042
        return response
1✔
1043

1044
    def get_endpoint_attributes(
1✔
1045
        self, context: RequestContext, endpoint_arn: String, **kwargs
1046
    ) -> GetEndpointAttributesResponse:
1047
        store = self.get_store(context.account_id, context.region)
1✔
1048
        platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
1✔
1049
        if not platform_endpoint_details:
1✔
1050
            raise NotFoundException("Endpoint does not exist")
1✔
1051
        attributes = platform_endpoint_details.platform_endpoint["Attributes"]
1✔
1052
        return GetEndpointAttributesResponse(Attributes=attributes)
1✔
1053

1054
    def set_endpoint_attributes(
1✔
1055
        self, context: RequestContext, endpoint_arn: String, attributes: MapStringToString, **kwargs
1056
    ) -> None:
1057
        store = self.get_store(context.account_id, context.region)
1✔
1058
        platform_endpoint_details = store.platform_endpoints.get(endpoint_arn)
1✔
1059
        if not platform_endpoint_details:
1✔
1060
            raise NotFoundException("Endpoint does not exist")
1✔
1061
        _validate_endpoint_attributes(attributes)
1✔
1062
        attributes = attributes or {}
1✔
1063
        platform_endpoint_details.platform_endpoint["Attributes"].update(attributes)
1✔
1064

1065
    #
1066
    # Sms operations
1067
    #
1068

1069
    def set_sms_attributes(
1✔
1070
        self, context: RequestContext, attributes: MapStringToString, **kwargs
1071
    ) -> SetSMSAttributesResponse:
1072
        store = self.get_store(context.account_id, context.region)
1✔
1073
        _validate_sms_attributes(attributes)
1✔
1074
        _set_sms_attribute_default(store)
1✔
1075
        store.sms_attributes.update(attributes or {})
1✔
1076
        return SetSMSAttributesResponse()
1✔
1077

1078
    def get_sms_attributes(
1✔
1079
        self, context: RequestContext, attributes: ListString | None = None, **kwargs
1080
    ) -> GetSMSAttributesResponse:
1081
        store = self.get_store(context.account_id, context.region)
1✔
1082
        _set_sms_attribute_default(store)
1✔
1083
        store_attributes = store.sms_attributes
1✔
1084
        return_attributes = {}
1✔
1085
        for k, v in store_attributes.items():
1✔
1086
            if not attributes or k in attributes:
1✔
1087
                return_attributes[k] = store_attributes[k]
1✔
1088

1089
        return GetSMSAttributesResponse(attributes=return_attributes)
1✔
1090

1091
    #
1092
    # Phone number operations
1093
    #
1094

1095
    def check_if_phone_number_is_opted_out(
1✔
1096
        self, context: RequestContext, phone_number: PhoneNumber, **kwargs
1097
    ) -> CheckIfPhoneNumberIsOptedOutResponse:
1098
        store = sns_stores[context.account_id][context.region]
1✔
1099
        return CheckIfPhoneNumberIsOptedOutResponse(
1✔
1100
            isOptedOut=phone_number in store.PHONE_NUMBERS_OPTED_OUT
1101
        )
1102

1103
    def list_phone_numbers_opted_out(
1✔
1104
        self, context: RequestContext, next_token: string | None = None, **kwargs
1105
    ) -> ListPhoneNumbersOptedOutResponse:
1106
        store = self.get_store(context.account_id, context.region)
1✔
1107
        numbers_opted_out = PaginatedList(store.PHONE_NUMBERS_OPTED_OUT)
1✔
1108
        page, nxt = numbers_opted_out.get_page(
1✔
1109
            token_generator=lambda x: x,
1110
            next_token=next_token,
1111
            page_size=100,
1112
        )
1113
        phone_numbers = {"phoneNumbers": page, "nextToken": nxt}
1✔
1114
        return ListPhoneNumbersOptedOutResponse(**phone_numbers)
1✔
1115

1116
    def opt_in_phone_number(
1✔
1117
        self, context: RequestContext, phone_number: PhoneNumber, **kwargs
1118
    ) -> OptInPhoneNumberResponse:
1119
        store = self.get_store(context.account_id, context.region)
1✔
1120
        if phone_number in store.PHONE_NUMBERS_OPTED_OUT:
1✔
1121
            store.PHONE_NUMBERS_OPTED_OUT.remove(phone_number)
1✔
1122
        return OptInPhoneNumberResponse()
1✔
1123

1124
    #
1125
    # Permission operations
1126
    #
1127

1128
    def add_permission(
1✔
1129
        self,
1130
        context: RequestContext,
1131
        topic_arn: topicARN,
1132
        label: label,
1133
        aws_account_id: DelegatesList,
1134
        action_name: ActionsList,
1135
        **kwargs,
1136
    ) -> None:
1137
        topic: Topic = self._get_topic(topic_arn, context)
1✔
1138
        policy = json.loads(topic["attributes"]["Policy"])
1✔
1139
        statement = next(
1✔
1140
            (statement for statement in policy["Statement"] if statement["Sid"] == label),
1141
            None,
1142
        )
1143

1144
        if statement:
1✔
1145
            raise InvalidParameterException("Invalid parameter: Statement already exists")
1✔
1146

1147
        if any(action not in VALID_POLICY_ACTIONS for action in action_name):
1✔
1148
            raise InvalidParameterException(
1✔
1149
                "Invalid parameter: Policy statement action out of service scope!"
1150
            )
1151

1152
        principals = [
1✔
1153
            f"arn:{get_partition(context.region)}:iam::{account_id}:root"
1154
            for account_id in aws_account_id
1155
        ]
1156
        actions = [f"SNS:{action}" for action in action_name]
1✔
1157

1158
        statement = {
1✔
1159
            "Sid": label,
1160
            "Effect": "Allow",
1161
            "Principal": {"AWS": principals[0] if len(principals) == 1 else principals},
1162
            "Action": actions[0] if len(actions) == 1 else actions,
1163
            "Resource": topic_arn,
1164
        }
1165

1166
        policy["Statement"].append(statement)
1✔
1167
        topic["attributes"]["Policy"] = json.dumps(policy)
1✔
1168

1169
    def remove_permission(
1✔
1170
        self, context: RequestContext, topic_arn: topicARN, label: label, **kwargs
1171
    ) -> None:
1172
        topic = self._get_topic(topic_arn, context)
1✔
1173
        policy = json.loads(topic["attributes"]["Policy"])
1✔
1174
        statements = policy["Statement"]
1✔
1175
        statements = [statement for statement in statements if statement["Sid"] != label]
1✔
1176
        policy["Statement"] = statements
1✔
1177
        topic["attributes"]["Policy"] = json.dumps(policy)
1✔
1178

1179
    #
1180
    # Data Protection Policy operations
1181
    #
1182

1183
    def get_data_protection_policy(
1✔
1184
        self, context: RequestContext, resource_arn: topicARN, **kwargs
1185
    ) -> GetDataProtectionPolicyResponse:
1186
        topic = self._get_topic(resource_arn, context)
1✔
1187
        return GetDataProtectionPolicyResponse(
1✔
1188
            DataProtectionPolicy=topic.get("data_protection_policy")
1189
        )
1190

1191
    def put_data_protection_policy(
1✔
1192
        self,
1193
        context: RequestContext,
1194
        resource_arn: topicARN,
1195
        data_protection_policy: attributeValue,
1196
        **kwargs,
1197
    ) -> None:
1198
        topic = self._get_topic(resource_arn, context)
1✔
1199
        topic["data_protection_policy"] = data_protection_policy
1✔
1200

1201
    def list_tags_for_resource(
1✔
1202
        self, context: RequestContext, resource_arn: AmazonResourceName, **kwargs
1203
    ) -> ListTagsForResourceResponse:
1204
        store = sns_stores[context.account_id][context.region]
1✔
1205
        tags = store.TAGS.list_tags_for_resource(resource_arn)
1✔
1206
        return ListTagsForResourceResponse(Tags=tags.get("Tags"))
1✔
1207

1208
    def tag_resource(
1✔
1209
        self, context: RequestContext, resource_arn: AmazonResourceName, tags: TagList, **kwargs
1210
    ) -> TagResourceResponse:
1211
        unique_tag_keys = {tag["Key"] for tag in tags}
1✔
1212
        if len(unique_tag_keys) < len(tags):
1✔
1213
            raise InvalidParameterException("Invalid parameter: Duplicated keys are not allowed.")
1✔
1214
        store = sns_stores[context.account_id][context.region]
1✔
1215
        store.TAGS.tag_resource(resource_arn, tags)
1✔
1216
        return TagResourceResponse()
1✔
1217

1218
    def untag_resource(
1✔
1219
        self,
1220
        context: RequestContext,
1221
        resource_arn: AmazonResourceName,
1222
        tag_keys: TagKeyList,
1223
        **kwargs,
1224
    ) -> UntagResourceResponse:
1225
        store = sns_stores[context.account_id][context.region]
1✔
1226
        store.TAGS.untag_resource(resource_arn, tag_keys)
1✔
1227
        return UntagResourceResponse()
1✔
1228

1229
    @staticmethod
1✔
1230
    def get_store(account_id: str, region: str) -> SnsStore:
1✔
1231
        return sns_stores[account_id][region]
1✔
1232

1233
    @staticmethod
1✔
1234
    def _get_topic(arn: str, context: RequestContext, multi_region: bool = False) -> Topic:
1✔
1235
        """
1236
        :param arn: the Topic ARN
1237
        :param context: the RequestContext of the request
1238
        :return: the model Topic
1239
        """
1240
        arn_data = parse_and_validate_topic_arn(arn)
1✔
1241
        if not multi_region and context.region != arn_data["region"]:
1✔
1242
            raise InvalidParameterException("Invalid parameter: TopicArn")
1✔
1243
        try:
1✔
1244
            store = SnsProvider.get_store(context.account_id, context.region)
1✔
1245
            return store.topics[arn]
1✔
1246
        except KeyError:
1✔
1247
            raise NotFoundException("Topic does not exist")
1✔
1248

1249
    @staticmethod
1✔
1250
    def _get_platform_application(
1✔
1251
        platform_application_arn: str, context: RequestContext
1252
    ) -> PlatformApplication:
1253
        parse_and_validate_platform_application_arn(platform_application_arn)
1✔
1254
        try:
1✔
1255
            store = SnsProvider.get_store(context.account_id, context.region)
1✔
1256
            return store.platform_applications[platform_application_arn].platform_application
1✔
1257
        except KeyError:
1✔
1258
            raise NotFoundException("PlatformApplication does not exist")
1✔
1259

1260

1261
def _create_topic(name: str, attributes: dict, context: RequestContext) -> Topic:
1✔
1262
    topic_arn = sns_topic_arn(
1✔
1263
        topic_name=name, region_name=context.region, account_id=context.account_id
1264
    )
1265
    topic: Topic = {
1✔
1266
        "name": name,
1267
        "arn": topic_arn,
1268
        "attributes": {},
1269
        "subscriptions": [],
1270
    }
1271
    attrs = _default_attributes(topic, context)
1✔
1272
    attrs.update(attributes or {})
1✔
1273
    topic["attributes"] = attrs
1✔
1274

1275
    return topic
1✔
1276

1277

1278
def _default_attributes(topic: Topic, context: RequestContext) -> TopicAttributesMap:
1✔
1279
    default_attributes = {
1✔
1280
        "DisplayName": "",
1281
        "Owner": context.account_id,
1282
        "Policy": _create_default_topic_policy(topic, context),
1283
        "SubscriptionsConfirmed": "0",
1284
        "SubscriptionsDeleted": "0",
1285
        "SubscriptionsPending": "0",
1286
        "TopicArn": topic["arn"],
1287
    }
1288
    if topic["name"].endswith(".fifo"):
1✔
1289
        default_attributes.update(
1✔
1290
            {
1291
                "ContentBasedDeduplication": "false",
1292
                "FifoTopic": "false",
1293
            }
1294
        )
1295
    return default_attributes
1✔
1296

1297

1298
def _create_default_effective_delivery_policy():
1✔
1299
    return json.dumps(
1✔
1300
        {
1301
            "http": {
1302
                "defaultHealthyRetryPolicy": {
1303
                    "minDelayTarget": 20,
1304
                    "maxDelayTarget": 20,
1305
                    "numRetries": 3,
1306
                    "numMaxDelayRetries": 0,
1307
                    "numNoDelayRetries": 0,
1308
                    "numMinDelayRetries": 0,
1309
                    "backoffFunction": "linear",
1310
                },
1311
                "disableSubscriptionOverrides": False,
1312
                "defaultRequestPolicy": {"headerContentType": "text/plain; charset=UTF-8"},
1313
            }
1314
        }
1315
    )
1316

1317

1318
def _create_default_topic_policy(topic: Topic, context: RequestContext) -> str:
1✔
1319
    return json.dumps(
1✔
1320
        {
1321
            "Version": "2008-10-17",
1322
            "Id": "__default_policy_ID",
1323
            "Statement": [
1324
                {
1325
                    "Effect": "Allow",
1326
                    "Sid": "__default_statement_ID",
1327
                    "Principal": {"AWS": "*"},
1328
                    "Action": [
1329
                        "SNS:GetTopicAttributes",
1330
                        "SNS:SetTopicAttributes",
1331
                        "SNS:AddPermission",
1332
                        "SNS:RemovePermission",
1333
                        "SNS:DeleteTopic",
1334
                        "SNS:Subscribe",
1335
                        "SNS:ListSubscriptionsByTopic",
1336
                        "SNS:Publish",
1337
                    ],
1338
                    "Resource": topic["arn"],
1339
                    "Condition": {"StringEquals": {"AWS:SourceOwner": context.account_id}},
1340
                }
1341
            ],
1342
        }
1343
    )
1344

1345

1346
def _validate_message_attributes(
1✔
1347
    message_attributes: MessageAttributeMap, position: int | None = None
1348
) -> None:
1349
    """
1350
    Validate the message attributes, and raises an exception if those do not follow AWS validation
1351
    See: https://docs.aws.amazon.com/sns/latest/dg/sns-message-attributes.html
1352
    Regex from: https://stackoverflow.com/questions/40718851/regex-that-does-not-allow-consecutive-dots
1353
    :param message_attributes: the message attributes map for the message
1354
    :param position: given to give the Batch Entry position if coming from `publishBatch`
1355
    :raises: InvalidParameterValueException
1356
    :return: None
1357
    """
1358
    for attr_name, attr in message_attributes.items():
1✔
1359
        if len(attr_name) > 256:
1✔
1360
            raise InvalidParameterValueException(
1✔
1361
                "Length of message attribute name must be less than 256 bytes."
1362
            )
1363
        _validate_message_attribute_name(attr_name)
1✔
1364
        # `DataType` is a required field for MessageAttributeValue
1365
        if (data_type := attr.get("DataType")) is None:
1✔
1366
            if position:
1✔
1367
                at = f"publishBatchRequestEntries.{position}.member.messageAttributes.{attr_name}.member.dataType"
1✔
1368
            else:
1369
                at = f"messageAttributes.{attr_name}.member.dataType"
1✔
1370

1371
            raise CommonServiceException(
1✔
1372
                code="ValidationError",
1373
                message=f"1 validation error detected: Value null at '{at}' failed to satisfy constraint: Member must not be null",
1374
                sender_fault=True,
1375
            )
1376

1377
        if data_type not in (
1✔
1378
            "String",
1379
            "Number",
1380
            "Binary",
1381
        ) and not ATTR_TYPE_REGEX.match(data_type):
1382
            raise InvalidParameterValueException(
1✔
1383
                f"The message attribute '{attr_name}' has an invalid message attribute type, the set of supported type prefixes is Binary, Number, and String."
1384
            )
1385
        if not any(attr_value.endswith("Value") for attr_value in attr):
1✔
1386
            raise InvalidParameterValueException(
1✔
1387
                f"The message attribute '{attr_name}' must contain non-empty message attribute value for message attribute type '{data_type}'."
1388
            )
1389

1390
        value_key_data_type = "Binary" if data_type.startswith("Binary") else "String"
1✔
1391
        value_key = f"{value_key_data_type}Value"
1✔
1392
        if value_key not in attr:
1✔
1393
            raise InvalidParameterValueException(
1✔
1394
                f"The message attribute '{attr_name}' with type '{data_type}' must use field '{value_key_data_type}'."
1395
            )
1396
        elif not attr[value_key]:
1✔
1397
            raise InvalidParameterValueException(
1✔
1398
                f"The message attribute '{attr_name}' must contain non-empty message attribute value for message attribute type '{data_type}'.",
1399
            )
1400

1401

1402
def _validate_message_attribute_name(name: str) -> None:
1✔
1403
    """
1404
    Validate the message attribute name with the specification of AWS.
1405
    The message attribute name can contain the following characters: A-Z, a-z, 0-9, underscore(_), hyphen(-), and period (.). The name must not start or end with a period, and it should not have successive periods.
1406
    :param name: message attribute name
1407
    :raises InvalidParameterValueException: if the name does not conform to the spec
1408
    """
1409
    if not MSG_ATTR_NAME_REGEX.match(name):
1✔
1410
        # find the proper exception
1411
        if name[0] == ".":
1✔
1412
            raise InvalidParameterValueException(
1✔
1413
                "Invalid message attribute name starting with character '.' was found."
1414
            )
1415
        elif name[-1] == ".":
1✔
1416
            raise InvalidParameterValueException(
1✔
1417
                "Invalid message attribute name ending with character '.' was found."
1418
            )
1419

1420
        for idx, char in enumerate(name):
1✔
1421
            if char not in VALID_MSG_ATTR_NAME_CHARS:
1✔
1422
                # change prefix from 0x to #x, without capitalizing the x
1423
                hex_char = "#x" + hex(ord(char)).upper()[2:]
1✔
1424
                raise InvalidParameterValueException(
1✔
1425
                    f"Invalid non-alphanumeric character '{hex_char}' was found in the message attribute name. Can only include alphanumeric characters, hyphens, underscores, or dots."
1426
                )
1427
            # even if we go negative index, it will be covered by starting/ending with dot
1428
            if char == "." and name[idx - 1] == ".":
1✔
1429
                raise InvalidParameterValueException(
1✔
1430
                    "Message attribute name can not have successive '.' character."
1431
                )
1432

1433

1434
def _validate_platform_application_name(name: str) -> None:
1✔
1435
    reason = ""
1✔
1436
    if not name:
1✔
1437
        reason = "cannot be empty"
1✔
1438
    elif not re.match(r"^.{0,256}$", name):
1✔
1439
        reason = "must be at most 256 characters long"
1✔
1440
    elif not re.match(r"^[A-Za-z0-9._-]+$", name):
1✔
1441
        reason = "must contain only characters 'a'-'z', 'A'-'Z', '0'-'9', '_', '-', and '.'"
1✔
1442

1443
    if reason:
1✔
1444
        raise InvalidParameterException(f"Invalid parameter: {name} Reason: {reason}")
1✔
1445

1446

1447
def _validate_platform_application_attributes(attributes: dict) -> None:
1✔
1448
    _check_empty_attributes(attributes)
1✔
1449

1450

1451
def _check_empty_attributes(attributes: dict) -> None:
1✔
1452
    if not attributes:
1✔
1453
        raise CommonServiceException(
1✔
1454
            code="ValidationError",
1455
            message="1 validation error detected: Value null at 'attributes' failed to satisfy constraint: Member must not be null",
1456
            sender_fault=True,
1457
        )
1458

1459

1460
def _validate_endpoint_attributes(attributes: dict, allow_empty: bool = False) -> None:
1✔
1461
    if not allow_empty:
1✔
1462
        _check_empty_attributes(attributes)
1✔
1463
    for key in attributes:
1✔
1464
        if key not in EndpointAttributeNames:
1✔
1465
            raise InvalidParameterException(
1✔
1466
                f"Invalid parameter: Attributes Reason: Invalid attribute name: {key}"
1467
            )
1468
    if len(attributes.get(EndpointAttributeNames.CUSTOM_USER_DATA, "")) > 2048:
1✔
1469
        raise InvalidParameterException(
1✔
1470
            "Invalid parameter: Attributes Reason: Invalid value for attribute: CustomUserData: must be at most 2048 bytes long in UTF-8 encoding"
1471
        )
1472

1473

1474
def _validate_sms_attributes(attributes: dict) -> None:
1✔
1475
    for k, v in attributes.items():
1✔
1476
        if k not in SMS_ATTRIBUTE_NAMES:
1✔
1477
            raise InvalidParameterException(f"{k} is not a valid attribute")
1✔
1478
    default_send_id = attributes.get("DefaultSendID")
1✔
1479
    if default_send_id and not re.match(SMS_DEFAULT_SENDER_REGEX, default_send_id):
1✔
UNCOV
1480
        raise InvalidParameterException("DefaultSendID is not a valid attribute")
×
1481
    sms_type = attributes.get("DefaultSMSType")
1✔
1482
    if sms_type and sms_type not in SMS_TYPES:
1✔
1483
        raise InvalidParameterException("DefaultSMSType is invalid")
1✔
1484

1485

1486
def _set_sms_attribute_default(store: SnsStore) -> None:
1✔
1487
    # TODO: don't call this on every sms attribute crud api call
1488
    store.sms_attributes.setdefault("MonthlySpendLimit", "1")
1✔
1489

1490

1491
def _check_matching_tags(topic_arn: str, tags: TagList | None, store: SnsStore) -> bool:
1✔
1492
    """
1493
    Checks if a topic to be created doesn't already exist with different tags
1494
    :param topic_arn: Arn of the topic
1495
    :param tags: Tags to be checked
1496
    :param store: Store object that holds the topics and tags
1497
    :return: False if there is a mismatch in tags, True otherwise
1498
    """
1499
    existing_tags = store.TAGS.list_tags_for_resource(topic_arn)["Tags"]
1✔
1500
    # if this is none there is nothing to check
1501
    if topic_arn in store.topics:
1✔
1502
        if tags is None:
1✔
1503
            tags = []
1✔
1504
        for tag in tags:
1✔
1505
            # this means topic already created with empty tags and when we try to create it
1506
            # again with other tag value then it should fail according to aws documentation.
1507
            if existing_tags is not None and tag not in existing_tags:
1✔
1508
                return False
1✔
1509
    return True
1✔
1510

1511

1512
def _get_total_publish_size(
1✔
1513
    message_body: str, message_attributes: MessageAttributeMap | None
1514
) -> int:
1515
    size = _get_byte_size(message_body)
1✔
1516
    if message_attributes:
1✔
1517
        # https://docs.aws.amazon.com/sns/latest/dg/sns-message-attributes.html
1518
        # All parts of the message attribute, including name, type, and value, are included in the message size
1519
        # restriction, which is 256 KB.
1520
        # iterate over the Keys and Attributes, adding the length of the Key to the length of all Attributes values
1521
        # (DataType and StringValue or BinaryValue)
1522
        size += sum(
1✔
1523
            _get_byte_size(key) + sum(_get_byte_size(attr_value) for attr_value in attr.values())
1524
            for key, attr in message_attributes.items()
1525
        )
1526

1527
    return size
1✔
1528

1529

1530
def _get_byte_size(payload: str | bytes) -> int:
1✔
1531
    # Calculate the real length of the byte object if the object is a string
1532
    return len(to_bytes(payload))
1✔
1533

1534

1535
def _register_sns_api_resource(router: Router):
1✔
1536
    """Register the retrospection endpoints as internal LocalStack endpoints."""
UNCOV
1537
    router.add(SNSServicePlatformEndpointMessagesApiResource())
×
UNCOV
1538
    router.add(SNSServiceSMSMessagesApiResource())
×
UNCOV
1539
    router.add(SNSServiceSubscriptionTokenApiResource())
×
1540

1541

1542
class SNSInternalResource:
1✔
1543
    resource_type: str
1✔
1544
    """Base class with helper to properly track usage of internal endpoints"""
1✔
1545

1546
    def count_usage(self):
1✔
1547
        internal_api_calls.labels(resource_type=self.resource_type).increment()
1✔
1548

1549

1550
def count_usage(f):
1✔
1551
    @functools.wraps(f)
1✔
1552
    def _wrapper(self, *args, **kwargs):
1✔
1553
        self.count_usage()
1✔
1554
        return f(self, *args, **kwargs)
1✔
1555

1556
    return _wrapper
1✔
1557

1558

1559
class SNSServicePlatformEndpointMessagesApiResource(SNSInternalResource):
1✔
1560
    resource_type = "platform-endpoint-message"
1✔
1561
    """Provides a REST API for retrospective access to platform endpoint messages sent via SNS.
1✔
1562

1563
    This is registered as a LocalStack internal HTTP resource.
1564

1565
    This endpoint accepts:
1566
    - GET param `accountId`: selector for AWS account. If not specified, return fallback `000000000000` test ID
1567
    - GET param `region`: selector for AWS `region`. If not specified, return default "us-east-1"
1568
    - GET param `endpointArn`: filter for `endpointArn` resource in SNS
1569
    - DELETE param `accountId`: selector for AWS account
1570
    - DELETE param `region`: will delete saved messages for `region`
1571
    - DELETE param `endpointArn`: will delete saved messages for `endpointArn`
1572
    """
1573

1574
    _PAYLOAD_FIELDS = [
1✔
1575
        "TargetArn",
1576
        "TopicArn",
1577
        "Message",
1578
        "MessageAttributes",
1579
        "MessageStructure",
1580
        "Subject",
1581
        "MessageId",
1582
    ]
1583

1584
    @route(PLATFORM_ENDPOINT_MSGS_ENDPOINT, methods=["GET"])
1✔
1585
    @count_usage
1✔
1586
    def on_get(self, request: Request):
1✔
1587
        filter_endpoint_arn = request.args.get("endpointArn")
1✔
1588
        account_id = (
1✔
1589
            request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
1590
            if not filter_endpoint_arn
1591
            else extract_account_id_from_arn(filter_endpoint_arn)
1592
        )
1593
        region = (
1✔
1594
            request.args.get("region", AWS_REGION_US_EAST_1)
1595
            if not filter_endpoint_arn
1596
            else extract_region_from_arn(filter_endpoint_arn)
1597
        )
1598
        store: SnsStore = sns_stores[account_id][region]
1✔
1599
        if filter_endpoint_arn:
1✔
1600
            messages = store.platform_endpoint_messages.get(filter_endpoint_arn, [])
1✔
1601
            messages = _format_messages(messages, self._PAYLOAD_FIELDS)
1✔
1602
            return {
1✔
1603
                "platform_endpoint_messages": {filter_endpoint_arn: messages},
1604
                "region": region,
1605
            }
1606

1607
        platform_endpoint_messages = {
1✔
1608
            endpoint_arn: _format_messages(messages, self._PAYLOAD_FIELDS)
1609
            for endpoint_arn, messages in store.platform_endpoint_messages.items()
1610
        }
1611
        return {
1✔
1612
            "platform_endpoint_messages": platform_endpoint_messages,
1613
            "region": region,
1614
        }
1615

1616
    @route(PLATFORM_ENDPOINT_MSGS_ENDPOINT, methods=["DELETE"])
1✔
1617
    @count_usage
1✔
1618
    def on_delete(self, request: Request) -> Response:
1✔
1619
        filter_endpoint_arn = request.args.get("endpointArn")
1✔
1620
        account_id = (
1✔
1621
            request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
1622
            if not filter_endpoint_arn
1623
            else extract_account_id_from_arn(filter_endpoint_arn)
1624
        )
1625
        region = (
1✔
1626
            request.args.get("region", AWS_REGION_US_EAST_1)
1627
            if not filter_endpoint_arn
1628
            else extract_region_from_arn(filter_endpoint_arn)
1629
        )
1630
        store: SnsStore = sns_stores[account_id][region]
1✔
1631
        if filter_endpoint_arn:
1✔
1632
            store.platform_endpoint_messages.pop(filter_endpoint_arn, None)
1✔
1633
            return Response("", status=204)
1✔
1634

1635
        store.platform_endpoint_messages.clear()
1✔
1636
        return Response("", status=204)
1✔
1637

1638

1639
def register_sns_api_resource(router: Router):
1✔
1640
    """Register the retrospection endpoints as internal LocalStack endpoints."""
1641
    router.add(SNSServicePlatformEndpointMessagesApiResource())
1✔
1642
    router.add(SNSServiceSMSMessagesApiResource())
1✔
1643
    router.add(SNSServiceSubscriptionTokenApiResource())
1✔
1644

1645

1646
def _format_messages(sent_messages: list[dict[str, str]], validated_keys: list[str]):
1✔
1647
    """
1648
    This method format the messages to be more readable and undo the format change that was needed for Moto
1649
    Should be removed once we refactor SNS.
1650
    """
1651
    formatted_messages = []
1✔
1652
    for sent_message in sent_messages:
1✔
1653
        msg = {
1✔
1654
            key: json.dumps(value)
1655
            if key == "Message" and sent_message.get("MessageStructure") == "json"
1656
            else value
1657
            for key, value in sent_message.items()
1658
            if key in validated_keys
1659
        }
1660
        formatted_messages.append(msg)
1✔
1661

1662
    return formatted_messages
1✔
1663

1664

1665
class SNSServiceSMSMessagesApiResource(SNSInternalResource):
1✔
1666
    resource_type = "sms-message"
1✔
1667
    """Provides a REST API for retrospective access to SMS messages sent via SNS.
1✔
1668

1669
    This is registered as a LocalStack internal HTTP resource.
1670

1671
    This endpoint accepts:
1672
    - GET param `accountId`: selector for AWS account. If not specified, return fallback `000000000000` test ID
1673
    - GET param `region`: selector for AWS `region`. If not specified, return default "us-east-1"
1674
    - GET param `phoneNumber`: filter for `phoneNumber` resource in SNS
1675
    """
1676

1677
    _PAYLOAD_FIELDS = [
1✔
1678
        "PhoneNumber",
1679
        "TopicArn",
1680
        "SubscriptionArn",
1681
        "MessageId",
1682
        "Message",
1683
        "MessageAttributes",
1684
        "MessageStructure",
1685
        "Subject",
1686
    ]
1687

1688
    @route(SMS_MSGS_ENDPOINT, methods=["GET"])
1✔
1689
    @count_usage
1✔
1690
    def on_get(self, request: Request):
1✔
1691
        account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
1✔
1692
        region = request.args.get("region", AWS_REGION_US_EAST_1)
1✔
1693
        filter_phone_number = request.args.get("phoneNumber")
1✔
1694
        store: SnsStore = sns_stores[account_id][region]
1✔
1695
        if filter_phone_number:
1✔
1696
            messages = [
1✔
1697
                m for m in store.sms_messages if m.get("PhoneNumber") == filter_phone_number
1698
            ]
1699
            messages = _format_messages(messages, self._PAYLOAD_FIELDS)
1✔
1700
            return {
1✔
1701
                "sms_messages": {filter_phone_number: messages},
1702
                "region": region,
1703
            }
1704

1705
        sms_messages = {}
1✔
1706

1707
        for m in _format_messages(store.sms_messages, self._PAYLOAD_FIELDS):
1✔
1708
            sms_messages.setdefault(m.get("PhoneNumber"), []).append(m)
1✔
1709

1710
        return {
1✔
1711
            "sms_messages": sms_messages,
1712
            "region": region,
1713
        }
1714

1715
    @route(SMS_MSGS_ENDPOINT, methods=["DELETE"])
1✔
1716
    @count_usage
1✔
1717
    def on_delete(self, request: Request) -> Response:
1✔
1718
        account_id = request.args.get("accountId", DEFAULT_AWS_ACCOUNT_ID)
1✔
1719
        region = request.args.get("region", AWS_REGION_US_EAST_1)
1✔
1720
        filter_phone_number = request.args.get("phoneNumber")
1✔
1721
        store: SnsStore = sns_stores[account_id][region]
1✔
1722
        if filter_phone_number:
1✔
1723
            store.sms_messages = [
1✔
1724
                m for m in store.sms_messages if m.get("PhoneNumber") != filter_phone_number
1725
            ]
1726
            return Response("", status=204)
1✔
1727

1728
        store.sms_messages.clear()
1✔
1729
        return Response("", status=204)
1✔
1730

1731

1732
class SNSServiceSubscriptionTokenApiResource(SNSInternalResource):
1✔
1733
    resource_type = "subscription-token"
1✔
1734
    """Provides a REST API for retrospective access to Subscription Confirmation Tokens to confirm subscriptions.
1✔
1735
    Those are not sent for email, and sometimes inaccessible when working with external HTTPS endpoint which won't be
1736
    able to reach your local host.
1737

1738
    This is registered as a LocalStack internal HTTP resource.
1739

1740
    This endpoint has the following parameter:
1741
    - GET `subscription_arn`: `subscriptionArn`resource in SNS for which you want the SubscriptionToken
1742
    """
1743

1744
    @route(f"{SUBSCRIPTION_TOKENS_ENDPOINT}/<path:subscription_arn>", methods=["GET"])
1✔
1745
    @count_usage
1✔
1746
    def on_get(self, _request: Request, subscription_arn: str):
1✔
1747
        try:
1✔
1748
            parsed_arn = parse_arn(subscription_arn)
1✔
1749
        except InvalidArnException:
1✔
1750
            response = Response("", 400)
1✔
1751
            response.set_json(
1✔
1752
                {
1753
                    "error": "The provided SubscriptionARN is invalid",
1754
                    "subscription_arn": subscription_arn,
1755
                }
1756
            )
1757
            return response
1✔
1758

1759
        store: SnsStore = sns_stores[parsed_arn["account"]][parsed_arn["region"]]
1✔
1760

1761
        for token, sub_arn in store.subscription_tokens.items():
1✔
1762
            if sub_arn == subscription_arn:
1✔
1763
                return {
1✔
1764
                    "subscription_token": token,
1765
                    "subscription_arn": subscription_arn,
1766
                }
1767

1768
        response = Response("", 404)
1✔
1769
        response.set_json(
1✔
1770
            {
1771
                "error": "The provided SubscriptionARN is not found",
1772
                "subscription_arn": subscription_arn,
1773
            }
1774
        )
1775
        return response
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc