• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

localstack / localstack / 16820655284

07 Aug 2025 05:03PM UTC coverage: 86.841% (-0.05%) from 86.892%
16820655284

push

github

web-flow
CFNV2: support CDK bootstrap and deployment (#12967)

32 of 38 new or added lines in 5 files covered. (84.21%)

2013 existing lines in 125 files now uncovered.

66606 of 76699 relevant lines covered (86.84%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.97
/localstack-core/localstack/services/dynamodb/provider.py
1
import copy
1✔
2
import json
1✔
3
import logging
1✔
4
import os
1✔
5
import random
1✔
6
import re
1✔
7
import threading
1✔
8
import time
1✔
9
import traceback
1✔
10
from collections import defaultdict
1✔
11
from concurrent.futures import ThreadPoolExecutor
1✔
12
from contextlib import contextmanager
1✔
13
from datetime import datetime
1✔
14
from operator import itemgetter
1✔
15

16
import requests
1✔
17
import werkzeug
1✔
18

19
from localstack import config
1✔
20
from localstack.aws import handlers
1✔
21
from localstack.aws.api import (
1✔
22
    CommonServiceException,
23
    RequestContext,
24
    ServiceRequest,
25
    ServiceResponse,
26
    handler,
27
)
28
from localstack.aws.api.dynamodb import (
1✔
29
    ApproximateCreationDateTimePrecision,
30
    AttributeMap,
31
    BatchExecuteStatementOutput,
32
    BatchGetItemOutput,
33
    BatchGetRequestMap,
34
    BatchGetResponseMap,
35
    BatchWriteItemInput,
36
    BatchWriteItemOutput,
37
    BatchWriteItemRequestMap,
38
    BillingMode,
39
    ContinuousBackupsDescription,
40
    ContinuousBackupsStatus,
41
    CreateGlobalTableOutput,
42
    CreateTableInput,
43
    CreateTableOutput,
44
    Delete,
45
    DeleteItemInput,
46
    DeleteItemOutput,
47
    DeleteRequest,
48
    DeleteTableOutput,
49
    DescribeContinuousBackupsOutput,
50
    DescribeGlobalTableOutput,
51
    DescribeKinesisStreamingDestinationOutput,
52
    DescribeTableOutput,
53
    DescribeTimeToLiveOutput,
54
    DestinationStatus,
55
    DynamodbApi,
56
    EnableKinesisStreamingConfiguration,
57
    ExecuteStatementInput,
58
    ExecuteStatementOutput,
59
    ExecuteTransactionInput,
60
    ExecuteTransactionOutput,
61
    GetItemInput,
62
    GetItemOutput,
63
    GlobalTableAlreadyExistsException,
64
    GlobalTableNotFoundException,
65
    KinesisStreamingDestinationOutput,
66
    ListGlobalTablesOutput,
67
    ListTablesInputLimit,
68
    ListTablesOutput,
69
    ListTagsOfResourceOutput,
70
    NextTokenString,
71
    PartiQLBatchRequest,
72
    PointInTimeRecoveryDescription,
73
    PointInTimeRecoverySpecification,
74
    PointInTimeRecoveryStatus,
75
    PositiveIntegerObject,
76
    ProvisionedThroughputExceededException,
77
    Put,
78
    PutItemInput,
79
    PutItemOutput,
80
    PutRequest,
81
    QueryInput,
82
    QueryOutput,
83
    RegionName,
84
    ReplicaDescription,
85
    ReplicaList,
86
    ReplicaStatus,
87
    ReplicaUpdateList,
88
    ResourceArnString,
89
    ResourceInUseException,
90
    ResourceNotFoundException,
91
    ReturnConsumedCapacity,
92
    ScanInput,
93
    ScanOutput,
94
    StreamArn,
95
    TableArn,
96
    TableDescription,
97
    TableName,
98
    TagKeyList,
99
    TagList,
100
    TimeToLiveSpecification,
101
    TransactGetItemList,
102
    TransactGetItemsOutput,
103
    TransactWriteItem,
104
    TransactWriteItemList,
105
    TransactWriteItemsInput,
106
    TransactWriteItemsOutput,
107
    Update,
108
    UpdateContinuousBackupsOutput,
109
    UpdateGlobalTableOutput,
110
    UpdateItemInput,
111
    UpdateItemOutput,
112
    UpdateKinesisStreamingConfiguration,
113
    UpdateKinesisStreamingDestinationOutput,
114
    UpdateTableInput,
115
    UpdateTableOutput,
116
    UpdateTimeToLiveOutput,
117
    WriteRequest,
118
)
119
from localstack.aws.api.dynamodbstreams import StreamStatus
1✔
120
from localstack.aws.connect import connect_to
1✔
121
from localstack.config import is_persistence_enabled
1✔
122
from localstack.constants import (
1✔
123
    AUTH_CREDENTIAL_REGEX,
124
    AWS_REGION_US_EAST_1,
125
    INTERNAL_AWS_SECRET_ACCESS_KEY,
126
)
127
from localstack.http import Request, Response, route
1✔
128
from localstack.services.dynamodb.models import (
1✔
129
    DynamoDBStore,
130
    RecordsMap,
131
    StreamRecord,
132
    StreamRecords,
133
    TableRecords,
134
    TableStreamType,
135
    dynamodb_stores,
136
)
137
from localstack.services.dynamodb.server import DynamodbServer
1✔
138
from localstack.services.dynamodb.utils import (
1✔
139
    ItemFinder,
140
    ItemSet,
141
    SchemaExtractor,
142
    de_dynamize_record,
143
    extract_table_name_from_partiql_update,
144
    get_ddb_access_key,
145
    modify_ddblocal_arns,
146
)
147
from localstack.services.dynamodbstreams import dynamodbstreams_api
1✔
148
from localstack.services.dynamodbstreams.models import dynamodbstreams_stores
1✔
149
from localstack.services.edge import ROUTER
1✔
150
from localstack.services.plugins import ServiceLifecycleHook
1✔
151
from localstack.state import AssetDirectory, StateVisitor
1✔
152
from localstack.utils.aws import arns
1✔
153
from localstack.utils.aws.arns import (
1✔
154
    extract_account_id_from_arn,
155
    extract_region_from_arn,
156
    get_partition,
157
)
158
from localstack.utils.aws.aws_stack import get_valid_regions_for_service
1✔
159
from localstack.utils.aws.request_context import (
1✔
160
    extract_account_id_from_headers,
161
    extract_region_from_headers,
162
)
163
from localstack.utils.collections import select_attributes, select_from_typed_dict
1✔
164
from localstack.utils.common import short_uid, to_bytes
1✔
165
from localstack.utils.files import cp_r, rm_rf
1✔
166
from localstack.utils.json import BytesEncoder, canonical_json
1✔
167
from localstack.utils.scheduler import Scheduler
1✔
168
from localstack.utils.strings import long_uid, md5, to_str
1✔
169
from localstack.utils.threads import FuncThread, start_thread
1✔
170

171
# set up logger
172
LOG = logging.getLogger(__name__)
1✔
173

174
# action header prefix
175
ACTION_PREFIX = "DynamoDB_20120810."
1✔
176

177
# list of actions subject to throughput limitations
178
READ_THROTTLED_ACTIONS = [
1✔
179
    "GetItem",
180
    "Query",
181
    "Scan",
182
    "TransactGetItems",
183
    "BatchGetItem",
184
]
185
WRITE_THROTTLED_ACTIONS = [
1✔
186
    "PutItem",
187
    "BatchWriteItem",
188
    "UpdateItem",
189
    "DeleteItem",
190
    "TransactWriteItems",
191
]
192
THROTTLED_ACTIONS = READ_THROTTLED_ACTIONS + WRITE_THROTTLED_ACTIONS
1✔
193

194
MANAGED_KMS_KEYS = {}
1✔
195

196

197
def dynamodb_table_exists(table_name: str, client=None) -> bool:
1✔
198
    client = client or connect_to().dynamodb
1✔
199
    paginator = client.get_paginator("list_tables")
1✔
200
    pages = paginator.paginate(PaginationConfig={"PageSize": 100})
1✔
201
    table_name = to_str(table_name)
1✔
202
    return any(table_name in page["TableNames"] for page in pages)
1✔
203

204

205
class EventForwarder:
1✔
206
    def __init__(self, num_thread: int = 10):
1✔
207
        self.executor = ThreadPoolExecutor(num_thread, thread_name_prefix="ddb_stream_fwd")
1✔
208

209
    def shutdown(self):
1✔
210
        self.executor.shutdown(wait=False)
1✔
211

212
    def forward_to_targets(
1✔
213
        self, account_id: str, region_name: str, records_map: RecordsMap, background: bool = True
214
    ) -> None:
215
        if background:
1✔
216
            self._submit_records(
1✔
217
                account_id=account_id,
218
                region_name=region_name,
219
                records_map=records_map,
220
            )
221
        else:
UNCOV
222
            self._forward(account_id, region_name, records_map)
×
223

224
    def _submit_records(self, account_id: str, region_name: str, records_map: RecordsMap):
1✔
225
        """Required for patching submit with local thread context for EventStudio"""
226
        self.executor.submit(
1✔
227
            self._forward,
228
            account_id,
229
            region_name,
230
            records_map,
231
        )
232

233
    def _forward(self, account_id: str, region_name: str, records_map: RecordsMap) -> None:
1✔
234
        try:
1✔
235
            self.forward_to_kinesis_stream(account_id, region_name, records_map)
1✔
UNCOV
236
        except Exception as e:
×
237
            LOG.debug(
×
238
                "Error while publishing to Kinesis streams: '%s'",
239
                e,
240
                exc_info=LOG.isEnabledFor(logging.DEBUG),
241
            )
242

243
        try:
1✔
244
            self.forward_to_ddb_stream(account_id, region_name, records_map)
1✔
UNCOV
245
        except Exception as e:
×
246
            LOG.debug(
×
247
                "Error while publishing to DynamoDB streams, '%s'",
248
                e,
249
                exc_info=LOG.isEnabledFor(logging.DEBUG),
250
            )
251

252
    @staticmethod
1✔
253
    def forward_to_ddb_stream(account_id: str, region_name: str, records_map: RecordsMap) -> None:
1✔
254
        dynamodbstreams_api.forward_events(account_id, region_name, records_map)
1✔
255

256
    @staticmethod
1✔
257
    def forward_to_kinesis_stream(
1✔
258
        account_id: str, region_name: str, records_map: RecordsMap
259
    ) -> None:
260
        # You can only stream data from DynamoDB to Kinesis Data Streams in the same AWS account and AWS Region as your
261
        # table.
262
        # You can only stream data from a DynamoDB table to one Kinesis data stream.
263
        store = get_store(account_id, region_name)
1✔
264

265
        for table_name, table_records in records_map.items():
1✔
266
            table_stream_type = table_records["table_stream_type"]
1✔
267
            if not table_stream_type.is_kinesis:
1✔
268
                continue
1✔
269

270
            kinesis_records = []
1✔
271

272
            table_arn = arns.dynamodb_table_arn(table_name, account_id, region_name)
1✔
273
            records = table_records["records"]
1✔
274
            table_def = store.table_definitions.get(table_name) or {}
1✔
275
            stream_arn = table_def["KinesisDataStreamDestinations"][-1]["StreamArn"]
1✔
276
            for record in records:
1✔
277
                kinesis_record = dict(
1✔
278
                    tableName=table_name,
279
                    recordFormat="application/json",
280
                    userIdentity=None,
281
                    **record,
282
                )
283
                fields_to_remove = {"StreamViewType", "SequenceNumber"}
1✔
284
                kinesis_record["dynamodb"] = {
1✔
285
                    k: v for k, v in record["dynamodb"].items() if k not in fields_to_remove
286
                }
287
                kinesis_record.pop("eventVersion", None)
1✔
288

289
                hash_keys = list(
1✔
290
                    filter(lambda key: key["KeyType"] == "HASH", table_def["KeySchema"])
291
                )
292
                # TODO: reverse properly how AWS creates the partition key, it seems to be an MD5 hash
293
                kinesis_partition_key = md5(f"{table_name}{hash_keys[0]['AttributeName']}")
1✔
294

295
                kinesis_records.append(
1✔
296
                    {
297
                        "Data": json.dumps(kinesis_record, cls=BytesEncoder),
298
                        "PartitionKey": kinesis_partition_key,
299
                    }
300
                )
301

302
            kinesis = connect_to(
1✔
303
                aws_access_key_id=account_id,
304
                aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
305
                region_name=region_name,
306
            ).kinesis.request_metadata(service_principal="dynamodb", source_arn=table_arn)
307

308
            kinesis.put_records(
1✔
309
                StreamARN=stream_arn,
310
                Records=kinesis_records,
311
            )
312

313
    @classmethod
1✔
314
    def is_kinesis_stream_exists(cls, stream_arn):
1✔
315
        account_id = extract_account_id_from_arn(stream_arn)
1✔
316
        region_name = extract_region_from_arn(stream_arn)
1✔
317

318
        kinesis = connect_to(
1✔
319
            aws_access_key_id=account_id,
320
            aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
321
            region_name=region_name,
322
        ).kinesis
323
        stream_name_from_arn = stream_arn.split("/", 1)[1]
1✔
324
        # check if the stream exists in kinesis for the user
325
        filtered = list(
1✔
326
            filter(
327
                lambda stream_name: stream_name == stream_name_from_arn,
328
                kinesis.list_streams()["StreamNames"],
329
            )
330
        )
331
        return bool(filtered)
1✔
332

333

334
class SSEUtils:
1✔
335
    """Utils for server-side encryption (SSE)"""
336

337
    @classmethod
1✔
338
    def get_sse_kms_managed_key(cls, account_id: str, region_name: str):
1✔
339
        from localstack.services.kms import provider
1✔
340

341
        existing_key = MANAGED_KMS_KEYS.get(region_name)
1✔
342
        if existing_key:
1✔
343
            return existing_key
1✔
344
        kms_client = connect_to(
1✔
345
            aws_access_key_id=account_id,
346
            aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
347
            region_name=region_name,
348
        ).kms
349
        key_data = kms_client.create_key(
1✔
350
            Description="Default key that protects my DynamoDB data when no other key is defined"
351
        )
352
        key_id = key_data["KeyMetadata"]["KeyId"]
1✔
353

354
        provider.set_key_managed(key_id, account_id, region_name)
1✔
355
        MANAGED_KMS_KEYS[region_name] = key_id
1✔
356
        return key_id
1✔
357

358
    @classmethod
1✔
359
    def get_sse_description(cls, account_id: str, region_name: str, data):
1✔
360
        if data.get("Enabled"):
1✔
361
            kms_master_key_id = data.get("KMSMasterKeyId")
1✔
362
            if not kms_master_key_id:
1✔
363
                # this is of course not the actual key for dynamodb, just a better, since existing, mock
364
                kms_master_key_id = cls.get_sse_kms_managed_key(account_id, region_name)
1✔
365
            kms_master_key_id = arns.kms_key_arn(kms_master_key_id, account_id, region_name)
1✔
366
            return {
1✔
367
                "Status": "ENABLED",
368
                "SSEType": "KMS",  # no other value is allowed here
369
                "KMSMasterKeyArn": kms_master_key_id,
370
            }
UNCOV
371
        return {}
×
372

373

374
class ValidationException(CommonServiceException):
1✔
375
    def __init__(self, message: str):
1✔
376
        super().__init__(code="ValidationException", status_code=400, message=message)
1✔
377

378

379
def get_store(account_id: str, region_name: str) -> DynamoDBStore:
1✔
380
    # special case: AWS NoSQL Workbench sends "localhost" as region - replace with proper region here
381
    region_name = DynamoDBProvider.ddb_region_name(region_name)
1✔
382
    return dynamodb_stores[account_id][region_name]
1✔
383

384

385
@contextmanager
1✔
386
def modify_context_region(context: RequestContext, region: str):
1✔
387
    """
388
    Context manager that modifies the region of a `RequestContext`. At the exit, the context is restored to its
389
    original state.
390

391
    :param context: the context to modify
392
    :param region: the modified region
393
    :return: a modified `RequestContext`
394
    """
395
    original_region = context.region
1✔
396
    original_authorization = context.request.headers.get("Authorization")
1✔
397

398
    key = get_ddb_access_key(context.account_id, region)
1✔
399

400
    context.region = region
1✔
401
    context.request.headers["Authorization"] = re.sub(
1✔
402
        AUTH_CREDENTIAL_REGEX,
403
        rf"Credential={key}/\2/{region}/\4/",
404
        original_authorization or "",
405
        flags=re.IGNORECASE,
406
    )
407

408
    try:
1✔
409
        yield context
1✔
410
    except Exception:
1✔
411
        raise
1✔
412
    finally:
413
        # revert the original context
414
        context.region = original_region
1✔
415
        context.request.headers["Authorization"] = original_authorization
1✔
416

417

418
class DynamoDBDeveloperEndpoints:
1✔
419
    """
420
    Developer endpoints for DynamoDB
421
    DELETE /_aws/dynamodb/expired - delete expired items from tables with TTL enabled; return the number of expired
422
        items deleted
423
    """
424

425
    @route("/_aws/dynamodb/expired", methods=["DELETE"])
1✔
426
    def delete_expired_messages(self, _: Request):
1✔
427
        no_expired_items = delete_expired_items()
1✔
428
        return {"ExpiredItems": no_expired_items}
1✔
429

430

431
def delete_expired_items() -> int:
1✔
432
    """
433
    This utility function iterates over all stores, looks for tables with TTL enabled,
434
    scan such tables and delete expired items.
435
    """
436
    no_expired_items = 0
1✔
437
    for account_id, region_name, state in dynamodb_stores.iter_stores():
1✔
438
        ttl_specs = state.ttl_specifications
1✔
439
        client = connect_to(aws_access_key_id=account_id, region_name=region_name).dynamodb
1✔
440
        for table_name, ttl_spec in ttl_specs.items():
1✔
441
            if ttl_spec.get("Enabled", False):
1✔
442
                attribute_name = ttl_spec.get("AttributeName")
1✔
443
                current_time = int(datetime.now().timestamp())
1✔
444
                try:
1✔
445
                    result = client.scan(
1✔
446
                        TableName=table_name,
447
                        FilterExpression="#ttl <= :threshold",
448
                        ExpressionAttributeValues={":threshold": {"N": str(current_time)}},
449
                        ExpressionAttributeNames={"#ttl": attribute_name},
450
                    )
451
                    items_to_delete = result.get("Items", [])
1✔
452
                    no_expired_items += len(items_to_delete)
1✔
453
                    table_description = client.describe_table(TableName=table_name)
1✔
454
                    partition_key, range_key = _get_hash_and_range_key(table_description)
1✔
455
                    keys_to_delete = [
1✔
456
                        {partition_key: item.get(partition_key)}
457
                        if range_key is None
458
                        else {
459
                            partition_key: item.get(partition_key),
460
                            range_key: item.get(range_key),
461
                        }
462
                        for item in items_to_delete
463
                    ]
464
                    delete_requests = [{"DeleteRequest": {"Key": key}} for key in keys_to_delete]
1✔
465
                    for i in range(0, len(delete_requests), 25):
1✔
466
                        batch = delete_requests[i : i + 25]
1✔
467
                        client.batch_write_item(RequestItems={table_name: batch})
1✔
468
                except Exception as e:
1✔
469
                    LOG.warning(
1✔
470
                        "An error occurred when deleting expired items from table %s: %s",
471
                        table_name,
472
                        e,
473
                    )
474
    return no_expired_items
1✔
475

476

477
def _get_hash_and_range_key(table_description: DescribeTableOutput) -> [str, str | None]:
1✔
478
    key_schema = table_description.get("Table", {}).get("KeySchema", [])
1✔
479
    hash_key, range_key = None, None
1✔
480
    for key in key_schema:
1✔
481
        if key["KeyType"] == "HASH":
1✔
482
            hash_key = key["AttributeName"]
1✔
483
        if key["KeyType"] == "RANGE":
1✔
484
            range_key = key["AttributeName"]
1✔
485
    return hash_key, range_key
1✔
486

487

488
class ExpiredItemsWorker:
1✔
489
    """A worker that periodically computes and deletes expired items from DynamoDB tables"""
490

491
    def __init__(self) -> None:
1✔
492
        super().__init__()
1✔
493
        self.scheduler = Scheduler()
1✔
494
        self.thread: FuncThread | None = None
1✔
495
        self.mutex = threading.RLock()
1✔
496

497
    def start(self):
1✔
UNCOV
498
        with self.mutex:
×
499
            if self.thread:
×
500
                return
×
501

UNCOV
502
            self.scheduler = Scheduler()
×
503
            self.scheduler.schedule(
×
504
                delete_expired_items, period=60 * 60
505
            )  # the background process seems slow on AWS
506

UNCOV
507
            def _run(*_args):
×
508
                self.scheduler.run()
×
509

UNCOV
510
            self.thread = start_thread(_run, name="ddb-remove-expired-items")
×
511

512
    def stop(self):
1✔
513
        with self.mutex:
1✔
514
            if self.scheduler:
1✔
515
                self.scheduler.close()
1✔
516

517
            if self.thread:
1✔
UNCOV
518
                self.thread.stop()
×
519

520
            self.thread = None
1✔
521
            self.scheduler = None
1✔
522

523

524
class DynamoDBProvider(DynamodbApi, ServiceLifecycleHook):
1✔
525
    server: DynamodbServer
1✔
526
    """The instance of the server managing the instance of DynamoDB local"""
1✔
527
    asset_directory = f"{config.dirs.data}/dynamodb"
1✔
528
    """The directory that contains the .db files saved by DynamoDB Local"""
1✔
529
    tmp_asset_directory = f"{config.dirs.tmp}/dynamodb"
1✔
530
    """Temporary directory for the .db files saved by DynamoDB Local when MANUAL snapshot persistence is enabled"""
1✔
531

532
    def __init__(self):
1✔
533
        self.server = self._new_dynamodb_server()
1✔
534
        self._expired_items_worker = ExpiredItemsWorker()
1✔
535
        self._router_rules = []
1✔
536
        self._event_forwarder = EventForwarder()
1✔
537

538
    def on_before_start(self):
1✔
539
        # We must copy back whatever state is saved to the temporary location to avoid to start always from a blank
540
        # state. See the `on_before_state_save` hook.
541
        if is_persistence_enabled() and config.SNAPSHOT_SAVE_STRATEGY == "MANUAL":
1✔
UNCOV
542
            if os.path.exists(self.asset_directory):
×
543
                LOG.debug("Copying %s to %s", self.tmp_asset_directory, self.asset_directory)
×
544
                cp_r(self.asset_directory, self.tmp_asset_directory, rm_dest_on_conflict=True)
×
545

546
        self.server.start_dynamodb()
1✔
547
        if config.DYNAMODB_REMOVE_EXPIRED_ITEMS:
1✔
UNCOV
548
            self._expired_items_worker.start()
×
549
        self._router_rules = ROUTER.add(DynamoDBDeveloperEndpoints())
1✔
550

551
    def on_before_stop(self):
1✔
552
        self._expired_items_worker.stop()
1✔
553
        ROUTER.remove(self._router_rules)
1✔
554
        self._event_forwarder.shutdown()
1✔
555

556
    def accept_state_visitor(self, visitor: StateVisitor):
1✔
UNCOV
557
        visitor.visit(dynamodb_stores)
×
558
        visitor.visit(dynamodbstreams_stores)
×
559
        visitor.visit(AssetDirectory(self.service, os.path.join(config.dirs.data, self.service)))
×
560

561
    def on_before_state_reset(self):
1✔
UNCOV
562
        self.server.stop_dynamodb()
×
563
        rm_rf(self.tmp_asset_directory)
×
564

565
    def on_before_state_load(self):
1✔
UNCOV
566
        self.server.stop_dynamodb()
×
567

568
    def on_after_state_reset(self):
1✔
UNCOV
569
        self.server.start_dynamodb()
×
570

571
    @staticmethod
1✔
572
    def _new_dynamodb_server() -> DynamodbServer:
1✔
573
        return DynamodbServer.get()
1✔
574

575
    def on_after_state_load(self):
1✔
UNCOV
576
        self.server.start_dynamodb()
×
577

578
    def on_before_state_save(self) -> None:
1✔
579
        # When the save strategy is MANUAL, we do not save the DB path to the usual ``confid.dirs.data`` folder.
580
        # With the MANUAL strategy, we want to take a snapshot on-demand but this is not possible if the DB files
581
        # are already in ``config.dirs.data``. For instance, the set of operation below will result in both tables
582
        # being implicitly saved.
583
        # - awslocal dynamodb create-table table1
584
        # - curl -X POST http://localhost:4566/_localstack/state/save
585
        # - awslocal dynamodb create-table table2
586
        # To avoid this problem, we start the DDBLocal server in a temporary directory that is then copied over
587
        # ``config.dirs.data`` when the save needs to be saved.
588
        # The ideal solution to the problem would be to always start the server in memory and have a dump capability.
UNCOV
589
        if is_persistence_enabled() and config.SNAPSHOT_SAVE_STRATEGY == "MANUAL":
×
590
            LOG.debug("Copying %s to %s", self.tmp_asset_directory, self.asset_directory)
×
591
            cp_r(self.tmp_asset_directory, self.asset_directory, rm_dest_on_conflict=True)
×
592

593
    def on_after_init(self):
1✔
594
        # add response processor specific to ddblocal
595
        handlers.modify_service_response.append(self.service, modify_ddblocal_arns)
1✔
596

597
        # routes for the shell ui
598
        ROUTER.add(
1✔
599
            path="/shell",
600
            endpoint=self.handle_shell_ui_redirect,
601
            methods=["GET"],
602
        )
603
        ROUTER.add(
1✔
604
            path="/shell/<regex('.*'):req_path>",
605
            endpoint=self.handle_shell_ui_request,
606
        )
607

608
    def _forward_request(
1✔
609
        self,
610
        context: RequestContext,
611
        region: str | None,
612
        service_request: ServiceRequest | None = None,
613
    ) -> ServiceResponse:
614
        """
615
        Modify the context region and then forward request to DynamoDB Local.
616

617
        This is used for operations impacted by global tables. In LocalStack, a single copy of global table
618
        is kept, and any requests to replicated tables are forwarded to this original table.
619
        """
620
        if region:
1✔
621
            with modify_context_region(context, region):
1✔
622
                return self.forward_request(context, service_request=service_request)
1✔
UNCOV
623
        return self.forward_request(context, service_request=service_request)
×
624

625
    def forward_request(
1✔
626
        self, context: RequestContext, service_request: ServiceRequest = None
627
    ) -> ServiceResponse:
628
        """
629
        Forward a request to DynamoDB Local.
630
        """
631
        self.check_provisioned_throughput(context.operation.name)
1✔
632
        self.prepare_request_headers(
1✔
633
            context.request.headers, account_id=context.account_id, region_name=context.region
634
        )
635
        return self.server.proxy(context, service_request)
1✔
636

637
    def get_forward_url(self, account_id: str, region_name: str) -> str:
1✔
638
        """Return the URL of the backend DynamoDBLocal server to forward requests to"""
UNCOV
639
        return self.server.url
×
640

641
    def handle_shell_ui_redirect(self, request: werkzeug.Request) -> Response:
1✔
UNCOV
642
        headers = {"Refresh": f"0; url={config.external_service_url()}/shell/index.html"}
×
643
        return Response("", headers=headers)
×
644

645
    def handle_shell_ui_request(self, request: werkzeug.Request, req_path: str) -> Response:
1✔
646
        # TODO: "DynamoDB Local Web Shell was deprecated with version 1.16.X and is not available any
647
        #  longer from 1.17.X to latest. There are no immediate plans for a new Web Shell to be introduced."
648
        #  -> keeping this for now, to allow configuring custom installs; should consider removing it in the future
649
        # https://repost.aws/questions/QUHyIzoEDqQ3iOKlUEp1LPWQ#ANdBm9Nz9TRf6VqR3jZtcA1g
UNCOV
650
        req_path = f"/{req_path}" if not req_path.startswith("/") else req_path
×
651
        account_id = extract_account_id_from_headers(request.headers)
×
652
        region_name = extract_region_from_headers(request.headers)
×
653
        url = f"{self.get_forward_url(account_id, region_name)}/shell{req_path}"
×
654
        result = requests.request(
×
655
            method=request.method, url=url, headers=request.headers, data=request.data
656
        )
UNCOV
657
        return Response(result.content, headers=dict(result.headers), status=result.status_code)
×
658

659
    #
660
    # Table ops
661
    #
662

663
    @handler("CreateTable", expand=False)
1✔
664
    def create_table(
1✔
665
        self,
666
        context: RequestContext,
667
        create_table_input: CreateTableInput,
668
    ) -> CreateTableOutput:
669
        table_name = create_table_input["TableName"]
1✔
670

671
        # Return this specific error message to keep parity with AWS
672
        if self.table_exists(context.account_id, context.region, table_name):
1✔
673
            raise ResourceInUseException(f"Table already exists: {table_name}")
1✔
674

675
        billing_mode = create_table_input.get("BillingMode")
1✔
676
        provisioned_throughput = create_table_input.get("ProvisionedThroughput")
1✔
677
        if billing_mode == BillingMode.PAY_PER_REQUEST and provisioned_throughput is not None:
1✔
678
            raise ValidationException(
1✔
679
                "One or more parameter values were invalid: Neither ReadCapacityUnits nor WriteCapacityUnits can be "
680
                "specified when BillingMode is PAY_PER_REQUEST"
681
            )
682

683
        result = self.forward_request(context)
1✔
684

685
        table_description = result["TableDescription"]
1✔
686
        table_description["TableArn"] = table_arn = self.fix_table_arn(
1✔
687
            context.account_id, context.region, table_description["TableArn"]
688
        )
689

690
        backend = get_store(context.account_id, context.region)
1✔
691
        backend.table_definitions[table_name] = table_definitions = dict(create_table_input)
1✔
692
        backend.TABLE_REGION[table_name] = context.region
1✔
693

694
        if "TableId" not in table_definitions:
1✔
695
            table_definitions["TableId"] = long_uid()
1✔
696

697
        if "SSESpecification" in table_definitions:
1✔
698
            sse_specification = table_definitions.pop("SSESpecification")
1✔
699
            table_definitions["SSEDescription"] = SSEUtils.get_sse_description(
1✔
700
                context.account_id, context.region, sse_specification
701
            )
702

703
        if table_definitions:
1✔
704
            table_content = result.get("Table", {})
1✔
705
            table_content.update(table_definitions)
1✔
706
            table_description.update(table_content)
1✔
707

708
        if "StreamSpecification" in table_definitions:
1✔
709
            create_dynamodb_stream(
1✔
710
                context.account_id,
711
                context.region,
712
                table_definitions,
713
                table_description.get("LatestStreamLabel"),
714
            )
715

716
        if "TableClass" in table_definitions:
1✔
717
            table_class = table_description.pop("TableClass", None) or table_definitions.pop(
1✔
718
                "TableClass"
719
            )
720
            table_description["TableClassSummary"] = {"TableClass": table_class}
1✔
721

722
        if "GlobalSecondaryIndexes" in table_description:
1✔
723
            gsis = copy.deepcopy(table_description["GlobalSecondaryIndexes"])
1✔
724
            # update the different values, as DynamoDB-local v2 has a regression around GSI and does not return anything
725
            # anymore
726
            for gsi in gsis:
1✔
727
                index_name = gsi.get("IndexName", "")
1✔
728
                gsi.update(
1✔
729
                    {
730
                        "IndexArn": f"{table_arn}/index/{index_name}",
731
                        "IndexSizeBytes": 0,
732
                        "IndexStatus": "ACTIVE",
733
                        "ItemCount": 0,
734
                    }
735
                )
736
                gsi_provisioned_throughput = gsi.setdefault("ProvisionedThroughput", {})
1✔
737
                gsi_provisioned_throughput["NumberOfDecreasesToday"] = 0
1✔
738

739
                if billing_mode == BillingMode.PAY_PER_REQUEST:
1✔
740
                    gsi_provisioned_throughput["ReadCapacityUnits"] = 0
1✔
741
                    gsi_provisioned_throughput["WriteCapacityUnits"] = 0
1✔
742

743
            table_description["GlobalSecondaryIndexes"] = gsis
1✔
744

745
        if "ProvisionedThroughput" in table_description:
1✔
746
            if "NumberOfDecreasesToday" not in table_description["ProvisionedThroughput"]:
1✔
747
                table_description["ProvisionedThroughput"]["NumberOfDecreasesToday"] = 0
1✔
748

749
        tags = table_definitions.pop("Tags", [])
1✔
750
        if tags:
1✔
751
            get_store(context.account_id, context.region).TABLE_TAGS[table_arn] = {
1✔
752
                tag["Key"]: tag["Value"] for tag in tags
753
            }
754

755
        # remove invalid attributes from result
756
        table_description.pop("Tags", None)
1✔
757
        table_description.pop("BillingMode", None)
1✔
758

759
        return result
1✔
760

761
    def delete_table(
1✔
762
        self, context: RequestContext, table_name: TableName, **kwargs
763
    ) -> DeleteTableOutput:
764
        global_table_region = self.get_global_table_region(context, table_name)
1✔
765

766
        # Limitation note: On AWS, for a replicated table, if the source table is deleted, the replicated tables continue to exist.
767
        # This is not the case for LocalStack, where all replicated tables will also be removed if source is deleted.
768

769
        result = self._forward_request(context=context, region=global_table_region)
1✔
770

771
        table_arn = result.get("TableDescription", {}).get("TableArn")
1✔
772
        table_arn = self.fix_table_arn(context.account_id, context.region, table_arn)
1✔
773
        dynamodbstreams_api.delete_streams(context.account_id, context.region, table_arn)
1✔
774

775
        store = get_store(context.account_id, context.region)
1✔
776
        store.TABLE_TAGS.pop(table_arn, None)
1✔
777
        store.REPLICAS.pop(table_name, None)
1✔
778

779
        return result
1✔
780

781
    def describe_table(
1✔
782
        self, context: RequestContext, table_name: TableName, **kwargs
783
    ) -> DescribeTableOutput:
784
        global_table_region = self.get_global_table_region(context, table_name)
1✔
785

786
        result = self._forward_request(context=context, region=global_table_region)
1✔
787
        table_description: TableDescription = result["Table"]
1✔
788

789
        # Update table properties from LocalStack stores
790
        if table_props := get_store(context.account_id, context.region).table_properties.get(
1✔
791
            table_name
792
        ):
793
            table_description.update(table_props)
1✔
794

795
        store = get_store(context.account_id, context.region)
1✔
796

797
        # Update replication details
798
        replicas: dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
1✔
799

800
        replica_description_list = []
1✔
801

802
        if global_table_region != context.region:
1✔
803
            replica_description_list.append(
1✔
804
                ReplicaDescription(
805
                    RegionName=global_table_region, ReplicaStatus=ReplicaStatus.ACTIVE
806
                )
807
            )
808

809
        for replica_region, replica_description in replicas.items():
1✔
810
            # The replica in the region being queried must not be returned
811
            if replica_region != context.region:
1✔
812
                replica_description_list.append(replica_description)
1✔
813

814
        if replica_description_list:
1✔
815
            table_description.update({"Replicas": replica_description_list})
1✔
816

817
        # update only TableId and SSEDescription if present
818
        if table_definitions := store.table_definitions.get(table_name):
1✔
819
            for key in ["TableId", "SSEDescription"]:
1✔
820
                if table_definitions.get(key):
1✔
821
                    table_description[key] = table_definitions[key]
1✔
822
            if "TableClass" in table_definitions:
1✔
823
                table_description["TableClassSummary"] = {
1✔
824
                    "TableClass": table_definitions["TableClass"]
825
                }
826

827
        if "GlobalSecondaryIndexes" in table_description:
1✔
828
            for gsi in table_description["GlobalSecondaryIndexes"]:
1✔
829
                default_values = {
1✔
830
                    "NumberOfDecreasesToday": 0,
831
                    "ReadCapacityUnits": 0,
832
                    "WriteCapacityUnits": 0,
833
                }
834
                # even if the billing mode is PAY_PER_REQUEST, AWS returns the Read and Write Capacity Units
835
                # Terraform depends on this parity for update operations
836
                gsi["ProvisionedThroughput"] = default_values | gsi.get("ProvisionedThroughput", {})
1✔
837

838
        return DescribeTableOutput(
1✔
839
            Table=select_from_typed_dict(TableDescription, table_description)
840
        )
841

842
    @handler("UpdateTable", expand=False)
1✔
843
    def update_table(
1✔
844
        self, context: RequestContext, update_table_input: UpdateTableInput
845
    ) -> UpdateTableOutput:
846
        table_name = update_table_input["TableName"]
1✔
847
        global_table_region = self.get_global_table_region(context, table_name)
1✔
848

849
        try:
1✔
850
            result = self._forward_request(context=context, region=global_table_region)
1✔
851
        except CommonServiceException as exc:
1✔
852
            # DynamoDBLocal refuses to update certain table params and raises.
853
            # But we still need to update this info in LocalStack stores
854
            if not (exc.code == "ValidationException" and exc.message == "Nothing to update"):
1✔
UNCOV
855
                raise
×
856

857
            if table_class := update_table_input.get("TableClass"):
1✔
858
                table_definitions = get_store(
1✔
859
                    context.account_id, context.region
860
                ).table_definitions.setdefault(table_name, {})
861
                table_definitions["TableClass"] = table_class
1✔
862

863
            if replica_updates := update_table_input.get("ReplicaUpdates"):
1✔
864
                store = get_store(context.account_id, global_table_region)
1✔
865

866
                # Dict with source region to set of replicated regions
867
                replicas: dict[RegionName, ReplicaDescription] = store.REPLICAS.get(table_name, {})
1✔
868

869
                for replica_update in replica_updates:
1✔
870
                    for key, details in replica_update.items():
1✔
871
                        # Replicated region
872
                        target_region = details.get("RegionName")
1✔
873

874
                        # Check if replicated region is valid
875
                        if target_region not in get_valid_regions_for_service("dynamodb"):
1✔
UNCOV
876
                            raise ValidationException(f"Region {target_region} is not supported")
×
877

878
                        match key:
1✔
879
                            case "Create":
1✔
880
                                if target_region in replicas:
1✔
UNCOV
881
                                    raise ValidationException(
×
882
                                        f"Failed to create a the new replica of table with name: '{table_name}' because one or more replicas already existed as tables."
883
                                    )
884
                                replicas[target_region] = ReplicaDescription(
1✔
885
                                    RegionName=target_region,
886
                                    KMSMasterKeyId=details.get("KMSMasterKeyId"),
887
                                    ProvisionedThroughputOverride=details.get(
888
                                        "ProvisionedThroughputOverride"
889
                                    ),
890
                                    GlobalSecondaryIndexes=details.get("GlobalSecondaryIndexes"),
891
                                    ReplicaStatus=ReplicaStatus.ACTIVE,
892
                                )
893
                            case "Delete":
1✔
894
                                try:
1✔
895
                                    replicas.pop(target_region)
1✔
896
                                except KeyError:
1✔
897
                                    raise ValidationException(
1✔
898
                                        "Update global table operation failed because one or more replicas were not part of the global table."
899
                                    )
900

901
                store.REPLICAS[table_name] = replicas
1✔
902

903
            # update response content
904
            SchemaExtractor.invalidate_table_schema(
1✔
905
                table_name, context.account_id, global_table_region
906
            )
907

908
            schema = SchemaExtractor.get_table_schema(
1✔
909
                table_name, context.account_id, global_table_region
910
            )
911

912
            if sse_specification_input := update_table_input.get("SSESpecification"):
1✔
913
                # If SSESpecification is changed, update store and return the 'UPDATING' status in the response
914
                table_definition = get_store(
1✔
915
                    context.account_id, context.region
916
                ).table_definitions.setdefault(table_name, {})
917
                if not sse_specification_input["Enabled"]:
1✔
918
                    table_definition.pop("SSEDescription", None)
1✔
919
                    schema["Table"]["SSEDescription"]["Status"] = "UPDATING"
1✔
920

921
            return UpdateTableOutput(TableDescription=schema["Table"])
1✔
922

923
        SchemaExtractor.invalidate_table_schema(table_name, context.account_id, global_table_region)
1✔
924

925
        schema = SchemaExtractor.get_table_schema(
1✔
926
            table_name, context.account_id, global_table_region
927
        )
928

929
        # TODO: DDB streams must also be created for replicas
930
        if update_table_input.get("StreamSpecification"):
1✔
931
            create_dynamodb_stream(
1✔
932
                context.account_id,
933
                context.region,
934
                update_table_input,
935
                result["TableDescription"].get("LatestStreamLabel"),
936
            )
937

938
        return UpdateTableOutput(TableDescription=schema["Table"])
1✔
939

940
    def list_tables(
1✔
941
        self,
942
        context: RequestContext,
943
        exclusive_start_table_name: TableName = None,
944
        limit: ListTablesInputLimit = None,
945
        **kwargs,
946
    ) -> ListTablesOutput:
947
        response = self.forward_request(context)
1✔
948

949
        # Add replicated tables
950
        replicas = get_store(context.account_id, context.region).REPLICAS
1✔
951
        for replicated_table, replications in replicas.items():
1✔
952
            for replica_region, replica_description in replications.items():
1✔
953
                if context.region == replica_region:
1✔
954
                    response["TableNames"].append(replicated_table)
1✔
955

956
        return response
1✔
957

958
    #
959
    # Item ops
960
    #
961

962
    @handler("PutItem", expand=False)
1✔
963
    def put_item(self, context: RequestContext, put_item_input: PutItemInput) -> PutItemOutput:
1✔
964
        table_name = put_item_input["TableName"]
1✔
965
        global_table_region = self.get_global_table_region(context, table_name)
1✔
966

967
        has_return_values = put_item_input.get("ReturnValues") == "ALL_OLD"
1✔
968
        stream_type = get_table_stream_type(context.account_id, context.region, table_name)
1✔
969

970
        # if the request doesn't ask for ReturnValues and we have stream enabled, we need to modify the request to
971
        # force DDBLocal to return those values
972
        if stream_type and not has_return_values:
1✔
973
            service_req = copy.copy(context.service_request)
1✔
974
            service_req["ReturnValues"] = "ALL_OLD"
1✔
975
            result = self._forward_request(
1✔
976
                context=context, region=global_table_region, service_request=service_req
977
            )
978
        else:
979
            result = self._forward_request(context=context, region=global_table_region)
1✔
980

981
        # Since this operation makes use of global table region, we need to use the same region for all
982
        # calls made via the inter-service client. This is taken care of by passing the account ID and
983
        # region, e.g. when getting the stream spec
984

985
        # Get stream specifications details for the table
986
        if stream_type:
1✔
987
            item = put_item_input["Item"]
1✔
988
            # prepare record keys
989
            keys = SchemaExtractor.extract_keys(
1✔
990
                item=item,
991
                table_name=table_name,
992
                account_id=context.account_id,
993
                region_name=global_table_region,
994
            )
995
            # because we modified the request, we will always have the ReturnValues if we have streams enabled
996
            if has_return_values:
1✔
UNCOV
997
                existing_item = result.get("Attributes")
×
998
            else:
999
                # remove the ReturnValues if the client didn't ask for it
1000
                existing_item = result.pop("Attributes", None)
1✔
1001

1002
            if existing_item == item:
1✔
1003
                return result
1✔
1004

1005
            # create record
1006
            record = self.get_record_template(
1✔
1007
                context.region,
1008
            )
1009
            record["eventName"] = "INSERT" if not existing_item else "MODIFY"
1✔
1010
            record["dynamodb"]["Keys"] = keys
1✔
1011
            record["dynamodb"]["SizeBytes"] = _get_size_bytes(item)
1✔
1012

1013
            if stream_type.needs_new_image:
1✔
1014
                record["dynamodb"]["NewImage"] = item
1✔
1015
            if stream_type.stream_view_type:
1✔
1016
                record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
1017
            if existing_item and stream_type.needs_old_image:
1✔
1018
                record["dynamodb"]["OldImage"] = existing_item
1✔
1019

1020
            records_map = {
1✔
1021
                table_name: TableRecords(records=[record], table_stream_type=stream_type)
1022
            }
1023
            self.forward_stream_records(context.account_id, context.region, records_map)
1✔
1024
        return result
1✔
1025

1026
    @handler("DeleteItem", expand=False)
1✔
1027
    def delete_item(
1✔
1028
        self,
1029
        context: RequestContext,
1030
        delete_item_input: DeleteItemInput,
1031
    ) -> DeleteItemOutput:
1032
        table_name = delete_item_input["TableName"]
1✔
1033
        global_table_region = self.get_global_table_region(context, table_name)
1✔
1034

1035
        has_return_values = delete_item_input.get("ReturnValues") == "ALL_OLD"
1✔
1036
        stream_type = get_table_stream_type(context.account_id, context.region, table_name)
1✔
1037

1038
        # if the request doesn't ask for ReturnValues and we have stream enabled, we need to modify the request to
1039
        # force DDBLocal to return those values
1040
        if stream_type and not has_return_values:
1✔
1041
            service_req = copy.copy(context.service_request)
1✔
1042
            service_req["ReturnValues"] = "ALL_OLD"
1✔
1043
            result = self._forward_request(
1✔
1044
                context=context, region=global_table_region, service_request=service_req
1045
            )
1046
        else:
1047
            result = self._forward_request(context=context, region=global_table_region)
1✔
1048

1049
        # determine and forward stream record
1050
        if stream_type:
1✔
1051
            # because we modified the request, we will always have the ReturnValues if we have streams enabled
1052
            if has_return_values:
1✔
UNCOV
1053
                existing_item = result.get("Attributes")
×
1054
            else:
1055
                # remove the ReturnValues if the client didn't ask for it
1056
                existing_item = result.pop("Attributes", None)
1✔
1057

1058
            if not existing_item:
1✔
UNCOV
1059
                return result
×
1060

1061
            # create record
1062
            record = self.get_record_template(context.region)
1✔
1063
            record["eventName"] = "REMOVE"
1✔
1064
            record["dynamodb"]["Keys"] = delete_item_input["Key"]
1✔
1065
            record["dynamodb"]["SizeBytes"] = _get_size_bytes(existing_item)
1✔
1066

1067
            if stream_type.stream_view_type:
1✔
1068
                record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
1069
            if stream_type.needs_old_image:
1✔
1070
                record["dynamodb"]["OldImage"] = existing_item
1✔
1071

1072
            records_map = {
1✔
1073
                table_name: TableRecords(records=[record], table_stream_type=stream_type)
1074
            }
1075
            self.forward_stream_records(context.account_id, context.region, records_map)
1✔
1076

1077
        return result
1✔
1078

1079
    @handler("UpdateItem", expand=False)
1✔
1080
    def update_item(
1✔
1081
        self,
1082
        context: RequestContext,
1083
        update_item_input: UpdateItemInput,
1084
    ) -> UpdateItemOutput:
1085
        # TODO: UpdateItem is harder to use ReturnValues for Streams, because it needs the Before and After images.
1086
        table_name = update_item_input["TableName"]
1✔
1087
        global_table_region = self.get_global_table_region(context, table_name)
1✔
1088

1089
        existing_item = None
1✔
1090
        stream_type = get_table_stream_type(context.account_id, context.region, table_name)
1✔
1091

1092
        # even if we don't need the OldImage, we still need to fetch the existing item to know if the event is INSERT
1093
        # or MODIFY (UpdateItem will create the object if it doesn't exist, and you don't use a ConditionExpression)
1094
        if stream_type:
1✔
1095
            existing_item = ItemFinder.find_existing_item(
1✔
1096
                put_item=update_item_input,
1097
                table_name=table_name,
1098
                account_id=context.account_id,
1099
                region_name=context.region,
1100
                endpoint_url=self.server.url,
1101
            )
1102

1103
        result = self._forward_request(context=context, region=global_table_region)
1✔
1104

1105
        # construct and forward stream record
1106
        if stream_type:
1✔
1107
            updated_item = ItemFinder.find_existing_item(
1✔
1108
                put_item=update_item_input,
1109
                table_name=table_name,
1110
                account_id=context.account_id,
1111
                region_name=context.region,
1112
                endpoint_url=self.server.url,
1113
            )
1114
            if not updated_item or updated_item == existing_item:
1✔
UNCOV
1115
                return result
×
1116

1117
            record = self.get_record_template(context.region)
1✔
1118
            record["eventName"] = "INSERT" if not existing_item else "MODIFY"
1✔
1119
            record["dynamodb"]["Keys"] = update_item_input["Key"]
1✔
1120
            record["dynamodb"]["SizeBytes"] = _get_size_bytes(updated_item)
1✔
1121

1122
            if stream_type.stream_view_type:
1✔
1123
                record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
1124
            if existing_item and stream_type.needs_old_image:
1✔
1125
                record["dynamodb"]["OldImage"] = existing_item
1✔
1126
            if stream_type.needs_new_image:
1✔
1127
                record["dynamodb"]["NewImage"] = updated_item
1✔
1128

1129
            records_map = {
1✔
1130
                table_name: TableRecords(records=[record], table_stream_type=stream_type)
1131
            }
1132
            self.forward_stream_records(context.account_id, context.region, records_map)
1✔
1133

1134
        return result
1✔
1135

1136
    @handler("GetItem", expand=False)
1✔
1137
    def get_item(self, context: RequestContext, get_item_input: GetItemInput) -> GetItemOutput:
1✔
1138
        table_name = get_item_input["TableName"]
1✔
1139
        global_table_region = self.get_global_table_region(context, table_name)
1✔
1140
        result = self._forward_request(context=context, region=global_table_region)
1✔
1141
        self.fix_consumed_capacity(get_item_input, result)
1✔
1142
        return result
1✔
1143

1144
    #
1145
    # Queries
1146
    #
1147

1148
    @handler("Query", expand=False)
1✔
1149
    def query(self, context: RequestContext, query_input: QueryInput) -> QueryOutput:
1✔
1150
        index_name = query_input.get("IndexName")
1✔
1151
        if index_name:
1✔
1152
            if not is_index_query_valid(context.account_id, context.region, query_input):
1✔
1153
                raise ValidationException(
1✔
1154
                    "One or more parameter values were invalid: Select type ALL_ATTRIBUTES "
1155
                    "is not supported for global secondary index id-index because its projection "
1156
                    "type is not ALL",
1157
                )
1158

1159
        table_name = query_input["TableName"]
1✔
1160
        global_table_region = self.get_global_table_region(context, table_name)
1✔
1161
        result = self._forward_request(context=context, region=global_table_region)
1✔
1162
        self.fix_consumed_capacity(query_input, result)
1✔
1163
        return result
1✔
1164

1165
    @handler("Scan", expand=False)
1✔
1166
    def scan(self, context: RequestContext, scan_input: ScanInput) -> ScanOutput:
1✔
1167
        table_name = scan_input["TableName"]
1✔
1168
        global_table_region = self.get_global_table_region(context, table_name)
1✔
1169
        result = self._forward_request(context=context, region=global_table_region)
1✔
1170
        return result
1✔
1171

1172
    #
1173
    # Batch ops
1174
    #
1175

1176
    @handler("BatchWriteItem", expand=False)
1✔
1177
    def batch_write_item(
1✔
1178
        self,
1179
        context: RequestContext,
1180
        batch_write_item_input: BatchWriteItemInput,
1181
    ) -> BatchWriteItemOutput:
1182
        # TODO: add global table support
1183
        existing_items = {}
1✔
1184
        existing_items_to_fetch: BatchWriteItemRequestMap = {}
1✔
1185
        # UnprocessedItems should have the same format as RequestItems
1186
        unprocessed_items = {}
1✔
1187
        request_items = batch_write_item_input["RequestItems"]
1✔
1188

1189
        tables_stream_type: dict[TableName, TableStreamType] = {}
1✔
1190

1191
        for table_name, items in sorted(request_items.items(), key=itemgetter(0)):
1✔
1192
            if stream_type := get_table_stream_type(context.account_id, context.region, table_name):
1✔
1193
                tables_stream_type[table_name] = stream_type
1✔
1194

1195
            for request in items:
1✔
1196
                request: WriteRequest
1197
                for key, inner_request in request.items():
1✔
1198
                    inner_request: PutRequest | DeleteRequest
1199
                    if self.should_throttle("BatchWriteItem"):
1✔
UNCOV
1200
                        unprocessed_items_for_table = unprocessed_items.setdefault(table_name, [])
×
1201
                        unprocessed_items_for_table.append(request)
×
1202

1203
                    elif stream_type:
1✔
1204
                        existing_items_to_fetch_for_table = existing_items_to_fetch.setdefault(
1✔
1205
                            table_name, []
1206
                        )
1207
                        existing_items_to_fetch_for_table.append(inner_request)
1✔
1208

1209
        if existing_items_to_fetch:
1✔
1210
            existing_items = ItemFinder.find_existing_items(
1✔
1211
                put_items_per_table=existing_items_to_fetch,
1212
                account_id=context.account_id,
1213
                region_name=context.region,
1214
                endpoint_url=self.server.url,
1215
            )
1216

1217
        try:
1✔
1218
            result = self.forward_request(context)
1✔
1219
        except CommonServiceException as e:
1✔
1220
            # TODO: validate if DynamoDB still raises `One of the required keys was not given a value`
1221
            # for now, replace with the schema error validation
1222
            if e.message == "One of the required keys was not given a value":
1✔
1223
                raise ValidationException("The provided key element does not match the schema")
1✔
UNCOV
1224
            raise e
×
1225

1226
        # determine and forward stream records
1227
        if tables_stream_type:
1✔
1228
            records_map = self.prepare_batch_write_item_records(
1✔
1229
                account_id=context.account_id,
1230
                region_name=context.region,
1231
                tables_stream_type=tables_stream_type,
1232
                request_items=request_items,
1233
                existing_items=existing_items,
1234
            )
1235
            self.forward_stream_records(context.account_id, context.region, records_map)
1✔
1236

1237
        # TODO: should unprocessed item which have mutated by `prepare_batch_write_item_records` be returned
1238
        for table_name, unprocessed_items_in_table in unprocessed_items.items():
1✔
UNCOV
1239
            unprocessed: dict = result["UnprocessedItems"]
×
1240
            result_unprocessed_table = unprocessed.setdefault(table_name, [])
×
1241

1242
            # add the Unprocessed items to the response
1243
            # TODO: check before if the same request has not been Unprocessed by DDB local already?
1244
            # those might actually have been processed? shouldn't we remove them from the proxied request?
UNCOV
1245
            for request in unprocessed_items_in_table:
×
1246
                result_unprocessed_table.append(request)
×
1247

1248
            # remove any table entry if it's empty
UNCOV
1249
            result["UnprocessedItems"] = {k: v for k, v in unprocessed.items() if v}
×
1250

1251
        return result
1✔
1252

1253
    @handler("BatchGetItem")
1✔
1254
    def batch_get_item(
1✔
1255
        self,
1256
        context: RequestContext,
1257
        request_items: BatchGetRequestMap,
1258
        return_consumed_capacity: ReturnConsumedCapacity = None,
1259
        **kwargs,
1260
    ) -> BatchGetItemOutput:
1261
        # TODO: add global table support
1262
        return self.forward_request(context)
1✔
1263

1264
    #
1265
    # Transactions
1266
    #
1267

1268
    @handler("TransactWriteItems", expand=False)
1✔
1269
    def transact_write_items(
1✔
1270
        self,
1271
        context: RequestContext,
1272
        transact_write_items_input: TransactWriteItemsInput,
1273
    ) -> TransactWriteItemsOutput:
1274
        # TODO: add global table support
1275
        existing_items = {}
1✔
1276
        existing_items_to_fetch: dict[str, list[Put | Update | Delete]] = {}
1✔
1277
        updated_items_to_fetch: dict[str, list[Update]] = {}
1✔
1278
        transact_items = transact_write_items_input["TransactItems"]
1✔
1279
        tables_stream_type: dict[TableName, TableStreamType] = {}
1✔
1280
        no_stream_tables = set()
1✔
1281

1282
        for item in transact_items:
1✔
1283
            item: TransactWriteItem
1284
            for key in ["Put", "Update", "Delete"]:
1✔
1285
                inner_item: Put | Delete | Update = item.get(key)
1✔
1286
                if inner_item:
1✔
1287
                    table_name = inner_item["TableName"]
1✔
1288
                    # if we've seen the table already and it does not have streams, skip
1289
                    if table_name in no_stream_tables:
1✔
1290
                        continue
1✔
1291

1292
                    # if we have not seen the table, fetch its streaming status
1293
                    if table_name not in tables_stream_type:
1✔
1294
                        if stream_type := get_table_stream_type(
1✔
1295
                            context.account_id, context.region, table_name
1296
                        ):
1297
                            tables_stream_type[table_name] = stream_type
1✔
1298
                        else:
1299
                            # no stream,
1300
                            no_stream_tables.add(table_name)
1✔
1301
                            continue
1✔
1302

1303
                    existing_items_to_fetch_for_table = existing_items_to_fetch.setdefault(
1✔
1304
                        table_name, []
1305
                    )
1306
                    existing_items_to_fetch_for_table.append(inner_item)
1✔
1307
                    if key == "Update":
1✔
1308
                        updated_items_to_fetch_for_table = updated_items_to_fetch.setdefault(
1✔
1309
                            table_name, []
1310
                        )
1311
                        updated_items_to_fetch_for_table.append(inner_item)
1✔
1312

1313
                    continue
1✔
1314

1315
        if existing_items_to_fetch:
1✔
1316
            existing_items = ItemFinder.find_existing_items(
1✔
1317
                put_items_per_table=existing_items_to_fetch,
1318
                account_id=context.account_id,
1319
                region_name=context.region,
1320
                endpoint_url=self.server.url,
1321
            )
1322

1323
        client_token: str | None = transact_write_items_input.get("ClientRequestToken")
1✔
1324

1325
        if client_token:
1✔
1326
            # we sort the payload since identical payload but with different order could cause
1327
            # IdempotentParameterMismatchException error if a client token is provided
1328
            context.request.data = to_bytes(canonical_json(json.loads(context.request.data)))
1✔
1329

1330
        result = self.forward_request(context)
1✔
1331

1332
        # determine and forward stream records
1333
        if tables_stream_type:
1✔
1334
            updated_items = (
1✔
1335
                ItemFinder.find_existing_items(
1336
                    put_items_per_table=existing_items_to_fetch,
1337
                    account_id=context.account_id,
1338
                    region_name=context.region,
1339
                    endpoint_url=self.server.url,
1340
                )
1341
                if updated_items_to_fetch
1342
                else {}
1343
            )
1344

1345
            records_map = self.prepare_transact_write_item_records(
1✔
1346
                account_id=context.account_id,
1347
                region_name=context.region,
1348
                transact_items=transact_items,
1349
                existing_items=existing_items,
1350
                updated_items=updated_items,
1351
                tables_stream_type=tables_stream_type,
1352
            )
1353
            self.forward_stream_records(context.account_id, context.region, records_map)
1✔
1354

1355
        return result
1✔
1356

1357
    @handler("TransactGetItems", expand=False)
1✔
1358
    def transact_get_items(
1✔
1359
        self,
1360
        context: RequestContext,
1361
        transact_items: TransactGetItemList,
1362
        return_consumed_capacity: ReturnConsumedCapacity = None,
1363
    ) -> TransactGetItemsOutput:
1364
        return self.forward_request(context)
1✔
1365

1366
    @handler("ExecuteTransaction", expand=False)
1✔
1367
    def execute_transaction(
1✔
1368
        self, context: RequestContext, execute_transaction_input: ExecuteTransactionInput
1369
    ) -> ExecuteTransactionOutput:
1370
        result = self.forward_request(context)
1✔
1371
        return result
1✔
1372

1373
    @handler("ExecuteStatement", expand=False)
1✔
1374
    def execute_statement(
1✔
1375
        self,
1376
        context: RequestContext,
1377
        execute_statement_input: ExecuteStatementInput,
1378
    ) -> ExecuteStatementOutput:
1379
        # TODO: this operation is still really slow with streams enabled
1380
        #  find a way to make it better, same way as the other operations, by using returnvalues
1381
        # see https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/ql-reference.update.html
1382
        statement = execute_statement_input["Statement"]
1✔
1383
        # We found out that 'Parameters' can be an empty list when the request comes from the AWS JS client.
1384
        if execute_statement_input.get("Parameters", None) == []:  # noqa
1✔
1385
            raise ValidationException(
1✔
1386
                "1 validation error detected: Value '[]' at 'parameters' failed to satisfy constraint: Member must have length greater than or equal to 1"
1387
            )
1388
        table_name = extract_table_name_from_partiql_update(statement)
1✔
1389
        existing_items = None
1✔
1390
        stream_type = table_name and get_table_stream_type(
1✔
1391
            context.account_id, context.region, table_name
1392
        )
1393
        if stream_type:
1✔
1394
            # Note: fetching the entire list of items is hugely inefficient, especially for larger tables
1395
            # TODO: find a mechanism to hook into the PartiQL update mechanism of DynamoDB Local directly!
1396
            existing_items = ItemFinder.list_existing_items_for_statement(
1✔
1397
                partiql_statement=statement,
1398
                account_id=context.account_id,
1399
                region_name=context.region,
1400
                endpoint_url=self.server.url,
1401
            )
1402

1403
        result = self.forward_request(context)
1✔
1404

1405
        # construct and forward stream record
1406
        if stream_type:
1✔
1407
            records = get_updated_records(
1✔
1408
                account_id=context.account_id,
1409
                region_name=context.region,
1410
                table_name=table_name,
1411
                existing_items=existing_items,
1412
                server_url=self.server.url,
1413
                table_stream_type=stream_type,
1414
            )
1415
            self.forward_stream_records(context.account_id, context.region, records)
1✔
1416

1417
        return result
1✔
1418

1419
    #
1420
    # Tags
1421
    #
1422

1423
    def tag_resource(
1✔
1424
        self, context: RequestContext, resource_arn: ResourceArnString, tags: TagList, **kwargs
1425
    ) -> None:
1426
        table_tags = get_store(context.account_id, context.region).TABLE_TAGS
1✔
1427
        if resource_arn not in table_tags:
1✔
UNCOV
1428
            table_tags[resource_arn] = {}
×
1429
        table_tags[resource_arn].update({tag["Key"]: tag["Value"] for tag in tags})
1✔
1430

1431
    def untag_resource(
1✔
1432
        self,
1433
        context: RequestContext,
1434
        resource_arn: ResourceArnString,
1435
        tag_keys: TagKeyList,
1436
        **kwargs,
1437
    ) -> None:
1438
        for tag_key in tag_keys or []:
1✔
1439
            get_store(context.account_id, context.region).TABLE_TAGS.get(resource_arn, {}).pop(
1✔
1440
                tag_key, None
1441
            )
1442

1443
    def list_tags_of_resource(
1✔
1444
        self,
1445
        context: RequestContext,
1446
        resource_arn: ResourceArnString,
1447
        next_token: NextTokenString = None,
1448
        **kwargs,
1449
    ) -> ListTagsOfResourceOutput:
1450
        result = [
1✔
1451
            {"Key": k, "Value": v}
1452
            for k, v in get_store(context.account_id, context.region)
1453
            .TABLE_TAGS.get(resource_arn, {})
1454
            .items()
1455
        ]
1456
        return ListTagsOfResourceOutput(Tags=result)
1✔
1457

1458
    #
1459
    # TTLs
1460
    #
1461

1462
    def describe_time_to_live(
1✔
1463
        self, context: RequestContext, table_name: TableName, **kwargs
1464
    ) -> DescribeTimeToLiveOutput:
1465
        if not self.table_exists(context.account_id, context.region, table_name):
1✔
1466
            raise ResourceNotFoundException(
1✔
1467
                f"Requested resource not found: Table: {table_name} not found"
1468
            )
1469

1470
        backend = get_store(context.account_id, context.region)
1✔
1471
        ttl_spec = backend.ttl_specifications.get(table_name)
1✔
1472

1473
        result = {"TimeToLiveStatus": "DISABLED"}
1✔
1474
        if ttl_spec:
1✔
1475
            if ttl_spec.get("Enabled"):
1✔
1476
                ttl_status = "ENABLED"
1✔
1477
            else:
1478
                ttl_status = "DISABLED"
1✔
1479
            result = {
1✔
1480
                "AttributeName": ttl_spec.get("AttributeName"),
1481
                "TimeToLiveStatus": ttl_status,
1482
            }
1483

1484
        return DescribeTimeToLiveOutput(TimeToLiveDescription=result)
1✔
1485

1486
    def update_time_to_live(
1✔
1487
        self,
1488
        context: RequestContext,
1489
        table_name: TableName,
1490
        time_to_live_specification: TimeToLiveSpecification,
1491
        **kwargs,
1492
    ) -> UpdateTimeToLiveOutput:
1493
        if not self.table_exists(context.account_id, context.region, table_name):
1✔
1494
            raise ResourceNotFoundException(
1✔
1495
                f"Requested resource not found: Table: {table_name} not found"
1496
            )
1497

1498
        # TODO: TTL status is maintained/mocked but no real expiry is happening for items
1499
        backend = get_store(context.account_id, context.region)
1✔
1500
        backend.ttl_specifications[table_name] = time_to_live_specification
1✔
1501
        return UpdateTimeToLiveOutput(TimeToLiveSpecification=time_to_live_specification)
1✔
1502

1503
    #
1504
    # Global tables
1505
    #
1506

1507
    def create_global_table(
1✔
1508
        self,
1509
        context: RequestContext,
1510
        global_table_name: TableName,
1511
        replication_group: ReplicaList,
1512
        **kwargs,
1513
    ) -> CreateGlobalTableOutput:
1514
        global_tables: dict = get_store(context.account_id, context.region).GLOBAL_TABLES
1✔
1515
        if global_table_name in global_tables:
1✔
1516
            raise GlobalTableAlreadyExistsException("Global table with this name already exists")
1✔
1517
        replication_group = [grp.copy() for grp in replication_group or []]
1✔
1518
        data = {"GlobalTableName": global_table_name, "ReplicationGroup": replication_group}
1✔
1519
        global_tables[global_table_name] = data
1✔
1520
        for group in replication_group:
1✔
1521
            group["ReplicaStatus"] = "ACTIVE"
1✔
1522
            group["ReplicaStatusDescription"] = "Replica active"
1✔
1523
        return CreateGlobalTableOutput(GlobalTableDescription=data)
1✔
1524

1525
    def describe_global_table(
1✔
1526
        self, context: RequestContext, global_table_name: TableName, **kwargs
1527
    ) -> DescribeGlobalTableOutput:
1528
        details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
1✔
1529
        if not details:
1✔
1530
            raise GlobalTableNotFoundException("Global table with this name does not exist")
1✔
1531
        return DescribeGlobalTableOutput(GlobalTableDescription=details)
1✔
1532

1533
    def list_global_tables(
1✔
1534
        self,
1535
        context: RequestContext,
1536
        exclusive_start_global_table_name: TableName = None,
1537
        limit: PositiveIntegerObject = None,
1538
        region_name: RegionName = None,
1539
        **kwargs,
1540
    ) -> ListGlobalTablesOutput:
1541
        # TODO: add paging support
UNCOV
1542
        result = [
×
1543
            select_attributes(tab, ["GlobalTableName", "ReplicationGroup"])
1544
            for tab in get_store(context.account_id, context.region).GLOBAL_TABLES.values()
1545
        ]
UNCOV
1546
        return ListGlobalTablesOutput(GlobalTables=result)
×
1547

1548
    def update_global_table(
1✔
1549
        self,
1550
        context: RequestContext,
1551
        global_table_name: TableName,
1552
        replica_updates: ReplicaUpdateList,
1553
        **kwargs,
1554
    ) -> UpdateGlobalTableOutput:
1555
        details = get_store(context.account_id, context.region).GLOBAL_TABLES.get(global_table_name)
1✔
1556
        if not details:
1✔
UNCOV
1557
            raise GlobalTableNotFoundException("Global table with this name does not exist")
×
1558
        for update in replica_updates or []:
1✔
1559
            repl_group = details["ReplicationGroup"]
1✔
1560
            # delete existing
1561
            delete = update.get("Delete")
1✔
1562
            if delete:
1✔
1563
                details["ReplicationGroup"] = [
1✔
1564
                    g for g in repl_group if g["RegionName"] != delete["RegionName"]
1565
                ]
1566
            # create new
1567
            create = update.get("Create")
1✔
1568
            if create:
1✔
1569
                exists = [g for g in repl_group if g["RegionName"] == create["RegionName"]]
1✔
1570
                if exists:
1✔
UNCOV
1571
                    continue
×
1572
                new_group = {
1✔
1573
                    "RegionName": create["RegionName"],
1574
                    "ReplicaStatus": "ACTIVE",
1575
                    "ReplicaStatusDescription": "Replica active",
1576
                }
1577
                details["ReplicationGroup"].append(new_group)
1✔
1578
        return UpdateGlobalTableOutput(GlobalTableDescription=details)
1✔
1579

1580
    #
1581
    # Kinesis Streaming
1582
    #
1583

1584
    def enable_kinesis_streaming_destination(
1✔
1585
        self,
1586
        context: RequestContext,
1587
        table_name: TableName,
1588
        stream_arn: StreamArn,
1589
        enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
1590
        **kwargs,
1591
    ) -> KinesisStreamingDestinationOutput:
1592
        self.ensure_table_exists(
1✔
1593
            context.account_id,
1594
            context.region,
1595
            table_name,
1596
            error_message=f"Requested resource not found: Table: {table_name} not found",
1597
        )
1598

1599
        # TODO: Use the time precision in config if set
1600
        enable_kinesis_streaming_configuration = enable_kinesis_streaming_configuration or {}
1✔
1601

1602
        stream = self._event_forwarder.is_kinesis_stream_exists(stream_arn=stream_arn)
1✔
1603
        if not stream:
1✔
UNCOV
1604
            raise ValidationException("User does not have a permission to use kinesis stream")
×
1605

1606
        table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
1✔
1607
            table_name, {}
1608
        )
1609

1610
        dest_status = table_def.get("KinesisDataStreamDestinationStatus")
1✔
1611
        if dest_status not in ["DISABLED", "ENABLE_FAILED", None]:
1✔
UNCOV
1612
            raise ValidationException(
×
1613
                "Table is not in a valid state to enable Kinesis Streaming "
1614
                "Destination:EnableKinesisStreamingDestination must be DISABLED or ENABLE_FAILED "
1615
                "to perform ENABLE operation."
1616
            )
1617

1618
        table_def.setdefault("KinesisDataStreamDestinations", [])
1✔
1619

1620
        # remove the stream destination if already present
1621
        table_def["KinesisDataStreamDestinations"] = [
1✔
1622
            t for t in table_def["KinesisDataStreamDestinations"] if t["StreamArn"] != stream_arn
1623
        ]
1624
        # append the active stream destination at the end of the list
1625
        table_def["KinesisDataStreamDestinations"].append(
1✔
1626
            {
1627
                "DestinationStatus": DestinationStatus.ACTIVE,
1628
                "DestinationStatusDescription": "Stream is active",
1629
                "StreamArn": stream_arn,
1630
                "ApproximateCreationDateTimePrecision": ApproximateCreationDateTimePrecision.MILLISECOND,
1631
            }
1632
        )
1633
        table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.ACTIVE
1✔
1634
        return KinesisStreamingDestinationOutput(
1✔
1635
            DestinationStatus=DestinationStatus.ENABLING,
1636
            StreamArn=stream_arn,
1637
            TableName=table_name,
1638
            EnableKinesisStreamingConfiguration=enable_kinesis_streaming_configuration,
1639
        )
1640

1641
    def disable_kinesis_streaming_destination(
1✔
1642
        self,
1643
        context: RequestContext,
1644
        table_name: TableName,
1645
        stream_arn: StreamArn,
1646
        enable_kinesis_streaming_configuration: EnableKinesisStreamingConfiguration = None,
1647
        **kwargs,
1648
    ) -> KinesisStreamingDestinationOutput:
1649
        self.ensure_table_exists(
1✔
1650
            context.account_id,
1651
            context.region,
1652
            table_name,
1653
            error_message=f"Requested resource not found: Table: {table_name} not found",
1654
        )
1655

1656
        # TODO: Must raise if invoked before KinesisStreamingDestination is ACTIVE
1657

1658
        stream = self._event_forwarder.is_kinesis_stream_exists(stream_arn=stream_arn)
1✔
1659
        if not stream:
1✔
UNCOV
1660
            raise ValidationException(
×
1661
                "User does not have a permission to use kinesis stream",
1662
            )
1663

1664
        table_def = get_store(context.account_id, context.region).table_definitions.setdefault(
1✔
1665
            table_name, {}
1666
        )
1667

1668
        stream_destinations = table_def.get("KinesisDataStreamDestinations")
1✔
1669
        if stream_destinations:
1✔
1670
            if table_def["KinesisDataStreamDestinationStatus"] == DestinationStatus.ACTIVE:
1✔
1671
                for dest in stream_destinations:
1✔
1672
                    if (
1✔
1673
                        dest["StreamArn"] == stream_arn
1674
                        and dest["DestinationStatus"] == DestinationStatus.ACTIVE
1675
                    ):
1676
                        dest["DestinationStatus"] = DestinationStatus.DISABLED
1✔
1677
                        dest["DestinationStatusDescription"] = ("Stream is disabled",)
1✔
1678
                        table_def["KinesisDataStreamDestinationStatus"] = DestinationStatus.DISABLED
1✔
1679
                        return KinesisStreamingDestinationOutput(
1✔
1680
                            DestinationStatus=DestinationStatus.DISABLING,
1681
                            StreamArn=stream_arn,
1682
                            TableName=table_name,
1683
                        )
UNCOV
1684
        raise ValidationException(
×
1685
            "Table is not in a valid state to disable Kinesis Streaming Destination:"
1686
            "DisableKinesisStreamingDestination must be ACTIVE to perform DISABLE operation."
1687
        )
1688

1689
    def describe_kinesis_streaming_destination(
1✔
1690
        self, context: RequestContext, table_name: TableName, **kwargs
1691
    ) -> DescribeKinesisStreamingDestinationOutput:
1692
        self.ensure_table_exists(context.account_id, context.region, table_name)
1✔
1693

1694
        table_def = (
1✔
1695
            get_store(context.account_id, context.region).table_definitions.get(table_name) or {}
1696
        )
1697

1698
        stream_destinations = table_def.get("KinesisDataStreamDestinations") or []
1✔
1699
        stream_destinations = copy.deepcopy(stream_destinations)
1✔
1700

1701
        for destination in stream_destinations:
1✔
1702
            destination.pop("ApproximateCreationDateTimePrecision", None)
1✔
1703
            destination.pop("DestinationStatusDescription", None)
1✔
1704

1705
        return DescribeKinesisStreamingDestinationOutput(
1✔
1706
            KinesisDataStreamDestinations=stream_destinations,
1707
            TableName=table_name,
1708
        )
1709

1710
    def update_kinesis_streaming_destination(
1✔
1711
        self,
1712
        context: RequestContext,
1713
        table_name: TableArn,
1714
        stream_arn: StreamArn,
1715
        update_kinesis_streaming_configuration: UpdateKinesisStreamingConfiguration | None = None,
1716
        **kwargs,
1717
    ) -> UpdateKinesisStreamingDestinationOutput:
1718
        self.ensure_table_exists(context.account_id, context.region, table_name)
1✔
1719

1720
        if not update_kinesis_streaming_configuration:
1✔
1721
            raise ValidationException(
1✔
1722
                "Streaming destination cannot be updated with given parameters: "
1723
                "UpdateKinesisStreamingConfiguration cannot be null or contain only null values"
1724
            )
1725

1726
        time_precision = update_kinesis_streaming_configuration.get(
1✔
1727
            "ApproximateCreationDateTimePrecision"
1728
        )
1729
        if time_precision not in (
1✔
1730
            ApproximateCreationDateTimePrecision.MILLISECOND,
1731
            ApproximateCreationDateTimePrecision.MICROSECOND,
1732
        ):
1733
            raise ValidationException(
1✔
1734
                f"1 validation error detected: Value '{time_precision}' at "
1735
                "'updateKinesisStreamingConfiguration.approximateCreationDateTimePrecision' failed to satisfy constraint: "
1736
                "Member must satisfy enum value set: [MILLISECOND, MICROSECOND]"
1737
            )
1738

1739
        store = get_store(context.account_id, context.region)
1✔
1740

1741
        table_def = store.table_definitions.get(table_name) or {}
1✔
1742
        table_def.setdefault("KinesisDataStreamDestinations", [])
1✔
1743

1744
        table_id = table_def["TableId"]
1✔
1745

1746
        destination = None
1✔
1747
        for stream in table_def["KinesisDataStreamDestinations"]:
1✔
1748
            if stream["StreamArn"] == stream_arn:
1✔
1749
                destination = stream
1✔
1750

1751
        if destination is None:
1✔
1752
            raise ValidationException(
1✔
1753
                "Table is not in a valid state to enable Kinesis Streaming Destination: "
1754
                f"No streaming destination with streamArn: {stream_arn} found for table with tableName: {table_name}"
1755
            )
1756

1757
        if (
1✔
1758
            existing_precision := destination["ApproximateCreationDateTimePrecision"]
1759
        ) == update_kinesis_streaming_configuration["ApproximateCreationDateTimePrecision"]:
1760
            raise ValidationException(
1✔
1761
                f"Invalid Request: Precision is already set to the desired value of {existing_precision} "
1762
                f"for tableId: {table_id}, kdsArn: {stream_arn}"
1763
            )
1764

1765
        destination["ApproximateCreationDateTimePrecision"] = time_precision
1✔
1766

1767
        return UpdateKinesisStreamingDestinationOutput(
1✔
1768
            TableName=table_name,
1769
            StreamArn=stream_arn,
1770
            DestinationStatus=DestinationStatus.UPDATING,
1771
            UpdateKinesisStreamingConfiguration=UpdateKinesisStreamingConfiguration(
1772
                ApproximateCreationDateTimePrecision=time_precision,
1773
            ),
1774
        )
1775

1776
    #
1777
    # Continuous Backups
1778
    #
1779

1780
    def describe_continuous_backups(
1✔
1781
        self, context: RequestContext, table_name: TableName, **kwargs
1782
    ) -> DescribeContinuousBackupsOutput:
1783
        self.get_global_table_region(context, table_name)
1✔
1784
        store = get_store(context.account_id, context.region)
1✔
1785
        continuous_backup_description = (
1✔
1786
            store.table_properties.get(table_name, {}).get("ContinuousBackupsDescription")
1787
        ) or ContinuousBackupsDescription(
1788
            ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
1789
            PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
1790
                PointInTimeRecoveryStatus=PointInTimeRecoveryStatus.DISABLED
1791
            ),
1792
        )
1793

1794
        return DescribeContinuousBackupsOutput(
1✔
1795
            ContinuousBackupsDescription=continuous_backup_description
1796
        )
1797

1798
    def update_continuous_backups(
1✔
1799
        self,
1800
        context: RequestContext,
1801
        table_name: TableName,
1802
        point_in_time_recovery_specification: PointInTimeRecoverySpecification,
1803
        **kwargs,
1804
    ) -> UpdateContinuousBackupsOutput:
1805
        self.get_global_table_region(context, table_name)
1✔
1806

1807
        store = get_store(context.account_id, context.region)
1✔
1808
        pit_recovery_status = (
1✔
1809
            PointInTimeRecoveryStatus.ENABLED
1810
            if point_in_time_recovery_specification["PointInTimeRecoveryEnabled"]
1811
            else PointInTimeRecoveryStatus.DISABLED
1812
        )
1813
        continuous_backup_description = ContinuousBackupsDescription(
1✔
1814
            ContinuousBackupsStatus=ContinuousBackupsStatus.ENABLED,
1815
            PointInTimeRecoveryDescription=PointInTimeRecoveryDescription(
1816
                PointInTimeRecoveryStatus=pit_recovery_status
1817
            ),
1818
        )
1819
        table_props = store.table_properties.setdefault(table_name, {})
1✔
1820
        table_props["ContinuousBackupsDescription"] = continuous_backup_description
1✔
1821

1822
        return UpdateContinuousBackupsOutput(
1✔
1823
            ContinuousBackupsDescription=continuous_backup_description
1824
        )
1825

1826
    #
1827
    # Helpers
1828
    #
1829

1830
    @staticmethod
1✔
1831
    def ddb_region_name(region_name: str) -> str:
1✔
1832
        """Map `local` or `localhost` region to the us-east-1 region. These values are used by NoSQL Workbench."""
1833
        # TODO: could this be somehow moved into the request handler chain?
1834
        if region_name in ("local", "localhost"):
1✔
UNCOV
1835
            region_name = AWS_REGION_US_EAST_1
×
1836

1837
        return region_name
1✔
1838

1839
    @staticmethod
1✔
1840
    def table_exists(account_id: str, region_name: str, table_name: str) -> bool:
1✔
1841
        region_name = DynamoDBProvider.ddb_region_name(region_name)
1✔
1842

1843
        client = connect_to(
1✔
1844
            aws_access_key_id=account_id,
1845
            aws_secret_access_key=INTERNAL_AWS_SECRET_ACCESS_KEY,
1846
            region_name=region_name,
1847
        ).dynamodb
1848
        return dynamodb_table_exists(table_name, client)
1✔
1849

1850
    @staticmethod
1✔
1851
    def ensure_table_exists(
1✔
1852
        account_id: str,
1853
        region_name: str,
1854
        table_name: str,
1855
        error_message: str = "Cannot do operations on a non-existent table",
1856
    ):
1857
        """
1858
        Raise ResourceNotFoundException if the given table does not exist.
1859

1860
        :param account_id: account id
1861
        :param region_name: region name
1862
        :param table_name: table name
1863
        :raise: ResourceNotFoundException if table does not exist in DynamoDB Local
1864
        """
1865
        if not DynamoDBProvider.table_exists(account_id, region_name, table_name):
1✔
1866
            raise ResourceNotFoundException(error_message)
1✔
1867

1868
    @staticmethod
1✔
1869
    def get_global_table_region(context: RequestContext, table_name: str) -> str:
1✔
1870
        """
1871
        Return the table region considering that it might be a replicated table.
1872

1873
        Replication in LocalStack works by keeping a single copy of a table and forwarding
1874
        requests to the region where this table exists.
1875

1876
        This method does not check whether the table actually exists in DDBLocal.
1877

1878
        :param context: request context
1879
        :param table_name: table name
1880
        :return: region
1881
        """
1882
        store = get_store(context.account_id, context.region)
1✔
1883

1884
        table_region = store.TABLE_REGION.get(table_name)
1✔
1885
        replicated_at = store.REPLICAS.get(table_name, {}).keys()
1✔
1886

1887
        if context.region == table_region or context.region in replicated_at:
1✔
1888
            return table_region
1✔
1889

1890
        return context.region
1✔
1891

1892
    @staticmethod
1✔
1893
    def prepare_request_headers(headers: dict, account_id: str, region_name: str):
1✔
1894
        """
1895
        Modify the Credentials field of Authorization header to achieve namespacing in DynamoDBLocal.
1896
        """
1897
        region_name = DynamoDBProvider.ddb_region_name(region_name)
1✔
1898
        key = get_ddb_access_key(account_id, region_name)
1✔
1899

1900
        # DynamoDBLocal namespaces based on the value of Credentials
1901
        # Since we want to namespace by both account ID and region, use an aggregate key
1902
        # We also replace the region to keep compatibility with NoSQL Workbench
1903
        headers["Authorization"] = re.sub(
1✔
1904
            AUTH_CREDENTIAL_REGEX,
1905
            rf"Credential={key}/\2/{region_name}/\4/",
1906
            headers.get("Authorization") or "",
1907
            flags=re.IGNORECASE,
1908
        )
1909

1910
    def fix_consumed_capacity(self, request: dict, result: dict):
1✔
1911
        # make sure we append 'ConsumedCapacity', which is properly
1912
        # returned by dynalite, but not by AWS's DynamoDBLocal
1913
        table_name = request.get("TableName")
1✔
1914
        return_cap = request.get("ReturnConsumedCapacity")
1✔
1915
        if "ConsumedCapacity" not in result and return_cap in ["TOTAL", "INDEXES"]:
1✔
UNCOV
1916
            request["ConsumedCapacity"] = {
×
1917
                "TableName": table_name,
1918
                "CapacityUnits": 5,  # TODO hardcoded
1919
                "ReadCapacityUnits": 2,
1920
                "WriteCapacityUnits": 3,
1921
            }
1922

1923
    def fix_table_arn(self, account_id: str, region_name: str, arn: str) -> str:
1✔
1924
        """
1925
        Set the correct account ID and region in ARNs returned by DynamoDB Local.
1926
        """
1927
        partition = get_partition(region_name)
1✔
1928
        return (
1✔
1929
            arn.replace("arn:aws:", f"arn:{partition}:")
1930
            .replace(":ddblocal:", f":{region_name}:")
1931
            .replace(":000000000000:", f":{account_id}:")
1932
        )
1933

1934
    def prepare_transact_write_item_records(
1✔
1935
        self,
1936
        account_id: str,
1937
        region_name: str,
1938
        transact_items: TransactWriteItemList,
1939
        existing_items: BatchGetResponseMap,
1940
        updated_items: BatchGetResponseMap,
1941
        tables_stream_type: dict[TableName, TableStreamType],
1942
    ) -> RecordsMap:
1943
        records_only_map: dict[TableName, StreamRecords] = defaultdict(list)
1✔
1944

1945
        for request in transact_items:
1✔
1946
            record = self.get_record_template(region_name)
1✔
1947
            match request:
1✔
1948
                case {"Put": {"TableName": table_name, "Item": new_item}}:
1✔
1949
                    if not (stream_type := tables_stream_type.get(table_name)):
1✔
1950
                        continue
1✔
1951
                    keys = SchemaExtractor.extract_keys(
1✔
1952
                        item=new_item,
1953
                        table_name=table_name,
1954
                        account_id=account_id,
1955
                        region_name=region_name,
1956
                    )
1957
                    existing_item = find_item_for_keys_values_in_batch(
1✔
1958
                        table_name, keys, existing_items
1959
                    )
1960
                    if existing_item == new_item:
1✔
1961
                        continue
1✔
1962

1963
                    if stream_type.stream_view_type:
1✔
1964
                        record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
1965

1966
                    record["eventID"] = short_uid()
1✔
1967
                    record["eventName"] = "INSERT" if not existing_item else "MODIFY"
1✔
1968
                    record["dynamodb"]["Keys"] = keys
1✔
1969
                    if stream_type.needs_new_image:
1✔
1970
                        record["dynamodb"]["NewImage"] = new_item
1✔
1971
                    if existing_item and stream_type.needs_old_image:
1✔
1972
                        record["dynamodb"]["OldImage"] = existing_item
1✔
1973

1974
                    record_item = de_dynamize_record(new_item)
1✔
1975
                    record["dynamodb"]["SizeBytes"] = _get_size_bytes(record_item)
1✔
1976
                    records_only_map[table_name].append(record)
1✔
1977
                    continue
1✔
1978

1979
                case {"Update": {"TableName": table_name, "Key": keys}}:
1✔
1980
                    if not (stream_type := tables_stream_type.get(table_name)):
1✔
UNCOV
1981
                        continue
×
1982
                    updated_item = find_item_for_keys_values_in_batch(
1✔
1983
                        table_name, keys, updated_items
1984
                    )
1985
                    if not updated_item:
1✔
UNCOV
1986
                        continue
×
1987

1988
                    existing_item = find_item_for_keys_values_in_batch(
1✔
1989
                        table_name, keys, existing_items
1990
                    )
1991
                    if existing_item == updated_item:
1✔
1992
                        # if the item is the same as the previous version, AWS does not send an event
1993
                        continue
1✔
1994

1995
                    if stream_type.stream_view_type:
1✔
1996
                        record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
1997

1998
                    record["eventID"] = short_uid()
1✔
1999
                    record["eventName"] = "MODIFY" if existing_item else "INSERT"
1✔
2000
                    record["dynamodb"]["Keys"] = keys
1✔
2001

2002
                    if existing_item and stream_type.needs_old_image:
1✔
2003
                        record["dynamodb"]["OldImage"] = existing_item
1✔
2004
                    if stream_type.needs_new_image:
1✔
2005
                        record["dynamodb"]["NewImage"] = updated_item
1✔
2006

2007
                    record["dynamodb"]["SizeBytes"] = _get_size_bytes(updated_item)
1✔
2008
                    records_only_map[table_name].append(record)
1✔
2009
                    continue
1✔
2010

2011
                case {"Delete": {"TableName": table_name, "Key": keys}}:
1✔
2012
                    if not (stream_type := tables_stream_type.get(table_name)):
1✔
UNCOV
2013
                        continue
×
2014

2015
                    existing_item = find_item_for_keys_values_in_batch(
1✔
2016
                        table_name, keys, existing_items
2017
                    )
2018
                    if not existing_item:
1✔
UNCOV
2019
                        continue
×
2020

2021
                    if stream_type.stream_view_type:
1✔
2022
                        record["dynamodb"]["StreamViewType"] = stream_type.stream_view_type
1✔
2023

2024
                    record["eventID"] = short_uid()
1✔
2025
                    record["eventName"] = "REMOVE"
1✔
2026
                    record["dynamodb"]["Keys"] = keys
1✔
2027
                    if stream_type.needs_old_image:
1✔
2028
                        record["dynamodb"]["OldImage"] = existing_item
1✔
2029
                    record_item = de_dynamize_record(existing_item)
1✔
2030
                    record["dynamodb"]["SizeBytes"] = _get_size_bytes(record_item)
1✔
2031

2032
                    records_only_map[table_name].append(record)
1✔
2033
                    continue
1✔
2034

2035
        records_map = {
1✔
2036
            table_name: TableRecords(
2037
                records=records, table_stream_type=tables_stream_type[table_name]
2038
            )
2039
            for table_name, records in records_only_map.items()
2040
        }
2041

2042
        return records_map
1✔
2043

2044
    def batch_execute_statement(
1✔
2045
        self,
2046
        context: RequestContext,
2047
        statements: PartiQLBatchRequest,
2048
        return_consumed_capacity: ReturnConsumedCapacity = None,
2049
        **kwargs,
2050
    ) -> BatchExecuteStatementOutput:
2051
        result = self.forward_request(context)
1✔
2052
        return result
1✔
2053

2054
    def prepare_batch_write_item_records(
1✔
2055
        self,
2056
        account_id: str,
2057
        region_name: str,
2058
        tables_stream_type: dict[TableName, TableStreamType],
2059
        request_items: BatchWriteItemRequestMap,
2060
        existing_items: BatchGetResponseMap,
2061
    ) -> RecordsMap:
2062
        records_map: RecordsMap = {}
1✔
2063

2064
        # only iterate over tables with streams
2065
        for table_name, stream_type in tables_stream_type.items():
1✔
2066
            existing_items_for_table_unordered = existing_items.get(table_name, [])
1✔
2067
            table_records: StreamRecords = []
1✔
2068

2069
            def find_existing_item_for_keys_values(item_keys: dict) -> AttributeMap | None:
1✔
2070
                """
2071
                This function looks up in the existing items for the provided item keys subset. If present, returns the
2072
                full item.
2073
                :param item_keys: the request item keys
2074
                :return:
2075
                """
2076
                keys_items = item_keys.items()
1✔
2077
                for item in existing_items_for_table_unordered:
1✔
2078
                    if keys_items <= item.items():
1✔
2079
                        return item
1✔
2080

2081
            for write_request in request_items[table_name]:
1✔
2082
                record = self.get_record_template(
1✔
2083
                    region_name,
2084
                    stream_view_type=stream_type.stream_view_type,
2085
                )
2086
                match write_request:
1✔
2087
                    case {"PutRequest": request}:
1✔
2088
                        keys = SchemaExtractor.extract_keys(
1✔
2089
                            item=request["Item"],
2090
                            table_name=table_name,
2091
                            account_id=account_id,
2092
                            region_name=region_name,
2093
                        )
2094
                        # we need to find if there was an existing item even if we don't need it for `OldImage`, because
2095
                        # of the `eventName`
2096
                        existing_item = find_existing_item_for_keys_values(keys)
1✔
2097
                        if existing_item == request["Item"]:
1✔
2098
                            # if the item is the same as the previous version, AWS does not send an event
2099
                            continue
1✔
2100
                        record["eventID"] = short_uid()
1✔
2101
                        record["dynamodb"]["SizeBytes"] = _get_size_bytes(request["Item"])
1✔
2102
                        record["eventName"] = "INSERT" if not existing_item else "MODIFY"
1✔
2103
                        record["dynamodb"]["Keys"] = keys
1✔
2104

2105
                        if stream_type.needs_new_image:
1✔
2106
                            record["dynamodb"]["NewImage"] = request["Item"]
1✔
2107
                        if existing_item and stream_type.needs_old_image:
1✔
2108
                            record["dynamodb"]["OldImage"] = existing_item
1✔
2109

2110
                        table_records.append(record)
1✔
2111
                        continue
1✔
2112

2113
                    case {"DeleteRequest": request}:
1✔
2114
                        keys = request["Key"]
1✔
2115
                        if not (existing_item := find_existing_item_for_keys_values(keys)):
1✔
UNCOV
2116
                            continue
×
2117

2118
                        record["eventID"] = short_uid()
1✔
2119
                        record["eventName"] = "REMOVE"
1✔
2120
                        record["dynamodb"]["Keys"] = keys
1✔
2121
                        if stream_type.needs_old_image:
1✔
2122
                            record["dynamodb"]["OldImage"] = existing_item
1✔
2123
                        record["dynamodb"]["SizeBytes"] = _get_size_bytes(existing_item)
1✔
2124
                        table_records.append(record)
1✔
2125
                        continue
1✔
2126

2127
            records_map[table_name] = TableRecords(
1✔
2128
                records=table_records, table_stream_type=stream_type
2129
            )
2130

2131
        return records_map
1✔
2132

2133
    def forward_stream_records(
1✔
2134
        self,
2135
        account_id: str,
2136
        region_name: str,
2137
        records_map: RecordsMap,
2138
    ) -> None:
2139
        if not records_map:
1✔
UNCOV
2140
            return
×
2141

2142
        self._event_forwarder.forward_to_targets(
1✔
2143
            account_id, region_name, records_map, background=True
2144
        )
2145

2146
    @staticmethod
1✔
2147
    def get_record_template(region_name: str, stream_view_type: str | None = None) -> StreamRecord:
1✔
2148
        record = {
1✔
2149
            "eventID": short_uid(),
2150
            "eventVersion": "1.1",
2151
            "dynamodb": {
2152
                # expects nearest second rounded down
2153
                "ApproximateCreationDateTime": int(time.time()),
2154
                "SizeBytes": -1,
2155
            },
2156
            "awsRegion": region_name,
2157
            "eventSource": "aws:dynamodb",
2158
        }
2159
        if stream_view_type:
1✔
2160
            record["dynamodb"]["StreamViewType"] = stream_view_type
1✔
2161

2162
        return record
1✔
2163

2164
    def check_provisioned_throughput(self, action):
1✔
2165
        """
2166
        Check rate limiting for an API operation and raise an error if provisioned throughput is exceeded.
2167
        """
2168
        if self.should_throttle(action):
1✔
2169
            message = (
1✔
2170
                "The level of configured provisioned throughput for the table was exceeded. "
2171
                + "Consider increasing your provisioning level with the UpdateTable API"
2172
            )
2173
            raise ProvisionedThroughputExceededException(message)
1✔
2174

2175
    def action_should_throttle(self, action, actions):
1✔
2176
        throttled = [f"{ACTION_PREFIX}{a}" for a in actions]
1✔
2177
        return (action in throttled) or (action in actions)
1✔
2178

2179
    def should_throttle(self, action):
1✔
2180
        if (
1✔
2181
            not config.DYNAMODB_READ_ERROR_PROBABILITY
2182
            and not config.DYNAMODB_ERROR_PROBABILITY
2183
            and not config.DYNAMODB_WRITE_ERROR_PROBABILITY
2184
        ):
2185
            # early exit so we don't need to call random()
2186
            return False
1✔
2187

2188
        rand = random.random()
1✔
2189
        if rand < config.DYNAMODB_READ_ERROR_PROBABILITY and self.action_should_throttle(
1✔
2190
            action, READ_THROTTLED_ACTIONS
2191
        ):
2192
            return True
1✔
2193
        elif rand < config.DYNAMODB_WRITE_ERROR_PROBABILITY and self.action_should_throttle(
1✔
2194
            action, WRITE_THROTTLED_ACTIONS
2195
        ):
2196
            return True
1✔
2197
        elif rand < config.DYNAMODB_ERROR_PROBABILITY and self.action_should_throttle(
1✔
2198
            action, THROTTLED_ACTIONS
2199
        ):
2200
            return True
1✔
2201
        return False
1✔
2202

2203

2204
# ---
2205
# Misc. util functions
2206
# ---
2207

2208

2209
def _get_size_bytes(item: dict) -> int:
1✔
2210
    try:
1✔
2211
        size_bytes = len(json.dumps(item, separators=(",", ":")))
1✔
2212
    except TypeError:
1✔
2213
        size_bytes = len(str(item))
1✔
2214
    return size_bytes
1✔
2215

2216

2217
def get_global_secondary_index(account_id: str, region_name: str, table_name: str, index_name: str):
1✔
2218
    schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
1✔
2219
    for index in schema["Table"].get("GlobalSecondaryIndexes", []):
1✔
2220
        if index["IndexName"] == index_name:
1✔
2221
            return index
1✔
UNCOV
2222
    raise ResourceNotFoundException("Index not found")
×
2223

2224

2225
def is_local_secondary_index(
1✔
2226
    account_id: str, region_name: str, table_name: str, index_name: str
2227
) -> bool:
2228
    schema = SchemaExtractor.get_table_schema(table_name, account_id, region_name)
1✔
2229
    for index in schema["Table"].get("LocalSecondaryIndexes", []):
1✔
2230
        if index["IndexName"] == index_name:
1✔
2231
            return True
1✔
2232
    return False
1✔
2233

2234

2235
def is_index_query_valid(account_id: str, region_name: str, query_data: dict) -> bool:
1✔
2236
    table_name = to_str(query_data["TableName"])
1✔
2237
    index_name = to_str(query_data["IndexName"])
1✔
2238
    if is_local_secondary_index(account_id, region_name, table_name, index_name):
1✔
2239
        return True
1✔
2240
    index_query_type = query_data.get("Select")
1✔
2241
    index = get_global_secondary_index(account_id, region_name, table_name, index_name)
1✔
2242
    index_projection_type = index.get("Projection").get("ProjectionType")
1✔
2243
    if index_query_type == "ALL_ATTRIBUTES" and index_projection_type != "ALL":
1✔
2244
        return False
1✔
2245
    return True
1✔
2246

2247

2248
def get_table_stream_type(
1✔
2249
    account_id: str, region_name: str, table_name_or_arn: str
2250
) -> TableStreamType | None:
2251
    """
2252
    :param account_id: the account id of the table
2253
    :param region_name: the region of the table
2254
    :param table_name_or_arn: the table name or ARN
2255
    :return: a TableStreamViewType object if the table has streams enabled. If not, return None
2256
    """
2257
    if not table_name_or_arn:
1✔
UNCOV
2258
        return
×
2259

2260
    table_name = table_name_or_arn.split(":table/")[-1]
1✔
2261

2262
    is_kinesis = False
1✔
2263
    stream_view_type = None
1✔
2264

2265
    if table_definition := get_store(account_id, region_name).table_definitions.get(table_name):
1✔
2266
        if table_definition.get("KinesisDataStreamDestinationStatus") == "ACTIVE":
1✔
2267
            is_kinesis = True
1✔
2268

2269
    table_arn = arns.dynamodb_table_arn(table_name, account_id=account_id, region_name=region_name)
1✔
2270

2271
    if (
1✔
2272
        stream := dynamodbstreams_api.get_stream_for_table(account_id, region_name, table_arn)
2273
    ) and stream["StreamStatus"] in (StreamStatus.ENABLING, StreamStatus.ENABLED):
2274
        stream_view_type = stream["StreamViewType"]
1✔
2275

2276
    if is_kinesis or stream_view_type:
1✔
2277
        return TableStreamType(stream_view_type, is_kinesis=is_kinesis)
1✔
2278

2279

2280
def get_updated_records(
1✔
2281
    account_id: str,
2282
    region_name: str,
2283
    table_name: str,
2284
    existing_items: list,
2285
    server_url: str,
2286
    table_stream_type: TableStreamType,
2287
) -> RecordsMap:
2288
    """
2289
    Determine the list of record updates, to be sent to a DDB stream after a PartiQL update operation.
2290

2291
    Note: This is currently a fairly expensive operation, as we need to retrieve the list of all items
2292
          from the table, and compare the items to the previously available. This is a limitation as
2293
          we're currently using the DynamoDB Local backend as a blackbox. In future, we should consider hooking
2294
          into the PartiQL query execution inside DynamoDB Local and directly extract the list of updated items.
2295
    """
2296
    result = []
1✔
2297

2298
    key_schema = SchemaExtractor.get_key_schema(table_name, account_id, region_name)
1✔
2299
    before = ItemSet(existing_items, key_schema=key_schema)
1✔
2300
    all_table_items = ItemFinder.get_all_table_items(
1✔
2301
        account_id=account_id,
2302
        region_name=region_name,
2303
        table_name=table_name,
2304
        endpoint_url=server_url,
2305
    )
2306
    after = ItemSet(all_table_items, key_schema=key_schema)
1✔
2307

2308
    def _add_record(item, comparison_set: ItemSet):
1✔
2309
        matching_item = comparison_set.find_item(item)
1✔
2310
        if matching_item == item:
1✔
UNCOV
2311
            return
×
2312

2313
        # determine event type
2314
        if comparison_set == after:
1✔
2315
            if matching_item:
1✔
2316
                return
1✔
2317
            event_name = "REMOVE"
1✔
2318
        else:
2319
            event_name = "INSERT" if not matching_item else "MODIFY"
1✔
2320

2321
        old_image = item if event_name == "REMOVE" else matching_item
1✔
2322
        new_image = matching_item if event_name == "REMOVE" else item
1✔
2323

2324
        # prepare record
2325
        keys = SchemaExtractor.extract_keys_for_schema(item=item, key_schema=key_schema)
1✔
2326

2327
        record = DynamoDBProvider.get_record_template(region_name)
1✔
2328
        record["eventName"] = event_name
1✔
2329
        record["dynamodb"]["Keys"] = keys
1✔
2330
        record["dynamodb"]["SizeBytes"] = _get_size_bytes(item)
1✔
2331

2332
        if table_stream_type.stream_view_type:
1✔
2333
            record["dynamodb"]["StreamViewType"] = table_stream_type.stream_view_type
1✔
2334
        if table_stream_type.needs_new_image:
1✔
UNCOV
2335
            record["dynamodb"]["NewImage"] = new_image
×
2336
        if old_image and table_stream_type.needs_old_image:
1✔
UNCOV
2337
            record["dynamodb"]["OldImage"] = old_image
×
2338

2339
        result.append(record)
1✔
2340

2341
    # loop over items in new item list (find INSERT/MODIFY events)
2342
    for item in after.items_list:
1✔
2343
        _add_record(item, before)
1✔
2344
    # loop over items in old item list (find REMOVE events)
2345
    for item in before.items_list:
1✔
2346
        _add_record(item, after)
1✔
2347

2348
    return {table_name: TableRecords(records=result, table_stream_type=table_stream_type)}
1✔
2349

2350

2351
def create_dynamodb_stream(account_id: str, region_name: str, data, latest_stream_label):
1✔
2352
    stream = data["StreamSpecification"]
1✔
2353
    enabled = stream.get("StreamEnabled")
1✔
2354

2355
    if enabled not in [False, "False"]:
1✔
2356
        table_name = data["TableName"]
1✔
2357
        view_type = stream["StreamViewType"]
1✔
2358

2359
        dynamodbstreams_api.add_dynamodb_stream(
1✔
2360
            account_id=account_id,
2361
            region_name=region_name,
2362
            table_name=table_name,
2363
            latest_stream_label=latest_stream_label,
2364
            view_type=view_type,
2365
            enabled=enabled,
2366
        )
2367

2368

2369
def dynamodb_get_table_stream_specification(account_id: str, region_name: str, table_name: str):
1✔
UNCOV
2370
    try:
×
2371
        table_schema = SchemaExtractor.get_table_schema(
×
2372
            table_name, account_id=account_id, region_name=region_name
2373
        )
UNCOV
2374
        return table_schema["Table"].get("StreamSpecification")
×
2375
    except Exception as e:
×
2376
        LOG.info(
×
2377
            "Unable to get stream specification for table %s: %s %s",
2378
            table_name,
2379
            e,
2380
            traceback.format_exc(),
2381
        )
UNCOV
2382
        raise e
×
2383

2384

2385
def find_item_for_keys_values_in_batch(
1✔
2386
    table_name: str, item_keys: dict, batch: BatchGetResponseMap
2387
) -> AttributeMap | None:
2388
    """
2389
    This function looks up in the existing items for the provided item keys subset. If present, returns the
2390
    full item.
2391
    :param table_name: the table name for the item
2392
    :param item_keys: the request item keys
2393
    :param batch: the values in which to look for the item
2394
    :return: a DynamoDB Item (AttributeMap)
2395
    """
2396
    keys = item_keys.items()
1✔
2397
    for item in batch.get(table_name, []):
1✔
2398
        if keys <= item.items():
1✔
2399
            return item
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc