• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pytransitions / transitions / 8938991339

03 May 2024 12:30PM UTC coverage: 98.432% (+0.2%) from 98.217%
8938991339

push

github

aleneum
use coverage only for mypy job and update setup.py tags

5149 of 5231 relevant lines covered (98.43%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.4
/transitions/extensions/asyncio.py
1
"""
1✔
2
    transitions.extensions.asyncio
3
    ------------------------------
4

5
    This module contains machine, state and event implementations for asynchronous callback processing.
6
    `AsyncMachine` and `HierarchicalAsyncMachine` use `asyncio` for concurrency. The extension `transitions-anyio`
7
    found at https://github.com/pytransitions/transitions-anyio illustrates how they can be extended to
8
    make use of other concurrency libraries.
9
    The module also contains the state mixin `AsyncTimeout` to asynchronously trigger timeout-related callbacks.
10
"""
11

12
# Overriding base methods of states, transitions and machines with async variants is not considered good practise.
13
# However, the alternative would mean to either increase the complexity of the base classes or copy code fragments
14
# and thus increase code complexity and reduce maintainability. If you know a better solution, please file an issue.
15
# pylint: disable=invalid-overridden-method
16

17
import logging
1✔
18
import asyncio
1✔
19
import contextvars
1✔
20
import inspect
1✔
21
from collections import deque
1✔
22
from functools import partial, reduce
1✔
23
import copy
1✔
24

25
from ..core import State, Condition, Transition, EventData, listify
1✔
26
from ..core import Event, MachineError, Machine
1✔
27
from .nesting import HierarchicalMachine, NestedState, NestedEvent, NestedTransition, resolve_order
1✔
28

29

30
_LOGGER = logging.getLogger(__name__)
1✔
31
_LOGGER.addHandler(logging.NullHandler())
1✔
32

33

34
class AsyncState(State):
1✔
35
    """A persistent representation of a state managed by a ``Machine``. Callback execution is done asynchronously."""
1✔
36

37
    async def enter(self, event_data):
1✔
38
        """Triggered when a state is entered.
39
        Args:
40
            event_data: (AsyncEventData): The currently processed event.
41
        """
42
        _LOGGER.debug("%sEntering state %s. Processing callbacks...", event_data.machine.name, self.name)
1✔
43
        await event_data.machine.callbacks(self.on_enter, event_data)
1✔
44
        _LOGGER.info("%sFinished processing state %s enter callbacks.", event_data.machine.name, self.name)
1✔
45

46
    async def exit(self, event_data):
1✔
47
        """Triggered when a state is exited.
48
        Args:
49
            event_data: (AsyncEventData): The currently processed event.
50
        """
51
        _LOGGER.debug("%sExiting state %s. Processing callbacks...", event_data.machine.name, self.name)
1✔
52
        await event_data.machine.callbacks(self.on_exit, event_data)
1✔
53
        _LOGGER.info("%sFinished processing state %s exit callbacks.", event_data.machine.name, self.name)
1✔
54

55

56
class NestedAsyncState(NestedState, AsyncState):
1✔
57
    """A state that allows substates. Callback execution is done asynchronously."""
1✔
58

59
    async def scoped_enter(self, event_data, scope=None):
1✔
60
        self._scope = scope or []
1✔
61
        await self.enter(event_data)
1✔
62
        self._scope = []
1✔
63

64
    async def scoped_exit(self, event_data, scope=None):
1✔
65
        self._scope = scope or []
1✔
66
        await self.exit(event_data)
1✔
67
        self._scope = []
1✔
68

69

70
class AsyncCondition(Condition):
1✔
71
    """A helper class to await condition checks in the intended way."""
1✔
72

73
    async def check(self, event_data):
1✔
74
        """Check whether the condition passes.
75
        Args:
76
            event_data (EventData): An EventData instance to pass to the
77
                condition (if event sending is enabled) or to extract arguments
78
                from (if event sending is disabled). Also contains the data
79
                model attached to the current machine which is used to invoke
80
                the condition.
81
        """
82
        func = event_data.machine.resolve_callable(self.func, event_data)
1✔
83
        res = func(event_data) if event_data.machine.send_event else func(*event_data.args, **event_data.kwargs)
1✔
84
        if inspect.isawaitable(res):
1✔
85
            return await res == self.target
1✔
86
        return res == self.target
1✔
87

88

89
class AsyncTransition(Transition):
1✔
90
    """Representation of an asynchronous transition managed by a ``AsyncMachine`` instance."""
1✔
91

92
    condition_cls = AsyncCondition
1✔
93

94
    async def _eval_conditions(self, event_data):
1✔
95
        res = await event_data.machine.await_all([partial(cond.check, event_data) for cond in self.conditions])
1✔
96
        if not all(res):
1✔
97
            _LOGGER.debug("%sTransition condition failed: Transition halted.", event_data.machine.name)
1✔
98
            return False
1✔
99
        return True
1✔
100

101
    async def execute(self, event_data):
1✔
102
        """Executes the transition.
103
        Args:
104
            event_data (EventData): An instance of class EventData.
105
        Returns: boolean indicating whether or not the transition was
106
            successfully executed (True if successful, False if not).
107
        """
108
        _LOGGER.debug("%sInitiating transition from state %s to state %s...",
1✔
109
                      event_data.machine.name, self.source, self.dest)
110

111
        await event_data.machine.callbacks(self.prepare, event_data)
1✔
112
        _LOGGER.debug("%sExecuted callbacks before conditions.", event_data.machine.name)
1✔
113

114
        if not await self._eval_conditions(event_data):
1✔
115
            return False
1✔
116

117
        machine = event_data.machine
1✔
118
        # cancel running tasks since the transition will happen
119
        await machine.switch_model_context(event_data.model)
1✔
120

121
        await event_data.machine.callbacks(event_data.machine.before_state_change, event_data)
1✔
122
        await event_data.machine.callbacks(self.before, event_data)
1✔
123
        _LOGGER.debug("%sExecuted callback before transition.", event_data.machine.name)
1✔
124

125
        if self.dest:  # if self.dest is None this is an internal transition with no actual state change
1✔
126
            await self._change_state(event_data)
1✔
127

128
        await event_data.machine.callbacks(self.after, event_data)
1✔
129
        await event_data.machine.callbacks(event_data.machine.after_state_change, event_data)
1✔
130
        _LOGGER.debug("%sExecuted callback after transition.", event_data.machine.name)
1✔
131
        return True
1✔
132

133
    async def _change_state(self, event_data):
1✔
134
        if hasattr(event_data.machine, "model_graphs"):
1✔
135
            graph = event_data.machine.model_graphs[id(event_data.model)]
1✔
136
            graph.reset_styling()
1✔
137
            graph.set_previous_transition(self.source, self.dest)
1✔
138
        await event_data.machine.get_state(self.source).exit(event_data)
1✔
139
        event_data.machine.set_state(self.dest, event_data.model)
1✔
140
        event_data.update(getattr(event_data.model, event_data.machine.model_attribute))
1✔
141
        await event_data.machine.get_state(self.dest).enter(event_data)
1✔
142

143

144
class NestedAsyncTransition(AsyncTransition, NestedTransition):
1✔
145
    """Representation of an asynchronous transition managed by a ``HierarchicalMachine`` instance."""
1✔
146
    async def _change_state(self, event_data):
1✔
147
        if hasattr(event_data.machine, "model_graphs"):
1✔
148
            graph = event_data.machine.model_graphs[id(event_data.model)]
1✔
149
            graph.reset_styling()
1✔
150
            graph.set_previous_transition(self.source, self.dest)
1✔
151
        state_tree, exit_partials, enter_partials = self._resolve_transition(event_data)
1✔
152
        for func in exit_partials:
1✔
153
            await func()
1✔
154
        self._update_model(event_data, state_tree)
1✔
155
        for func in enter_partials:
1✔
156
            await func()
1✔
157

158

159
class AsyncEventData(EventData):
1✔
160
    """A redefinition of the base EventData intended to easy type checking."""
1✔
161

162

163
class AsyncEvent(Event):
1✔
164
    """A collection of transitions assigned to the same trigger"""
1✔
165

166
    async def trigger(self, model, *args, **kwargs):
1✔
167
        """Serially execute all transitions that match the current state,
168
        halting as soon as one successfully completes. Note that `AsyncEvent` triggers must be awaited.
169
        Args:
170
            args and kwargs: Optional positional or named arguments that will
171
                be passed onto the EventData object, enabling arbitrary state
172
                information to be passed on to downstream triggered functions.
173
        Returns: boolean indicating whether or not a transition was
174
            successfully executed (True if successful, False if not).
175
        """
176
        func = partial(self._trigger, EventData(None, self, self.machine, model, args=args, kwargs=kwargs))
1✔
177
        return await self.machine.process_context(func, model)
1✔
178

179
    async def _trigger(self, event_data):
1✔
180
        event_data.state = self.machine.get_state(getattr(event_data.model, self.machine.model_attribute))
1✔
181
        try:
1✔
182
            if self._is_valid_source(event_data.state):
1✔
183
                await self._process(event_data)
1✔
184
        except BaseException as err:  # pylint: disable=broad-except; Exception will be handled elsewhere
1✔
185
            _LOGGER.error("%sException was raised while processing the trigger: %s", self.machine.name, err)
1✔
186
            event_data.error = err
1✔
187
            if self.machine.on_exception:
1✔
188
                await self.machine.callbacks(self.machine.on_exception, event_data)
1✔
189
            else:
190
                raise
1✔
191
        finally:
192
            try:
1✔
193
                await self.machine.callbacks(self.machine.finalize_event, event_data)
1✔
194
                _LOGGER.debug("%sExecuted machine finalize callbacks", self.machine.name)
1✔
195
            except BaseException as err:  # pylint: disable=broad-except; Exception will be handled elsewhere
1✔
196
                _LOGGER.error("%sWhile executing finalize callbacks a %s occurred: %s.",
1✔
197
                              self.machine.name,
198
                              type(err).__name__,
199
                              str(err))
200
        return event_data.result
1✔
201

202
    async def _process(self, event_data):
1✔
203
        await self.machine.callbacks(self.machine.prepare_event, event_data)
1✔
204
        _LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", self.machine.name)
1✔
205
        for trans in self.transitions[event_data.state.name]:
1✔
206
            event_data.transition = trans
1✔
207
            event_data.result = await trans.execute(event_data)
1✔
208
            if event_data.result:
1✔
209
                break
1✔
210

211

212
class NestedAsyncEvent(NestedEvent):
1✔
213
    """A collection of transitions assigned to the same trigger.
1✔
214
    This Event requires a (subclass of) `HierarchicalAsyncMachine`.
215
    """
216

217
    async def trigger_nested(self, event_data):
1✔
218
        """Serially execute all transitions that match the current state,
219
        halting as soon as one successfully completes. NOTE: This should only
220
        be called by HierarchicalMachine instances.
221
        Args:
222
            event_data (AsyncEventData): The currently processed event.
223
        Returns: boolean indicating whether or not a transition was
224
            successfully executed (True if successful, False if not).
225
        """
226
        machine = event_data.machine
1✔
227
        model = event_data.model
1✔
228
        state_tree = machine.build_state_tree(getattr(model, machine.model_attribute), machine.state_cls.separator)
1✔
229
        state_tree = reduce(dict.get, machine.get_global_name(join=False), state_tree)
1✔
230
        ordered_states = resolve_order(state_tree)
1✔
231
        done = set()
1✔
232
        event_data.event = self
1✔
233
        for state_path in ordered_states:
1✔
234
            state_name = machine.state_cls.separator.join(state_path)
1✔
235
            if state_name not in done and state_name in self.transitions:
1✔
236
                event_data.state = machine.get_state(state_name)
1✔
237
                event_data.source_name = state_name
1✔
238
                event_data.source_path = copy.copy(state_path)
1✔
239
                await self._process(event_data)
1✔
240
                if event_data.result:
1✔
241
                    elems = state_path
1✔
242
                    while elems:
1✔
243
                        done.add(machine.state_cls.separator.join(elems))
1✔
244
                        elems.pop()
1✔
245
        return event_data.result
1✔
246

247
    async def _process(self, event_data):
1✔
248
        machine = event_data.machine
1✔
249
        await machine.callbacks(event_data.machine.prepare_event, event_data)
1✔
250
        _LOGGER.debug("%sExecuted machine preparation callbacks before conditions.", machine.name)
1✔
251

252
        for trans in self.transitions[event_data.source_name]:
1✔
253
            event_data.transition = trans
1✔
254
            event_data.result = await trans.execute(event_data)
1✔
255
            if event_data.result:
1✔
256
                break
1✔
257

258

259
class AsyncMachine(Machine):
1✔
260
    """Machine manages states, transitions and models. In case it is initialized without a specific model
1✔
261
    (or specifically no model), it will also act as a model itself. Machine takes also care of decorating
262
    models with conveniences functions related to added transitions and states during runtime.
263

264
    Attributes:
265
        states (OrderedDict): Collection of all registered states.
266
        events (dict): Collection of transitions ordered by trigger/event.
267
        models (list): List of models attached to the machine.
268
        initial (str): Name of the initial state for new models.
269
        prepare_event (list): Callbacks executed when an event is triggered.
270
        before_state_change (list): Callbacks executed after condition checks but before transition is conducted.
271
            Callbacks will be executed BEFORE the custom callbacks assigned to the transition.
272
        after_state_change (list): Callbacks executed after the transition has been conducted.
273
            Callbacks will be executed AFTER the custom callbacks assigned to the transition.
274
        finalize_event (list): Callbacks will be executed after all transitions callbacks have been executed.
275
            Callbacks mentioned here will also be called if a transition or condition check raised an error.
276
        on_exception: A callable called when an event raises an exception. If not set,
277
            the Exception will be raised instead.
278
        queued (bool or str): Whether transitions in callbacks should be executed immediately (False) or sequentially.
279
        send_event (bool): When True, any arguments passed to trigger methods will be wrapped in an EventData
280
            object, allowing indirect and encapsulated access to data. When False, all positional and keyword
281
            arguments will be passed directly to all callback methods.
282
        auto_transitions (bool):  When True (default), every state will automatically have an associated
283
            to_{state}() convenience trigger in the base model.
284
        ignore_invalid_triggers (bool): When True, any calls to trigger methods that are not valid for the
285
            present state (e.g., calling an a_to_b() trigger when the current state is c) will be silently
286
            ignored rather than raising an invalid transition exception.
287
        name (str): Name of the ``Machine`` instance mainly used for easier log message distinction.
288
    """
289

290
    state_cls = AsyncState
1✔
291
    transition_cls = AsyncTransition
1✔
292
    event_cls = AsyncEvent
1✔
293
    async_tasks = {}
1✔
294
    protected_tasks = []
1✔
295
    current_context = contextvars.ContextVar('current_context', default=None)
1✔
296

297
    def __init__(self, model=Machine.self_literal, states=None, initial='initial', transitions=None,
1✔
298
                 send_event=False, auto_transitions=True,
299
                 ordered_transitions=False, ignore_invalid_triggers=None,
300
                 before_state_change=None, after_state_change=None, name=None,
301
                 queued=False, prepare_event=None, finalize_event=None, model_attribute='state', on_exception=None,
302
                 **kwargs):
303
        self._transition_queue_dict = {}
1✔
304
        super().__init__(model=model, states=states, initial=initial, transitions=transitions,
1✔
305
                         send_event=send_event, auto_transitions=auto_transitions,
306
                         ordered_transitions=ordered_transitions, ignore_invalid_triggers=ignore_invalid_triggers,
307
                         before_state_change=before_state_change, after_state_change=after_state_change, name=name,
308
                         queued=queued, prepare_event=prepare_event, finalize_event=finalize_event,
309
                         model_attribute=model_attribute, on_exception=on_exception, **kwargs)
310
        if self.has_queue is True:
1✔
311
            # _DictionaryMock sets and returns ONE internal value and ignores the passed key
312
            self._transition_queue_dict = _DictionaryMock(self._transition_queue)
1✔
313

314
    def add_model(self, model, initial=None):
1✔
315
        super().add_model(model, initial)
1✔
316
        if self.has_queue == 'model':
1✔
317
            for mod in listify(model):
1✔
318
                self._transition_queue_dict[id(self) if mod is self.self_literal else id(mod)] = deque()
1✔
319

320
    async def dispatch(self, trigger, *args, **kwargs):
1✔
321
        """Trigger an event on all models assigned to the machine.
322
        Args:
323
            trigger (str): Event name
324
            *args (list): List of arguments passed to the event trigger
325
            **kwargs (dict): Dictionary of keyword arguments passed to the event trigger
326
        Returns:
327
            bool The truth value of all triggers combined with AND
328
        """
329
        results = await self.await_all([partial(getattr(model, trigger), *args, **kwargs) for model in self.models])
1✔
330
        return all(results)
1✔
331

332
    async def callbacks(self, funcs, event_data):
1✔
333
        """Triggers a list of callbacks"""
334
        await self.await_all([partial(event_data.machine.callback, func, event_data) for func in funcs])
1✔
335

336
    async def callback(self, func, event_data):
1✔
337
        """Trigger a callback function with passed event_data parameters. In case func is a string,
338
            the callable will be resolved from the passed model in event_data. This function is not intended to
339
            be called directly but through state and transition callback definitions.
340
        Args:
341
            func (string, callable): The callback function.
342
                1. First, if the func is callable, just call it
343
                2. Second, we try to import string assuming it is a path to a func
344
                3. Fallback to a model attribute
345
            event_data (EventData): An EventData instance to pass to the
346
                callback (if event sending is enabled) or to extract arguments
347
                from (if event sending is disabled).
348
        """
349
        func = self.resolve_callable(func, event_data)
1✔
350
        res = func(event_data) if self.send_event else func(*event_data.args, **event_data.kwargs)
1✔
351
        if inspect.isawaitable(res):
1✔
352
            await res
1✔
353

354
    @staticmethod
1✔
355
    async def await_all(callables):
1✔
356
        """
357
        Executes callables without parameters in parallel and collects their results.
358
        Args:
359
            callables (list): A list of callable functions
360

361
        Returns:
362
            list: A list of results. Using asyncio the list will be in the same order as the passed callables.
363
        """
364
        return await asyncio.gather(*[func() for func in callables])
1✔
365

366
    async def switch_model_context(self, model):
1✔
367
        """
368
        This method is called by an `AsyncTransition` when all conditional tests have passed
369
        and the transition will happen. This requires already running tasks to be cancelled.
370
        Args:
371
            model (object): The currently processed model
372
        """
373
        for running_task in self.async_tasks.get(id(model), []):
1✔
374
            if self.current_context.get() == running_task or running_task in self.protected_tasks:
1✔
375
                continue
1✔
376
            if running_task.done() is False:
×
377
                _LOGGER.debug("Cancel running tasks...")
×
378
                running_task.cancel()
×
379

380
    async def process_context(self, func, model):
1✔
381
        """
382
        This function is called by an `AsyncEvent` to make callbacks processed in Event._trigger cancellable.
383
        Using asyncio this will result in a try-catch block catching CancelledEvents.
384
        Args:
385
            func (partial): The partial of Event._trigger with all parameters already assigned
386
            model (object): The currently processed model
387

388
        Returns:
389
            bool: returns the success state of the triggered event
390
        """
391
        if self.current_context.get() is None:
1✔
392
            self.current_context.set(asyncio.current_task())
1✔
393
            if id(model) in self.async_tasks:
1✔
394
                self.async_tasks[id(model)].append(asyncio.current_task())
1✔
395
            else:
396
                self.async_tasks[id(model)] = [asyncio.current_task()]
1✔
397
            try:
1✔
398
                res = await self._process_async(func, model)
1✔
399
            except asyncio.CancelledError:
1✔
400
                res = False
1✔
401
            finally:
402
                self.async_tasks[id(model)].remove(asyncio.current_task())
1✔
403
                if len(self.async_tasks[id(model)]) == 0:
1✔
404
                    del self.async_tasks[id(model)]
1✔
405
        else:
406
            res = await self._process_async(func, model)
1✔
407
        return res
1✔
408

409
    def remove_model(self, model):
1✔
410
        """Remove a model from the state machine. The model will still contain all previously added triggers
411
        and callbacks, but will not receive updates when states or transitions are added to the Machine.
412
        If an event queue is used, all queued events of that model will be removed."""
413
        models = listify(model)
1✔
414
        if self.has_queue == 'model':
1✔
415
            for mod in models:
1✔
416
                del self._transition_queue_dict[id(mod)]
1✔
417
                self.models.remove(mod)
1✔
418
        else:
419
            for mod in models:
1✔
420
                self.models.remove(mod)
1✔
421
        if len(self._transition_queue) > 0:
1✔
422
            queue = self._transition_queue
1✔
423
            new_queue = [queue.popleft()] + [e for e in queue if e.args[0].model not in models]
1✔
424
            self._transition_queue.clear()
1✔
425
            self._transition_queue.extend(new_queue)
1✔
426

427
    async def _can_trigger(self, model, trigger, *args, **kwargs):
1✔
428
        evt = AsyncEventData(None, None, self, model, args, kwargs)
1✔
429
        state = self.get_model_state(model).name
1✔
430

431
        for trigger_name in self.get_triggers(state):
1✔
432
            if trigger_name != trigger:
1✔
433
                continue
1✔
434
            for transition in self.events[trigger_name].transitions[state]:
1✔
435
                try:
1✔
436
                    _ = self.get_state(transition.dest)
1✔
437
                except ValueError:
1✔
438
                    continue
1✔
439
                await self.callbacks(self.prepare_event, evt)
1✔
440
                await self.callbacks(transition.prepare, evt)
1✔
441
                if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])):
1✔
442
                    return True
1✔
443
        return False
1✔
444

445
    def _process(self, trigger):
1✔
446
        raise RuntimeError("AsyncMachine should not call `Machine._process`. Use `Machine._process_async` instead.")
×
447

448
    async def _process_async(self, trigger, model):
1✔
449
        # default processing
450
        if not self.has_queue:
1✔
451
            if not self._transition_queue:
1✔
452
                # if trigger raises an Error, it has to be handled by the Machine.process caller
453
                return await trigger()
1✔
454
            raise MachineError("Attempt to process events synchronously while transition queue is not empty!")
1✔
455

456
        self._transition_queue_dict[id(model)].append(trigger)
1✔
457
        # another entry in the queue implies a running transition; skip immediate execution
458
        if len(self._transition_queue_dict[id(model)]) > 1:
1✔
459
            return True
1✔
460

461
        while self._transition_queue_dict[id(model)]:
1✔
462
            try:
1✔
463
                await self._transition_queue_dict[id(model)][0]()
1✔
464
            except BaseException:
1✔
465
                # if a transition raises an exception, clear queue and delegate exception handling
466
                self._transition_queue_dict[id(model)].clear()
1✔
467
                raise
1✔
468
            try:
1✔
469
                self._transition_queue_dict[id(model)].popleft()
1✔
470
            except KeyError:
1✔
471
                return True
1✔
472
        return True
1✔
473

474

475
class HierarchicalAsyncMachine(HierarchicalMachine, AsyncMachine):
1✔
476
    """Asynchronous variant of transitions.extensions.nesting.HierarchicalMachine.
1✔
477
        An asynchronous hierarchical machine REQUIRES AsyncNestedStates, AsyncNestedEvent and AsyncNestedTransitions
478
        (or any subclass of it) to operate.
479
    """
480

481
    state_cls = NestedAsyncState
1✔
482
    transition_cls = NestedAsyncTransition
1✔
483
    event_cls = NestedAsyncEvent
1✔
484

485
    async def trigger_event(self, model, trigger, *args, **kwargs):
1✔
486
        """Processes events recursively and forwards arguments if suitable events are found.
487
        This function is usually bound to models with model and trigger arguments already
488
        resolved as a partial. Execution will halt when a nested transition has been executed
489
        successfully.
490
        Args:
491
            model (object): targeted model
492
            trigger (str): event name
493
            *args: positional parameters passed to the event and its callbacks
494
            **kwargs: keyword arguments passed to the event and its callbacks
495
        Returns:
496
            bool: whether a transition has been executed successfully
497
        Raises:
498
            MachineError: When no suitable transition could be found and ignore_invalid_trigger
499
                          is not True. Note that a transition which is not executed due to conditions
500
                          is still considered valid.
501
        """
502
        event_data = AsyncEventData(state=None, event=None, machine=self, model=model, args=args, kwargs=kwargs)
1✔
503
        event_data.result = None
1✔
504

505
        return await self.process_context(partial(self._trigger_event, event_data, trigger), model)
1✔
506

507
    async def _trigger_event(self, event_data, trigger):
1✔
508
        try:
1✔
509
            with self():
1✔
510
                res = await self._trigger_event_nested(event_data, trigger, None)
1✔
511
            event_data.result = self._check_event_result(res, event_data.model, trigger)
1✔
512
        except BaseException as err:  # pylint: disable=broad-except; Exception will be handled elsewhere
1✔
513
            event_data.error = err
1✔
514
            if self.on_exception:
1✔
515
                await self.callbacks(self.on_exception, event_data)
1✔
516
            else:
517
                raise
1✔
518
        finally:
519
            try:
1✔
520
                await self.callbacks(self.finalize_event, event_data)
1✔
521
                _LOGGER.debug("%sExecuted machine finalize callbacks", self.name)
1✔
522
            except BaseException as err:  # pylint: disable=broad-except; Exception will be handled elsewhere
1✔
523
                _LOGGER.error("%sWhile executing finalize callbacks a %s occurred: %s.",
1✔
524
                              self.name,
525
                              type(err).__name__,
526
                              str(err))
527
        return event_data.result
1✔
528

529
    async def _trigger_event_nested(self, event_data, _trigger, _state_tree):
1✔
530
        model = event_data.model
1✔
531
        if _state_tree is None:
1✔
532
            _state_tree = self.build_state_tree(listify(getattr(model, self.model_attribute)),
1✔
533
                                                self.state_cls.separator)
534
        res = {}
1✔
535
        for key, value in _state_tree.items():
1✔
536
            if value:
1✔
537
                with self(key):
1✔
538
                    tmp = await self._trigger_event_nested(event_data, _trigger, value)
1✔
539
                    if tmp is not None:
1✔
540
                        res[key] = tmp
×
541
            if not res.get(key, None) and _trigger in self.events:
1✔
542
                tmp = await self.events[_trigger].trigger_nested(event_data)
1✔
543
                if tmp is not None:
1✔
544
                    res[key] = tmp
1✔
545
        return None if not res or all(v is None for v in res.values()) else any(res.values())
1✔
546

547
    async def _can_trigger(self, model, trigger, *args, **kwargs):
1✔
548
        state_tree = self.build_state_tree(getattr(model, self.model_attribute), self.state_cls.separator)
1✔
549
        ordered_states = resolve_order(state_tree)
1✔
550
        for state_path in ordered_states:
1✔
551
            with self():
1✔
552
                return await self._can_trigger_nested(model, trigger, state_path, *args, **kwargs)
1✔
553

554
    async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
1✔
555
        evt = AsyncEventData(None, None, self, model, args, kwargs)
1✔
556
        if trigger in self.events:
1✔
557
            source_path = copy.copy(path)
1✔
558
            while source_path:
1✔
559
                state_name = self.state_cls.separator.join(source_path)
1✔
560
                for transition in self.events[trigger].transitions.get(state_name, []):
1✔
561
                    try:
1✔
562
                        _ = self.get_state(transition.dest)
1✔
563
                    except ValueError:
1✔
564
                        continue
1✔
565
                    await self.callbacks(self.prepare_event, evt)
1✔
566
                    await self.callbacks(transition.prepare, evt)
1✔
567
                    if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])):
1✔
568
                        return True
1✔
569
                source_path.pop(-1)
1✔
570
        if path:
1✔
571
            with self(path.pop(0)):
1✔
572
                return await self._can_trigger_nested(model, trigger, path, *args, **kwargs)
1✔
573
        return False
1✔
574

575

576
class AsyncTimeout(AsyncState):
1✔
577
    """
1✔
578
    Adds timeout functionality to an asynchronous state. Timeouts are handled model-specific.
579

580
    Attributes:
581
        timeout (float): Seconds after which a timeout function should be
582
                         called.
583
        on_timeout (list): Functions to call when a timeout is triggered.
584
        runner (dict): Keeps track of running timeout tasks to cancel when a state is exited.
585
    """
586

587
    dynamic_methods = ["on_timeout"]
1✔
588

589
    def __init__(self, *args, **kwargs):
1✔
590
        """
591
        Args:
592
            **kwargs: If kwargs contain 'timeout', assign the float value to
593
                self.timeout. If timeout is set, 'on_timeout' needs to be
594
                passed with kwargs as well or an AttributeError will be thrown
595
                if timeout is not passed or equal 0.
596
        """
597
        self.timeout = kwargs.pop("timeout", 0)
1✔
598
        self._on_timeout = None
1✔
599
        if self.timeout > 0:
1✔
600
            try:
1✔
601
                self.on_timeout = kwargs.pop("on_timeout")
1✔
602
            except KeyError:
1✔
603
                raise AttributeError("Timeout state requires 'on_timeout' when timeout is set.") from None
1✔
604
        else:
605
            self.on_timeout = kwargs.pop("on_timeout", None)
1✔
606
        self.runner = {}
1✔
607
        super().__init__(*args, **kwargs)
1✔
608

609
    async def enter(self, event_data):
1✔
610
        """
611
        Extends `transitions.core.State.enter` by starting a timeout timer for
612
        the current model when the state is entered and self.timeout is larger
613
        than 0.
614

615
        Args:
616
            event_data (EventData): events representing the currently processed event.
617
        """
618
        if self.timeout > 0:
1✔
619
            self.runner[id(event_data.model)] = self.create_timer(event_data)
1✔
620
        await super().enter(event_data)
1✔
621

622
    async def exit(self, event_data):
1✔
623
        """
624
        Cancels running timeout tasks stored in `self.runner` first (when not note) before
625
        calling further exit callbacks.
626

627
        Args:
628
            event_data (EventData): Data representing the currently processed event.
629

630
        Returns:
631

632
        """
633
        timer_task = self.runner.get(id(event_data.model), None)
1✔
634
        if timer_task is not None and not timer_task.done():
1✔
635
            timer_task.cancel()
1✔
636
        await super().exit(event_data)
1✔
637

638
    def create_timer(self, event_data):
1✔
639
        """
640
        Creates and returns a running timer. Shields self._process_timeout to prevent cancellation when
641
        transitioning away from the current state (which cancels the timer) while processing timeout callbacks.
642
        Args:
643
            event_data (EventData): Data representing the currently processed event.
644

645
        Returns (cancellable): A running timer with a cancel method
646
        """
647
        async def _timeout():
1✔
648
            try:
1✔
649
                await asyncio.sleep(self.timeout)
1✔
650
                await asyncio.shield(self._process_timeout(event_data))
1✔
651
            except asyncio.CancelledError:
1✔
652
                pass
1✔
653

654
        return asyncio.ensure_future(_timeout())
1✔
655

656
    async def _process_timeout(self, event_data):
1✔
657
        _LOGGER.debug("%sTimeout state %s. Processing callbacks...", event_data.machine.name, self.name)
1✔
658
        await event_data.machine.callbacks(self.on_timeout, event_data)
1✔
659
        _LOGGER.info("%sTimeout state %s processed.", event_data.machine.name, self.name)
1✔
660

661
    @property
1✔
662
    def on_timeout(self):
1✔
663
        """
664
        List of strings and callables to be called when the state timeouts.
665
        """
666
        return self._on_timeout
1✔
667

668
    @on_timeout.setter
1✔
669
    def on_timeout(self, value):
1✔
670
        """Listifies passed values and assigns them to on_timeout."""
671
        self._on_timeout = listify(value)
1✔
672

673

674
class _DictionaryMock(dict):
1✔
675

676
    def __init__(self, item):
1✔
677
        super().__init__()
1✔
678
        self._value = item
1✔
679

680
    def __setitem__(self, key, item):
1✔
681
        self._value = item
×
682

683
    def __getitem__(self, key):
1✔
684
        return self._value
1✔
685

686
    def __repr__(self):
1✔
687
        return repr("{{'*': {0}}}".format(self._value))
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc