Source code for statemachine.contrib.diagram

import importlib
import sys
from urllib.parse import quote
from urllib.request import urlopen

import pydot

from ..statemachine import StateChart


class DotGraphMachine:
    graph_rankdir = "LR"
    """
    Direction of the graph. Defaults to "LR" (option "TB" for top bottom)
    http://www.graphviz.org/doc/info/attrs.html#d:rankdir
    """

    font_name = "Arial"
    """Graph font face name"""

    state_font_size = "10pt"
    """State font size"""

    state_active_penwidth = 2
    """Active state external line width"""

    state_active_fillcolor = "turquoise"

    transition_font_size = "9pt"
    """Transition font size"""

    def __init__(self, machine):
        self.machine = machine

    def _get_graph(self, machine):
        return pydot.Dot(
            machine.name,
            graph_type="digraph",
            label=machine.name,
            fontname=self.font_name,
            fontsize=self.state_font_size,
            rankdir=self.graph_rankdir,
            compound="true",
        )

    def _get_subgraph(self, state):
        style = ", solid"
        if state.parent and state.parent.parallel:
            style = ", dashed"
        label = state.name
        if state.parallel:
            label = f"<<b>{state.name}</b> &#9783;>"
        subgraph = pydot.Subgraph(
            label=label,
            graph_name=f"cluster_{state.id}",
            style=f"rounded{style}",
            cluster="true",
        )
        return subgraph

    def _initial_node(self, state):
        node = pydot.Node(
            self._state_id(state),
            label="",
            shape="point",
            style="filled",
            fontsize="1pt",
            fixedsize="true",
            width=0.2,
            height=0.2,
        )
        node.set_fillcolor("black")  # type: ignore[attr-defined]
        return node

    def _initial_edge(self, initial_node, state):
        extra_params = {}
        if state.states:
            extra_params["lhead"] = f"cluster_{state.id}"
        return pydot.Edge(
            initial_node.get_name(),
            self._state_id(state),
            label="",
            color="blue",
            fontname=self.font_name,
            fontsize=self.transition_font_size,
            **extra_params,
        )

    def _actions_getter(self):
        if isinstance(self.machine, StateChart):

            def getter(grouper):  # pyright: ignore[reportRedeclaration]
                return self.machine._callbacks.str(grouper.key)
        else:

            def getter(grouper):
                all_names = set(dir(self.machine))
                return ", ".join(
                    str(c) for c in grouper if not c.is_convention or c.func in all_names
                )

        return getter

    def _state_actions(self, state):
        getter = self._actions_getter()

        entry = str(getter(state.enter))
        exit_ = str(getter(state.exit))
        internal = ", ".join(
            f"{transition.event} / {str(getter(transition.on))}"
            for transition in state.transitions
            if transition.internal
        )

        if entry:
            entry = f"entry / {entry}"
        if exit_:
            exit_ = f"exit / {exit_}"

        actions = "\n".join(x for x in [entry, exit_, internal] if x)

        if actions:
            actions = f"\n{actions}"

        return actions

    @staticmethod
    def _state_id(state):
        if state.states:
            return f"{state.id}_anchor"
        else:
            return state.id

    def _history_node(self, state):
        label = "H*" if state.type.is_deep else "H"
        return pydot.Node(
            self._state_id(state),
            label=label,
            shape="circle",
            style="filled",
            fillcolor="white",
            fontname=self.font_name,
            fontsize="8pt",
            fixedsize="true",
            width=0.3,
            height=0.3,
        )

    def _state_as_node(self, state):
        actions = self._state_actions(state)

        node = pydot.Node(
            self._state_id(state),
            label=f"{state.name}{actions}",
            shape="rectangle",
            style="rounded, filled",
            fontname=self.font_name,
            fontsize=self.state_font_size,
            peripheries=2 if state.final else 1,
        )
        if (
            isinstance(self.machine, StateChart)
            and state.value in self.machine.configuration_values
        ):
            node.set_penwidth(self.state_active_penwidth)  # type: ignore[attr-defined]
            node.set_fillcolor(self.state_active_fillcolor)  # type: ignore[attr-defined]
        else:
            node.set_fillcolor("white")  # type: ignore[attr-defined]
        return node

    def _transition_as_edges(self, transition):
        targets = transition.targets if transition.targets else [None]
        cond = ", ".join([str(c) for c in transition.cond])
        if cond:
            cond = f"\n[{cond}]"

        edges = []
        for i, target in enumerate(targets):
            extra_params = {}
            has_substates = transition.source.states or (target and target.states)
            if transition.source.states:
                extra_params["ltail"] = f"cluster_{transition.source.id}"
            if target and target.states:
                extra_params["lhead"] = f"cluster_{target.id}"

            targetless = target is None
            label = f"{transition.event}{cond}" if i == 0 else ""
            dst = self._state_id(target) if not targetless else self._state_id(transition.source)
            edges.append(
                pydot.Edge(
                    self._state_id(transition.source),
                    dst,
                    label=label,
                    color="blue",
                    fontname=self.font_name,
                    fontsize=self.transition_font_size,
                    minlen=2 if has_substates else 1,
                    **extra_params,
                )
            )
        return edges

    def get_graph(self):
        graph = self._get_graph(self.machine)
        self._graph_states(self.machine, graph)
        return graph

    def _add_transitions(self, graph, state):
        for transition in state.transitions:
            if transition.internal:
                continue
            for edge in self._transition_as_edges(transition):
                graph.add_edge(edge)

    def _graph_states(self, state, graph):
        initial_node = self._initial_node(state)
        initial_subgraph = pydot.Subgraph(
            graph_name=f"{initial_node.get_name()}_initial",
            label="",
            peripheries=0,
            margin=0,
        )
        atomic_states_subgraph = pydot.Subgraph(
            graph_name=f"cluster_{initial_node.get_name()}_atomic",
            label="",
            peripheries=0,
            cluster="true",
        )
        initial_subgraph.add_node(initial_node)
        graph.add_subgraph(initial_subgraph)
        graph.add_subgraph(atomic_states_subgraph)

        if state.states and not getattr(state, "parallel", False):
            initial = next((s for s in state.states if s.initial), None)
            if initial:  # pragma: no branch
                graph.add_edge(self._initial_edge(initial_node, initial))

        for substate in state.states:
            if substate.states:
                subgraph = self._get_subgraph(substate)
                self._graph_states(substate, subgraph)
                graph.add_subgraph(subgraph)
            else:
                atomic_states_subgraph.add_node(self._state_as_node(substate))
            self._add_transitions(graph, substate)

        for history_state in getattr(state, "history", []):
            atomic_states_subgraph.add_node(self._history_node(history_state))
            self._add_transitions(graph, history_state)

    def __call__(self):
        return self.get_graph()


[docs] def quickchart_write_svg(sm: StateChart, path: str): """ If the default dependency of GraphViz installed locally doesn't work for you. As an option, you can generate the image online from the output of the `dot` language, using one of the many services available. To get the **dot** representation of your state machine is as easy as follows: >>> from tests.examples.order_control_machine import OrderControl >>> sm = OrderControl() >>> print(sm._graph().to_string()) digraph OrderControl { compound=true; fontname=Arial; fontsize="10pt"; label=OrderControl; rankdir=LR; ... To give you an example, we included this method that will serialize the dot, request the graph to https://quickchart.io, and persist the result locally as an ``.svg`` file. .. warning:: Quickchart is an external graph service that supports many formats to generate diagrams. By using this method, you should trust http://quickchart.io. Please read https://quickchart.io/documentation/faq/ for more information. >>> quickchart_write_svg(sm, "docs/images/oc_machine_processing.svg") # doctest: +SKIP """ dot_representation = sm._graph().to_string() url = f"https://quickchart.io/graphviz?graph={quote(dot_representation)}" response = urlopen(url) data = response.read() with open(path, "wb") as f: f.write(data)
def _find_sm_class(module): """Find the first StateChart subclass defined in a module.""" import inspect for _name, obj in inspect.getmembers(module, inspect.isclass): if ( issubclass(obj, StateChart) and obj is not StateChart and obj.__module__ == module.__name__ ): return obj return None def import_sm(qualname): module_name, class_name = qualname.rsplit(".", 1) module = importlib.import_module(module_name) smclass = getattr(module, class_name, None) if smclass is not None and isinstance(smclass, type) and issubclass(smclass, StateChart): return smclass # qualname may be a module path without a class name — try importing # the whole path as a module and find the first StateChart subclass. try: module = importlib.import_module(qualname) except ImportError as err: raise ValueError(f"{class_name} is not a subclass of StateMachine") from err smclass = _find_sm_class(module) if smclass is None: raise ValueError(f"No StateMachine subclass found in module {qualname!r}") return smclass def write_image(qualname, out): """ Given a `qualname`, that is the fully qualified dotted path to a StateMachine classes, imports the class and generates a dot graph using the `pydot` lib. Writes the graph representation to the filename 'out' that will open/create and truncate such file and write on it a representation of the graph defined by the statemachine, in the format specified by the extension contained in the out path (out.ext). """ smclass = import_sm(qualname) graph = DotGraphMachine(smclass).get_graph() out_extension = out.rsplit(".", 1)[1] graph.write(out, format=out_extension) def main(argv=None): import argparse parser = argparse.ArgumentParser( usage="%(prog)s [OPTION] <class_path> <out>", description="Generate diagrams for StateMachine classes.", ) parser.add_argument( "class_path", help="A fully-qualified dotted path to the StateMachine class." ) parser.add_argument( "out", help="File to generate the image using extension as the output format.", ) args = parser.parse_args(argv) write_image(qualname=args.class_path, out=args.out) if __name__ == "__main__": # pragma: no cover sys.exit(main())