from collections import deque
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Dict
from .callbacks import CallbackMetaList
from .callbacks import CallbacksExecutor
from .callbacks import CallbacksRegistry
from .dispatcher import ObjectConfig
from .dispatcher import resolver_factory
from .event import Event
from .event_data import EventData
from .event_data import TriggerData
from .exceptions import InvalidDefinition
from .exceptions import InvalidStateValue
from .exceptions import TransitionNotAllowed
from .factory import StateMachineMetaclass
from .i18n import _
from .model import Model
from .transition import Transition
if TYPE_CHECKING:
from .state import State
[docs]
class StateMachine(metaclass=StateMachineMetaclass):
"""
Args:
model: An optional external object to store state. See :ref:`domain models`.
state_field (str): The model's field which stores the current state.
Default: ``state``.
start_value: An optional start state value if there's no current state assigned
on the :ref:`domain models`. Default: ``None``.
rtc (bool): Controls the :ref:`processing model`. Defaults to ``True``
that corresponds to a **run-to-completion** (RTC) model.
allow_event_without_transition: If ``False`` when an event does not result in a transition,
an exception ``TransitionNotAllowed`` will be raised.
If ``True`` the state machine allows triggering events that may not lead to a state
:ref:`transition`, including tolerance to unknown :ref:`event` triggers.
Default: ``False``.
"""
TransitionNotAllowed = TransitionNotAllowed
"""Shortcut for easy exception handling.
Example::
try:
sm.send("an-inexistent-event")
except sm.TransitionNotAllowed:
pass
"""
def __init__(
self,
model: Any = None,
state_field: str = "state",
start_value: Any = None,
rtc: bool = True,
allow_event_without_transition: bool = False,
):
self.model = model if model else Model()
self.state_field = state_field
self.start_value = start_value
self.allow_event_without_transition = allow_event_without_transition
self.__rtc = rtc
self.__processing: bool = False
self._external_queue: deque = deque()
self._callbacks_registry = CallbacksRegistry()
self._states_for_instance: Dict["State", "State"] = {}
self._observers: Dict[Any, Any] = {}
if self._abstract:
raise InvalidDefinition(_("There are no states or transitions."))
self._initial_transition = Transition(None, self._get_initial_state(), event="__initial__")
self._setup()
self._activate_initial_state()
def __init_subclass__(cls, strict_states: bool = False):
cls._strict_states = strict_states
super().__init_subclass__()
if TYPE_CHECKING:
"""Makes mypy happy with dynamic created attributes"""
def __getattr__(self, attribute: str) -> Any: ...
def __repr__(self):
current_state_id = self.current_state.id if self.current_state_value else None
return (
f"{type(self).__name__}(model={self.model!r}, state_field={self.state_field!r}, "
f"current_state={current_state_id!r})"
)
def __deepcopy__(self, memo):
deepcopy_method = self.__deepcopy__
self.__deepcopy__ = None
try:
cp = deepcopy(self, memo)
finally:
self.__deepcopy__ = deepcopy_method
cp.__deepcopy__ = deepcopy_method
cp._callbacks_registry.clear()
cp._setup()
cp.add_observer(*cp._observers.keys())
return cp
def _get_initial_state(self):
current_state_value = self.start_value if self.start_value else self.initial_state.value
try:
return self.states_map[current_state_value]
except KeyError as err:
raise InvalidStateValue(current_state_value) from err
def _activate_initial_state(self):
if self.current_state_value is None:
# send an one-time event `__initial__` to enter the current state.
# current_state = self.current_state
self._initial_transition.before.clear()
self._initial_transition.on.clear()
self._initial_transition.after.clear()
event_data = EventData(
trigger_data=TriggerData(
machine=self,
event=self._initial_transition.event,
),
transition=self._initial_transition,
)
self._activate(event_data)
def _get_protected_attrs(self):
return {
"_abstract",
"model",
"state_field",
"start_value",
"initial_state",
"final_states",
"states",
"_events",
"states_map",
"send",
} | {s.id for s in self.states}
def _iterate_states_and_transitions(self):
for state in self.states:
yield state
yield from state.transitions
def _setup(self):
machine = ObjectConfig.from_obj(self, skip_attrs=self._get_protected_attrs())
model = ObjectConfig.from_obj(self.model, skip_attrs={self.state_field})
add_observer_visitor = self._build_observers_visitor(machine, model)
check_callbacks = self._callbacks_registry.check
for visited in self._iterate_states_and_transitions():
visited._setup()
for visited in self._iterate_states_and_transitions():
add_observer_visitor(visited)
for visited in self._iterate_states_and_transitions():
visited._check_callbacks(check_callbacks)
def _build_observers_visitor(self, *observers):
resolver = resolver_factory(*observers)
_register = partial(self._callbacks_registry.register, resolver=resolver)
def add_observer_visitor(visited):
visited._add_observer(_register)
return add_observer_visitor
[docs]
def add_observer(self, *observers):
"""Add an observer.
Observers are a way to generically add behavior to a :ref:`StateMachine` without changing
its internal implementation.
.. seealso::
:ref:`observers`.
"""
self._observers.update({o: None for o in observers})
add_observer_visitor = self._build_observers_visitor(*observers)
for visited in self._iterate_states_and_transitions():
add_observer_visitor(visited)
return self
def _repr_html_(self):
return f'<div class="statemachine">{self._repr_svg_()}</div>'
def _repr_svg_(self):
return self._graph().create_svg().decode()
def _graph(self):
from .contrib.diagram import DotGraphMachine
return DotGraphMachine(self).get_graph()
@property
def current_state_value(self):
"""Get/Set the current :ref:`state` value.
This is a low level API, that can be used to assign any valid state value
completely bypassing all the hooks and validations.
"""
value = getattr(self.model, self.state_field, None)
return value
@current_state_value.setter
def current_state_value(self, value):
if value not in self.states_map:
raise InvalidStateValue(value)
setattr(self.model, self.state_field, value)
@property
def current_state(self) -> "State":
"""Get/Set the current :ref:`state`.
This is a low level API, that can be to assign any valid state
completely bypassing all the hooks and validations.
"""
state: State = self.states_map[self.current_state_value].for_instance(
machine=self,
cache=self._states_for_instance,
)
return state
@current_state.setter
def current_state(self, value):
self.current_state_value = value.value
@property
def events(self):
return self.__class__.events
@property
def allowed_events(self):
"""List of the current allowed events."""
return [getattr(self, event) for event in self.current_state.transitions.unique_events]
def _process(self, trigger):
"""Process event triggers.
The simplest implementation is the non-RTC (synchronous),
where the trigger will be run immediately and the result collected as the return.
.. note::
While processing the trigger, if others events are generated, they
will also be processed immediately, so a "nested" behavior happens.
If the machine is on ``rtc`` model (queued), the event is put on a queue, and only the
first event will have the result collected.
.. note::
While processing the queue items, if others events are generated, they
will be processed sequentially (and not nested).
"""
if not self.__rtc:
# The machine is in "synchronous" mode
return trigger()
# The machine is in "queued" mode
# Add the trigger to queue and start processing in a loop.
self._external_queue.append(trigger)
# We make sure that only the first event enters the processing critical section,
# next events will only be put on the queue and processed by the same loop.
if self.__processing:
return
return self._processing_loop()
def _processing_loop(self):
"""Execute the triggers in the queue in order until the queue is empty"""
self.__processing = True
# We will collect the first result as the processing result to keep backwards compatibility
# so we need to use a sentinel object instead of `None` because the first result may
# be also `None`, and on this case the `first_result` may be overridden by another result.
sentinel = object()
first_result = sentinel
try:
while self._external_queue:
trigger = self._external_queue.popleft()
try:
result = trigger()
if first_result is sentinel:
first_result = result
except Exception:
# Whe clear the queue as we don't have an expected behavior
# and cannot keep processing
self._external_queue.clear()
raise
finally:
self.__processing = False
return first_result if first_result is not sentinel else None
def _activate(self, event_data: EventData):
transition = event_data.transition
source = event_data.state
target = transition.target
result = self._callbacks(transition.before).call(
*event_data.args, **event_data.extended_kwargs
)
if source is not None and not transition.internal:
self._callbacks(source.exit).call(*event_data.args, **event_data.extended_kwargs)
result += self._callbacks(transition.on).call(
*event_data.args, **event_data.extended_kwargs
)
self.current_state = target
event_data.state = target
if not transition.internal:
self._callbacks(target.enter).call(*event_data.args, **event_data.extended_kwargs)
self._callbacks(transition.after).call(*event_data.args, **event_data.extended_kwargs)
if len(result) == 0:
result = None
elif len(result) == 1:
result = result[0]
return result
[docs]
def send(self, event, *args, **kwargs):
"""Send an :ref:`Event` to the state machine.
.. seealso::
See: :ref:`triggering events`.
"""
event = Event(event)
return event.trigger(self, *args, **kwargs)
def _callbacks(self, meta_list: CallbackMetaList) -> CallbacksExecutor:
return self._callbacks_registry[meta_list]