• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

localstack / localstack / 21f9bd53-116e-4d5e-b4ed-df3e0ed75cd9

05 Mar 2025 07:28PM UTC coverage: 86.862% (-0.03%) from 86.896%
21f9bd53-116e-4d5e-b4ed-df3e0ed75cd9

push

circleci

web-flow
Step Functions: Improve Nested Map Run Stability (#12343)

48 of 49 new or added lines in 7 files covered. (97.96%)

40 existing lines in 18 files now uncovered.

61896 of 71258 relevant lines covered (86.86%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

84.74
/localstack-core/localstack/services/lambda_/event_source_mapping/pollers/stream_poller.py
1
import json
1✔
2
import logging
1✔
3
import threading
1✔
4
from abc import abstractmethod
1✔
5
from datetime import datetime
1✔
6
from typing import Iterator
1✔
7

8
from botocore.client import BaseClient
1✔
9
from botocore.exceptions import ClientError
1✔
10

11
from localstack.aws.api.pipes import (
1✔
12
    OnPartialBatchItemFailureStreams,
13
)
14
from localstack.services.lambda_.event_source_mapping.event_processor import (
1✔
15
    BatchFailureError,
16
    CustomerInvocationError,
17
    EventProcessor,
18
    PartialBatchFailureError,
19
    PipeInternalError,
20
)
21
from localstack.services.lambda_.event_source_mapping.pipe_utils import (
1✔
22
    get_current_time,
23
    get_datetime_from_timestamp,
24
    get_internal_client,
25
)
26
from localstack.services.lambda_.event_source_mapping.pollers.poller import (
1✔
27
    EmptyPollResultsException,
28
    Poller,
29
    get_batch_item_failures,
30
)
31
from localstack.services.lambda_.event_source_mapping.pollers.sqs_poller import get_queue_url
1✔
32
from localstack.utils.aws.arns import parse_arn, s3_bucket_name
1✔
33
from localstack.utils.backoff import ExponentialBackoff
1✔
34
from localstack.utils.strings import long_uid
1✔
35

36
LOG = logging.getLogger(__name__)
1✔
37

38

39
# TODO: fix this poller to support resharding
40
#   https://docs.aws.amazon.com/streams/latest/dev/kinesis-using-sdk-java-resharding.html
41
class StreamPoller(Poller):
1✔
42
    # Mapping of shard id => shard iterator
43
    shards: dict[str, str]
1✔
44
    # Iterator for round-robin polling from different shards because a batch cannot contain events from different shards
45
    # This is a workaround for not handling shards in parallel.
46
    iterator_over_shards: Iterator[tuple[str, str]] | None
1✔
47
    # ESM UUID is needed in failure processing to form s3 failure destination object key
48
    esm_uuid: str | None
1✔
49

50
    # The ARN of the processor (e.g., Pipe ARN)
51
    partner_resource_arn: str | None
1✔
52

53
    # Used for backing-off between retries and breaking the retry loop
54
    _is_shutdown: threading.Event
1✔
55

56
    def __init__(
1✔
57
        self,
58
        source_arn: str,
59
        source_parameters: dict | None = None,
60
        source_client: BaseClient | None = None,
61
        processor: EventProcessor | None = None,
62
        partner_resource_arn: str | None = None,
63
        esm_uuid: str | None = None,
64
    ):
65
        super().__init__(source_arn, source_parameters, source_client, processor)
1✔
66
        self.partner_resource_arn = partner_resource_arn
1✔
67
        self.esm_uuid = esm_uuid
1✔
68
        self.shards = {}
1✔
69
        self.iterator_over_shards = None
1✔
70

71
        self._is_shutdown = threading.Event()
1✔
72

73
    @abstractmethod
1✔
74
    def transform_into_events(self, records: list[dict], shard_id) -> list[dict]:
1✔
75
        pass
×
76

77
    @property
1✔
78
    @abstractmethod
1✔
79
    def stream_parameters(self) -> dict:
1✔
80
        pass
×
81

82
    @abstractmethod
1✔
83
    def initialize_shards(self) -> dict[str, str]:
1✔
84
        """Returns a shard dict mapping from shard id -> shard iterator
85
        The implementations for Kinesis and DynamoDB are similar but differ in various ways:
86
        * Kinesis uses "StreamARN" and DynamoDB uses "StreamArn" as source parameter
87
        * Kinesis uses "StreamStatus.ACTIVE" and DynamoDB uses "StreamStatus.ENABLED"
88
        * Only Kinesis supports the additional StartingPosition called "AT_TIMESTAMP" using "StartingPositionTimestamp"
89
        """
90
        pass
×
91

92
    @abstractmethod
1✔
93
    def stream_arn_param(self) -> dict:
1✔
94
        """Returns a dict of the correct key/value pair for the stream arn used in GetRecords.
95
        Either StreamARN for Kinesis or {} for DynamoDB (unsupported)"""
96
        pass
×
97

98
    @abstractmethod
1✔
99
    def failure_payload_details_field_name(self) -> str:
1✔
100
        pass
×
101

102
    @abstractmethod
1✔
103
    def get_approximate_arrival_time(self, record: dict) -> float:
1✔
104
        pass
×
105

106
    @abstractmethod
1✔
107
    def format_datetime(self, time: datetime) -> str:
1✔
108
        """Formats a datetime in the correct format for DynamoDB (with ms) or Kinesis (without ms)"""
109
        pass
×
110

111
    @abstractmethod
1✔
112
    def get_sequence_number(self, record: dict) -> str:
1✔
113
        pass
×
114

115
    def close(self):
1✔
116
        self._is_shutdown.set()
1✔
117

118
    def pre_filter(self, events: list[dict]) -> list[dict]:
1✔
119
        return events
1✔
120

121
    def post_filter(self, events: list[dict]) -> list[dict]:
1✔
122
        return events
1✔
123

124
    def poll_events(self):
1✔
125
        """Generalized poller for streams such as Kinesis or DynamoDB
126
        Examples of Kinesis consumers:
127
        * StackOverflow: https://stackoverflow.com/a/22403036/6875981
128
        * AWS Sample: https://github.com/aws-samples/kinesis-poster-worker/blob/master/worker.py
129
        Examples of DynamoDB consumers:
130
        * Blogpost: https://www.tecracer.com/blog/2022/05/getting-a-near-real-time-view-of-a-dynamodb-stream-with-python.html
131
        """
132
        # TODO: consider potential shard iterator timeout after 300 seconds (likely not relevant with short-polling):
133
        #   https://docs.aws.amazon.com/streams/latest/dev/troubleshooting-consumers.html#shard-iterator-expires-unexpectedly
134
        #  Does this happen if no records are received for 300 seconds?
135
        if not self.shards:
1✔
136
            self.shards = self.initialize_shards()
1✔
137

138
        # TODO: improve efficiency because this currently limits the throughput to at most batch size per poll interval
139
        # Handle shards round-robin. Re-initialize current shard iterator once all shards are handled.
140
        if self.iterator_over_shards is None:
1✔
141
            self.iterator_over_shards = iter(self.shards.items())
1✔
142

143
        current_shard_tuple = next(self.iterator_over_shards, None)
1✔
144
        if not current_shard_tuple:
1✔
145
            self.iterator_over_shards = iter(self.shards.items())
1✔
146
            current_shard_tuple = next(self.iterator_over_shards, None)
1✔
147

148
        try:
1✔
149
            self.poll_events_from_shard(*current_shard_tuple)
1✔
150
        # TODO: implement exponential back-off for errors in general
151
        except PipeInternalError:
1✔
152
            # TODO: standardize logging
153
            # Ignore and wait for the next polling interval, which will do retry
154
            pass
1✔
155

156
    def poll_events_from_shard(self, shard_id: str, shard_iterator: str):
1✔
157
        abort_condition = None
1✔
158
        get_records_response = self.get_records(shard_iterator)
1✔
159
        records = get_records_response["Records"]
1✔
160
        polled_events = self.transform_into_events(records, shard_id)
1✔
161
        if not polled_events:
1✔
162
            raise EmptyPollResultsException(service=self.event_source, source_arn=self.source_arn)
1✔
163

164
        # Check MaximumRecordAgeInSeconds
165
        if maximum_record_age_in_seconds := self.stream_parameters.get("MaximumRecordAgeInSeconds"):
1✔
166
            arrival_timestamp_of_last_event = polled_events[-1]["approximateArrivalTimestamp"]
×
167
            now = get_current_time().timestamp()
×
168
            record_age_in_seconds = now - arrival_timestamp_of_last_event
×
169
            if record_age_in_seconds > maximum_record_age_in_seconds:
×
170
                abort_condition = "RecordAgeExpired"
×
171

172
        # TODO: implement format detection behavior (e.g., for JSON body):
173
        #  https://docs.aws.amazon.com/eventbridge/latest/userguide/eb-pipes-event-filtering.html
174
        #  Check whether we need poller-specific filter-preprocessing here without modifying the actual event!
175
        # convert to json for filtering (HACK for fixing parity with v1 and getting regression tests passing)
176
        # localstack.services.lambda_.event_source_listeners.kinesis_event_source_listener.KinesisEventSourceListener._filter_records
177
        # TODO: explore better abstraction for the entire filtering, including the set_data and get_data remapping
178
        #  We need better clarify which transformations happen before and after filtering -> fix missing test coverage
179
        parsed_events = self.pre_filter(polled_events)
1✔
180
        # TODO: advance iterator past matching events!
181
        #  We need to checkpoint the sequence number for each shard and then advance the shard iterator using
182
        #  GetShardIterator with a given sequence number
183
        #  https://docs.aws.amazon.com/kinesis/latest/APIReference/API_GetShardIterator.html
184
        #  Failing to do so kinda blocks the stream resulting in very high latency.
185
        matching_events = self.filter_events(parsed_events)
1✔
186
        matching_events_post_filter = self.post_filter(matching_events)
1✔
187

188
        # TODO: implement MaximumBatchingWindowInSeconds flush condition (before or after filter?)
189
        # Don't trigger upon empty events
190
        if len(matching_events_post_filter) == 0:
1✔
191
            # Update shard iterator if no records match the filter
192
            self.shards[shard_id] = get_records_response["NextShardIterator"]
1✔
193
            return
1✔
194
        events = self.add_source_metadata(matching_events_post_filter)
1✔
195
        LOG.debug("Polled %d events from %s in shard %s", len(events), self.source_arn, shard_id)
1✔
196
        # TODO: A retry should probably re-trigger fetching the record from the stream again?!
197
        #  -> This could be tested by setting a high retry number, using a long pipe execution, and a relatively
198
        #  short record expiration age at the source. Check what happens if the record expires at the source.
199
        #  A potential implementation could use checkpointing based on the iterator position (within shard scope)
200
        # TODO: handle partial batch failure (see poller.py:parse_batch_item_failures)
201
        # TODO: think about how to avoid starvation of other shards if one shard runs into infinite retries
202
        attempts = 0
1✔
203
        error_payload = {}
1✔
204

205
        boff = ExponentialBackoff(max_retries=attempts)
1✔
206
        while (
1✔
207
            not abort_condition
208
            and not self.max_retries_exceeded(attempts)
209
            and not self._is_shutdown.is_set()
210
        ):
211
            try:
1✔
212
                if attempts > 0:
1✔
213
                    # TODO: Should we always backoff (with jitter) before processing since we may not want multiple pollers
214
                    # all starting up and polling simultaneously
215
                    # For example: 500 persisted ESMs starting up and requesting concurrently could flood gateway
216
                    self._is_shutdown.wait(boff.next_backoff())
1✔
217

218
                self.processor.process_events_batch(events)
1✔
219
                boff.reset()
1✔
220

221
                # Update shard iterator if execution is successful
222
                self.shards[shard_id] = get_records_response["NextShardIterator"]
1✔
223
                return
1✔
224
            except PartialBatchFailureError as ex:
1✔
225
                # TODO: add tests for partial batch failure scenarios
226
                if (
1✔
227
                    self.stream_parameters.get("OnPartialBatchItemFailure")
228
                    == OnPartialBatchItemFailureStreams.AUTOMATIC_BISECT
229
                ):
230
                    # TODO: implement and test splitting batches in half until batch size 1
231
                    #  https://docs.aws.amazon.com/eventbridge/latest/pipes-reference/API_PipeSourceKinesisStreamParameters.html
232
                    LOG.warning(
×
233
                        "AUTOMATIC_BISECT upon partial batch item failure is not yet implemented. Retrying the entire batch."
234
                    )
235
                error_payload = ex.error
1✔
236

237
                # Extract all sequence numbers from events in batch. This allows us to fail the whole batch if
238
                # an unknown itemidentifier is returned.
239
                batch_sequence_numbers = {
1✔
240
                    self.get_sequence_number(event) for event in matching_events
241
                }
242

243
                # If the batchItemFailures array contains multiple items, Lambda uses the record with the lowest sequence number as the checkpoint.
244
                # Lambda then retries all records starting from that checkpoint.
245
                failed_sequence_ids: list[int] | None = get_batch_item_failures(
1✔
246
                    ex.partial_failure_payload, batch_sequence_numbers
247
                )
248

249
                # If None is returned, consider the entire batch a failure.
250
                if failed_sequence_ids is None:
1✔
251
                    continue
1✔
252

253
                # This shouldn't be possible since a PartialBatchFailureError was raised
254
                if len(failed_sequence_ids) == 0:
1✔
255
                    assert failed_sequence_ids, (
×
256
                        "Invalid state encountered: PartialBatchFailureError raised but no batch item failures found."
257
                    )
258

259
                lowest_sequence_id: str = min(failed_sequence_ids, key=int)
1✔
260

261
                # Discard all successful events and re-process from sequence number of failed event
262
                _, events = self.bisect_events(lowest_sequence_id, events)
1✔
263
            except (BatchFailureError, Exception) as ex:
1✔
264
                if isinstance(ex, BatchFailureError):
1✔
265
                    error_payload = ex.error
1✔
266

267
                # FIXME partner_resource_arn is not defined in ESM
268
                LOG.debug(
1✔
269
                    "Attempt %d failed while processing %s with events: %s",
270
                    attempts,
271
                    self.partner_resource_arn or self.source_arn,
272
                    events,
273
                )
274
            finally:
275
                # Retry polling until the record expires at the source
276
                attempts += 1
1✔
277

278
        # Send failed events to potential DLQ
279
        abort_condition = abort_condition or "RetryAttemptsExhausted"
1✔
280
        failure_context = self.processor.generate_event_failure_context(
1✔
281
            abort_condition=abort_condition,
282
            error=error_payload,
283
            attempts_count=attempts,
284
            partner_resource_arn=self.partner_resource_arn,
285
        )
286
        self.send_events_to_dlq(shard_id, events, context=failure_context)
1✔
287
        # Update shard iterator if the execution failed but the events are sent to a DLQ
288
        self.shards[shard_id] = get_records_response["NextShardIterator"]
1✔
289

290
    def get_records(self, shard_iterator: str) -> dict:
1✔
291
        """Returns a GetRecordsOutput from the GetRecords endpoint of streaming services such as Kinesis or DynamoDB"""
292
        try:
1✔
293
            get_records_response = self.source_client.get_records(
1✔
294
                # TODO: add test for cross-account scenario
295
                # Differs for Kinesis and DynamoDB but required for cross-account scenario
296
                **self.stream_arn_param(),
297
                ShardIterator=shard_iterator,
298
                Limit=self.stream_parameters["BatchSize"],
299
            )
300
            return get_records_response
1✔
301
        # TODO: test iterator expired with conditional error scenario (requires failure destinations)
302
        except self.source_client.exceptions.ExpiredIteratorException as e:
1✔
UNCOV
303
            LOG.debug(
×
304
                "Shard iterator %s expired for stream %s, re-initializing shards",
305
                shard_iterator,
306
                self.source_arn,
307
            )
308
            # TODO: test TRIM_HORIZON and AT_TIMESTAMP scenarios for this case. We don't want to start from scratch and
309
            #  might need to think about checkpointing here.
UNCOV
310
            self.shards = self.initialize_shards()
×
UNCOV
311
            raise PipeInternalError from e
×
312
        except ClientError as e:
1✔
313
            if "AccessDeniedException" in str(e):
1✔
314
                LOG.warning(
×
315
                    "Insufficient permissions to get records from stream %s: %s",
316
                    self.source_arn,
317
                    e,
318
                )
319
                raise CustomerInvocationError from e
×
320
            elif "ResourceNotFoundException" in str(e):
1✔
321
                # FIXME: The 'Invalid ShardId in ShardIterator' error is returned by DynamoDB-local. Unsure when/why this is returned.
UNCOV
322
                if "Invalid ShardId in ShardIterator" in str(e):
×
323
                    LOG.warning(
×
324
                        "Invalid ShardId in ShardIterator for %s. Re-initializing shards.",
325
                        self.source_arn,
326
                    )
327
                    self.initialize_shards()
×
328
                else:
UNCOV
329
                    LOG.warning(
×
330
                        "Source stream %s does not exist: %s",
331
                        self.source_arn,
332
                        e,
333
                    )
UNCOV
334
                    raise CustomerInvocationError from e
×
335
            elif "TrimmedDataAccessException" in str(e):
1✔
336
                LOG.debug(
×
337
                    "Attempted to iterate over trimmed record or expired shard iterator %s for stream %s, re-initializing shards",
338
                    shard_iterator,
339
                    self.source_arn,
340
                )
341
                self.initialize_shards()
×
342
            else:
343
                LOG.debug("ClientError during get_records for stream %s: %s", self.source_arn, e)
1✔
344
                raise PipeInternalError from e
1✔
345

346
    def send_events_to_dlq(self, shard_id, events, context) -> None:
1✔
347
        dlq_arn = self.stream_parameters.get("DeadLetterConfig", {}).get("Arn")
1✔
348
        if dlq_arn:
1✔
349
            failure_timstamp = get_current_time()
1✔
350
            dlq_event = self.create_dlq_event(shard_id, events, context, failure_timstamp)
1✔
351
            # Send DLQ event to DLQ target
352
            parsed_arn = parse_arn(dlq_arn)
1✔
353
            service = parsed_arn["service"]
1✔
354
            # TODO: use a sender instance here, likely inject via DI into poller (what if it updates?)
355
            if service == "sqs":
1✔
356
                # TODO: inject and cache SQS client using proper IAM role (supports cross-account operations)
357
                sqs_client = get_internal_client(dlq_arn)
1✔
358
                # TODO: check if the DLQ exists
359
                dlq_url = get_queue_url(dlq_arn)
1✔
360
                # TODO: validate no FIFO queue because they are unsupported
361
                sqs_client.send_message(QueueUrl=dlq_url, MessageBody=json.dumps(dlq_event))
1✔
362
            elif service == "sns":
1✔
363
                sns_client = get_internal_client(dlq_arn)
1✔
364
                sns_client.publish(TopicArn=dlq_arn, Message=json.dumps(dlq_event))
1✔
365
            elif service == "s3":
1✔
366
                s3_client = get_internal_client(dlq_arn)
1✔
367
                dlq_event_with_payload = {
1✔
368
                    **dlq_event,
369
                    "payload": {
370
                        "Records": events,
371
                    },
372
                }
373
                s3_client.put_object(
1✔
374
                    Bucket=s3_bucket_name(dlq_arn),
375
                    Key=get_failure_s3_object_key(self.esm_uuid, shard_id, failure_timstamp),
376
                    Body=json.dumps(dlq_event_with_payload),
377
                )
378
            else:
379
                LOG.warning("Unsupported DLQ service %s", service)
×
380

381
    def create_dlq_event(
1✔
382
        self, shard_id: str, events: list[dict], context: dict, failure_timestamp: datetime
383
    ) -> dict:
384
        first_record = events[0]
1✔
385
        first_record_arrival = get_datetime_from_timestamp(
1✔
386
            self.get_approximate_arrival_time(first_record)
387
        )
388

389
        last_record = events[-1]
1✔
390
        last_record_arrival = get_datetime_from_timestamp(
1✔
391
            self.get_approximate_arrival_time(last_record)
392
        )
393
        return {
1✔
394
            **context,
395
            self.failure_payload_details_field_name(): {
396
                "approximateArrivalOfFirstRecord": self.format_datetime(first_record_arrival),
397
                "approximateArrivalOfLastRecord": self.format_datetime(last_record_arrival),
398
                "batchSize": len(events),
399
                "endSequenceNumber": self.get_sequence_number(last_record),
400
                "shardId": shard_id,
401
                "startSequenceNumber": self.get_sequence_number(first_record),
402
                "streamArn": self.source_arn,
403
            },
404
            "timestamp": failure_timestamp.isoformat(timespec="milliseconds").replace(
405
                "+00:00", "Z"
406
            ),
407
            "version": "1.0",
408
        }
409

410
    def max_retries_exceeded(self, attempts: int) -> bool:
1✔
411
        maximum_retry_attempts = self.stream_parameters.get("MaximumRetryAttempts", -1)
1✔
412
        # Infinite retries until the source expires
413
        if maximum_retry_attempts == -1:
1✔
414
            return False
1✔
415
        return attempts > maximum_retry_attempts
1✔
416

417
    def bisect_events(
1✔
418
        self, sequence_number: str, events: list[dict]
419
    ) -> tuple[list[dict], list[dict]]:
420
        """Splits list of events in two, where a sequence number equals a passed parameter `sequence_number`.
421
        This is used for:
422
          - `ReportBatchItemFailures`: Discarding events in a batch following a failure when is set.
423
          - `BisectBatchOnFunctionError`: Used to split a failed batch in two when doing a retry (not implemented)."""
424
        for i, event in enumerate(events):
1✔
425
            if self.get_sequence_number(event) == sequence_number:
1✔
426
                return events[:i], events[i:]
1✔
427

428
        return events, []
×
429

430

431
def get_failure_s3_object_key(esm_uuid: str, shard_id: str, failure_datetime: datetime) -> str:
1✔
432
    """
433
    From https://docs.aws.amazon.com/lambda/latest/dg/kinesis-on-failure-destination.html:
434

435
    The S3 object containing the invocation record uses the following naming convention:
436
    aws/lambda/<ESM-UUID>/<shardID>/YYYY/MM/DD/YYYY-MM-DDTHH.MM.SS-<Random UUID>
437

438
    :return: Key for s3 object that invocation failure record will be put to
439
    """
440
    timestamp = failure_datetime.strftime("%Y-%m-%dT%H.%M.%S")
1✔
441
    year_month_day = failure_datetime.strftime("%Y/%m/%d")
1✔
442
    random_uuid = long_uid()
1✔
443
    return f"aws/lambda/{esm_uuid}/{shard_id}/{year_month_day}/{timestamp}-{random_uuid}"
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc