• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

localstack / localstack / 22665822670

04 Mar 2026 10:41AM UTC coverage: 86.951% (-0.02%) from 86.974%
22665822670

push

github

web-flow
feat(stepfunctions): Add Parallel state support for StepFunctions TestState API (#13861)

30 of 40 new or added lines in 11 files covered. (75.0%)

22 existing lines in 7 files now uncovered.

69848 of 80330 relevant lines covered (86.95%)

0.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.29
/localstack-core/localstack/services/stepfunctions/asl/parse/test_state/preprocessor.py
1
import enum
1✔
2
from typing import Final
1✔
3

4
from antlr4.tree.Tree import ParseTree
1✔
5

6
from localstack.services.stepfunctions.asl.antlr.runtime.ASLParser import ASLParser
1✔
7
from localstack.services.stepfunctions.asl.antlt4utils.antlr4utils import (
1✔
8
    is_production,
9
)
10
from localstack.services.stepfunctions.asl.component.common.parargs import (
1✔
11
    ArgumentsJSONataTemplateValueObject,
12
    ArgumentsStringJSONata,
13
    Parameters,
14
)
15
from localstack.services.stepfunctions.asl.component.common.path.input_path import InputPath
1✔
16
from localstack.services.stepfunctions.asl.component.common.path.items_path import ItemsPath
1✔
17
from localstack.services.stepfunctions.asl.component.common.path.result_path import ResultPath
1✔
18
from localstack.services.stepfunctions.asl.component.common.query_language import QueryLanguage
1✔
19
from localstack.services.stepfunctions.asl.component.common.result_selector import ResultSelector
1✔
20
from localstack.services.stepfunctions.asl.component.state.state import CommonStateField
1✔
21
from localstack.services.stepfunctions.asl.component.state.state_choice.state_choice import (
1✔
22
    StateChoice,
23
)
24
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.max_concurrency import (
1✔
25
    MaxConcurrency,
26
    MaxConcurrencyJSONata,
27
    MaxConcurrencyPath,
28
)
29
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.state_map import (
1✔
30
    StateMap,
31
)
32
from localstack.services.stepfunctions.asl.component.state.state_execution.state_map.tolerated_failure import (
1✔
33
    ToleratedFailureCountInt,
34
    ToleratedFailureCountPath,
35
    ToleratedFailureCountStringJSONata,
36
    ToleratedFailurePercentage,
37
    ToleratedFailurePercentagePath,
38
    ToleratedFailurePercentageStringJSONata,
39
)
40
from localstack.services.stepfunctions.asl.component.state.state_execution.state_parallel.state_parallel import (
1✔
41
    StateParallel,
42
)
43
from localstack.services.stepfunctions.asl.component.state.state_execution.state_task.state_task import (
1✔
44
    StateTask,
45
)
46
from localstack.services.stepfunctions.asl.component.state.state_fail.state_fail import StateFail
1✔
47
from localstack.services.stepfunctions.asl.component.state.state_pass.result import Result
1✔
48
from localstack.services.stepfunctions.asl.component.state.state_pass.state_pass import StatePass
1✔
49
from localstack.services.stepfunctions.asl.component.state.state_succeed.state_succeed import (
1✔
50
    StateSucceed,
51
)
52
from localstack.services.stepfunctions.asl.component.test_state.program.test_state_program import (
1✔
53
    TestStateProgram,
54
)
55
from localstack.services.stepfunctions.asl.component.test_state.state.common import (
1✔
56
    MockedCommonState,
57
)
58
from localstack.services.stepfunctions.asl.component.test_state.state.map import (
1✔
59
    MockedStateMap,
60
)
61
from localstack.services.stepfunctions.asl.component.test_state.state.parallel import (
1✔
62
    MockedStateParallel,
63
)
64
from localstack.services.stepfunctions.asl.component.test_state.state.task import (
1✔
65
    MockedStateTask,
66
)
67
from localstack.services.stepfunctions.asl.component.test_state.state.test_state_state_props import (
1✔
68
    TestStateStateProps,
69
)
70
from localstack.services.stepfunctions.asl.eval.test_state.environment import TestStateEnvironment
1✔
71
from localstack.services.stepfunctions.asl.parse.preprocessor import Preprocessor
1✔
72
from localstack.services.stepfunctions.asl.utils.encoding import to_json_str
1✔
73

74

75
class InspectionDataKey(enum.Enum):
1✔
76
    INPUT = "input"
1✔
77
    AFTER_INPUT_PATH = "afterInputPath"
1✔
78
    AFTER_PARAMETERS = "afterParameters"
1✔
79
    AFTER_ARGUMENTS = "afterArguments"
1✔
80
    RESULT = "result"
1✔
81
    AFTER_RESULT_SELECTOR = "afterResultSelector"
1✔
82
    AFTER_RESULT_PATH = "afterResultPath"
1✔
83
    AFTER_ITEMS_PATH = "afterItemsPath"
1✔
84
    REQUEST = "request"
1✔
85
    RESPONSE = "response"
1✔
86

87
    MAX_CONCURRENCY = "maxConcurrency"
1✔
88
    TOLERATED_FAILURE_COUNT = "toleratedFailureCount"
1✔
89
    TOLERATED_FAILURE_PERCENTAGE = "toleratedFailurePercentage"
1✔
90

91

92
def _decorated_updates_inspection_data(method, inspection_data_key: InspectionDataKey):
1✔
93
    def wrapper(env: TestStateEnvironment, *args, **kwargs):
1✔
94
        method(env, *args, **kwargs)
1✔
95
        result = env.stack[-1]
1✔
96
        if not isinstance(result, (int, float)):
1✔
97
            result = to_json_str(result)
1✔
98
        # We know that the enum value used here corresponds to a supported inspection data field by design.
99
        env.inspection_data[inspection_data_key.value] = result  # noqa
1✔
100

101
    return wrapper
1✔
102

103

104
def _decorate_state_field(state_field: CommonStateField, is_single_state: bool = False) -> None:
1✔
105
    if isinstance(state_field, StateMap):
1✔
106
        MockedStateMap.wrap(state_field, is_single_state)
1✔
107
    elif isinstance(state_field, StateParallel):
1✔
108
        MockedStateParallel.wrap(state_field, is_single_state)
1✔
109
    elif isinstance(state_field, StateTask):
1✔
110
        MockedStateTask.wrap(state_field, is_single_state)
1✔
111
    elif isinstance(state_field, (StateChoice, StatePass, StateFail, StateSucceed)):
1✔
112
        MockedCommonState.wrap(state_field, is_single_state)
1✔
113

114

115
def find_state(state_name: str, states: dict[str, CommonStateField]) -> CommonStateField | None:
1✔
116
    if state_name in states:
1✔
117
        return states[state_name]
1✔
118

119
    for state in states.values():
1✔
120
        if isinstance(state, StateMap):
1✔
121
            found_state = find_state(state_name, state.iteration_component._states.states)
1✔
122
            if found_state:
1✔
123
                return found_state
1✔
124
        elif isinstance(state, StateParallel):
1✔
NEW
125
            for program in state.branches.programs:
×
NEW
126
                found_state = find_state(state_name, program.states.states)
×
NEW
127
                if found_state:
×
NEW
128
                    return found_state
×
129

130

131
class TestStatePreprocessor(Preprocessor):
1✔
132
    STATE_NAME: Final[str] = "StateName"
1✔
133
    _state_name_stack: list[str] = []
1✔
134

135
    def to_test_state_program(
1✔
136
        self, tree: ParseTree, state_name: str | None = None
137
    ) -> TestStateProgram:
138
        if is_production(tree, ASLParser.RULE_state_machine):
1✔
139
            # full definition passed in
140
            program = self.visitState_machine(ctx=tree)
1✔
141
            state_field = find_state(state_name, program.states.states)
1✔
142
            _decorate_state_field(state_field, False)
1✔
143
            return TestStateProgram(state_field)
1✔
144

145
        if is_production(tree, ASLParser.RULE_state_decl_body):
1✔
146
            # single state case
147
            state_props = self.visitState_decl_body(ctx=tree)
1✔
148
            state_field = self._common_state_field_of(state_props=state_props)
1✔
149
            _decorate_state_field(state_field, True)
1✔
150
            return TestStateProgram(state_field)
1✔
151

152
        return super().visit(tree)
×
153

154
    def visitState_decl(self, ctx: ASLParser.State_declContext) -> CommonStateField:
1✔
155
        # if we are parsing a full state machine, we need to record the state_name prior to stepping
156
        # into the state body definition.
157
        state_name = self._inner_string_of(parser_rule_context=ctx.string_literal())
1✔
158
        self._state_name_stack.append(state_name)
1✔
159
        state_props: TestStateStateProps = self.visit(ctx.state_decl_body())
1✔
160
        state_field = self._common_state_field_of(state_props=state_props)
1✔
161
        return state_field
1✔
162

163
    def visitState_decl_body(self, ctx: ASLParser.State_decl_bodyContext) -> TestStateStateProps:
1✔
164
        self._open_query_language_scope(ctx)
1✔
165
        state_props = TestStateStateProps()
1✔
166
        state_props.name = (
1✔
167
            self._state_name_stack.pop(-1) if self._state_name_stack else self.STATE_NAME
168
        )
169
        for child in ctx.children:
1✔
170
            cmp = self.visit(child)
1✔
171
            state_props.add(cmp)
1✔
172
        if state_props.get(QueryLanguage) is None:
1✔
173
            state_props.add(self._get_current_query_language())
1✔
174
        self._close_query_language_scope()
1✔
175
        return state_props
1✔
176

177
    def visitInput_path_decl(self, ctx: ASLParser.Input_path_declContext) -> InputPath:
1✔
178
        input_path: InputPath = super().visitInput_path_decl(ctx=ctx)
1✔
179
        input_path._eval_body = _decorated_updates_inspection_data(
1✔
180
            method=input_path._eval_body,  # noqa
181
            inspection_data_key=InspectionDataKey.AFTER_INPUT_PATH,
182
        )
183
        return input_path
1✔
184

185
    def visitParameters_decl(self, ctx: ASLParser.Parameters_declContext) -> Parameters:
1✔
186
        parameters: Parameters = super().visitParameters_decl(ctx=ctx)
1✔
187
        parameters._eval_body = _decorated_updates_inspection_data(
1✔
188
            method=parameters._eval_body,  # noqa
189
            inspection_data_key=InspectionDataKey.AFTER_PARAMETERS,
190
        )
191
        return parameters
1✔
192

193
    def visitResult_selector_decl(
1✔
194
        self, ctx: ASLParser.Result_selector_declContext
195
    ) -> ResultSelector:
196
        result_selector: ResultSelector = super().visitResult_selector_decl(ctx=ctx)
1✔
197
        result_selector._eval_body = _decorated_updates_inspection_data(
1✔
198
            method=result_selector._eval_body,  # noqa
199
            inspection_data_key=InspectionDataKey.AFTER_RESULT_SELECTOR,
200
        )
201
        return result_selector
1✔
202

203
    def visitResult_path_decl(self, ctx: ASLParser.Result_path_declContext) -> ResultPath:
1✔
204
        result_path: ResultPath = super().visitResult_path_decl(ctx=ctx)
1✔
205
        result_path._eval_body = _decorated_updates_inspection_data(
1✔
206
            method=result_path._eval_body,  # noqa
207
            inspection_data_key=InspectionDataKey.AFTER_RESULT_PATH,
208
        )
209
        return result_path
1✔
210

211
    def visitResult_decl(self, ctx: ASLParser.Result_declContext) -> Result:
1✔
212
        result: Result = super().visitResult_decl(ctx=ctx)
1✔
213
        result._eval_body = _decorated_updates_inspection_data(
1✔
214
            method=result._eval_body,
215
            inspection_data_key=InspectionDataKey.RESULT,  # noqa
216
        )
217
        return result
1✔
218

219
    def visitMax_concurrency_int(self, ctx: ASLParser.Max_concurrency_intContext) -> MaxConcurrency:
1✔
220
        max_concurrency: MaxConcurrency = super().visitMax_concurrency_int(ctx)
1✔
221
        max_concurrency._eval_body = _decorated_updates_inspection_data(
1✔
222
            method=max_concurrency._eval_body,
223
            inspection_data_key=InspectionDataKey.MAX_CONCURRENCY,  # noqa
224
        )
225
        return max_concurrency
1✔
226

227
    def visitMax_concurrency_jsonata(
1✔
228
        self, ctx: ASLParser.Max_concurrency_jsonataContext
229
    ) -> MaxConcurrencyJSONata:
230
        max_concurrency_jsonata: MaxConcurrencyJSONata = super().visitMax_concurrency_jsonata(ctx)
1✔
231
        max_concurrency_jsonata._eval_body = _decorated_updates_inspection_data(
1✔
232
            method=max_concurrency_jsonata._eval_body,
233
            inspection_data_key=InspectionDataKey.MAX_CONCURRENCY,  # noqa
234
        )
235
        return max_concurrency_jsonata
1✔
236

237
    def visitMax_concurrency_path(
1✔
238
        self, ctx: ASLParser.Max_concurrency_declContext
239
    ) -> MaxConcurrencyPath:
240
        max_concurrency_path: MaxConcurrencyPath = super().visitMax_concurrency_path(ctx)
×
241
        max_concurrency_path._eval_body = _decorated_updates_inspection_data(
×
242
            method=max_concurrency_path._eval_body,
243
            inspection_data_key=InspectionDataKey.MAX_CONCURRENCY,  # noqa
244
        )
245
        return max_concurrency_path
×
246

247
    def visitTolerated_failure_count_int(self, ctx) -> ToleratedFailureCountInt:
1✔
248
        tolerated_failure_count: ToleratedFailureCountInt = (
1✔
249
            super().visitTolerated_failure_count_int(ctx)
250
        )
251
        tolerated_failure_count._eval_body = _decorated_updates_inspection_data(
1✔
252
            method=tolerated_failure_count._eval_body,
253
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_COUNT,
254
        )
255
        return tolerated_failure_count
1✔
256

257
    def visitTolerated_failure_count_path(self, ctx) -> ToleratedFailureCountPath:
1✔
258
        tolerated_failure_count_path: ToleratedFailureCountPath = (
×
259
            super().visitTolerated_failure_count_path(ctx)
260
        )
261
        tolerated_failure_count_path._eval_body = _decorated_updates_inspection_data(
×
262
            method=tolerated_failure_count_path._eval_body,
263
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_COUNT,
264
        )
265
        return tolerated_failure_count_path
×
266

267
    def visitTolerated_failure_count_string_jsonata(
1✔
268
        self, ctx
269
    ) -> ToleratedFailureCountStringJSONata:
270
        tolerated_failure_count_jsonata: ToleratedFailureCountStringJSONata = (
1✔
271
            super().visitTolerated_failure_count_string_jsonata(ctx)
272
        )
273
        tolerated_failure_count_jsonata._eval_body = _decorated_updates_inspection_data(
1✔
274
            method=tolerated_failure_count_jsonata._eval_body,
275
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_COUNT,
276
        )
277
        return tolerated_failure_count_jsonata
1✔
278

279
    def visitTolerated_failure_percentage_number(self, ctx) -> ToleratedFailurePercentage:
1✔
280
        tolerated_failure_percentage: ToleratedFailurePercentage = (
×
281
            super().visitTolerated_failure_percentage_number(ctx)
282
        )
283
        tolerated_failure_percentage._eval_body = _decorated_updates_inspection_data(
×
284
            method=tolerated_failure_percentage._eval_body,
285
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_PERCENTAGE,
286
        )
287
        return tolerated_failure_percentage
×
288

289
    def visitTolerated_failure_percentage_path(self, ctx) -> ToleratedFailurePercentagePath:
1✔
290
        tolerated_failure_percentage_path: ToleratedFailurePercentagePath = (
×
291
            super().visitTolerated_failure_percentage_path(ctx)
292
        )
293
        tolerated_failure_percentage_path._eval_body = _decorated_updates_inspection_data(
×
294
            method=tolerated_failure_percentage_path._eval_body,
295
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_PERCENTAGE,
296
        )
297
        return tolerated_failure_percentage_path
×
298

299
    def visitTolerated_failure_percentage_string_jsonata(
1✔
300
        self, ctx
301
    ) -> ToleratedFailurePercentageStringJSONata:
302
        tolerated_failure_percentage_jsonata: ToleratedFailurePercentageStringJSONata = (
1✔
303
            super().visitTolerated_failure_percentage_string_jsonata(ctx)
304
        )
305
        tolerated_failure_percentage_jsonata._eval_body = _decorated_updates_inspection_data(
1✔
306
            method=tolerated_failure_percentage_jsonata._eval_body,
307
            inspection_data_key=InspectionDataKey.TOLERATED_FAILURE_PERCENTAGE,
308
        )
309
        return tolerated_failure_percentage_jsonata
1✔
310

311
    def visitItems_path_decl(self, ctx) -> ItemsPath:
1✔
312
        items_path: ItemsPath = super().visitItems_path_decl(ctx)
1✔
313
        items_path._eval_body = _decorated_updates_inspection_data(
1✔
314
            method=items_path._eval_body,
315
            inspection_data_key=InspectionDataKey.AFTER_ITEMS_PATH,
316
        )
317
        return items_path
1✔
318

319
    def visitArguments_string_jsonata(self, ctx):
1✔
320
        arguments: ArgumentsStringJSONata = super().visitArguments_string_jsonata(ctx)
1✔
321
        arguments._eval_body = _decorated_updates_inspection_data(
1✔
322
            method=arguments._eval_body,
323
            inspection_data_key=InspectionDataKey.AFTER_ARGUMENTS,
324
        )
325
        return arguments
1✔
326

327
    def visitArguments_jsonata_template_value_object(self, ctx):
1✔
328
        arguments: ArgumentsJSONataTemplateValueObject = (
1✔
329
            super().visitArguments_jsonata_template_value_object(ctx)
330
        )
331
        arguments._eval_body = _decorated_updates_inspection_data(
1✔
332
            method=arguments._eval_body,
333
            inspection_data_key=InspectionDataKey.AFTER_ARGUMENTS,
334
        )
335
        return arguments
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc