• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

pytransitions / transitions / 8938991339

03 May 2024 12:30PM UTC coverage: 98.432% (+0.2%) from 98.217%
8938991339

push

github

aleneum
use coverage only for mypy job and update setup.py tags

5149 of 5231 relevant lines covered (98.43%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.02
/transitions/extensions/diagrams_graphviz.py
1
"""
1✔
2
    transitions.extensions.diagrams
3
    -------------------------------
4

5
    Graphviz support for (nested) machines. This also includes partial views
6
    of currently valid transitions.
7
"""
8
import copy
1✔
9
import logging
1✔
10
from functools import partial
1✔
11
from collections import defaultdict
1✔
12
from os.path import splitext
1✔
13

14
try:
1✔
15
    import graphviz as pgv
1✔
16
except ImportError:
×
17
    pgv = None
×
18

19
from ..core import listify
1✔
20
from .diagrams_base import BaseGraph
1✔
21

22
_LOGGER = logging.getLogger(__name__)
1✔
23
_LOGGER.addHandler(logging.NullHandler())
1✔
24

25

26
class Graph(BaseGraph):
1✔
27
    """Graph creation for transitions.core.Machine.
1✔
28
        Attributes:
29
            custom_styles (dict): A dictionary of styles for the current graph
30
    """
31

32
    def __init__(self, machine):
1✔
33
        self.custom_styles = {}
1✔
34
        self.reset_styling()
1✔
35
        super(Graph, self).__init__(machine)
1✔
36

37
    def set_previous_transition(self, src, dst):
1✔
38
        self.custom_styles["edge"][src][dst] = "previous"
1✔
39
        self.set_node_style(src, "previous")
1✔
40

41
    def set_node_style(self, state, style):
1✔
42
        self.custom_styles["node"][state.name if hasattr(state, "name") else state] = style
1✔
43

44
    def reset_styling(self):
1✔
45
        self.custom_styles = {
1✔
46
            "edge": defaultdict(lambda: defaultdict(str)),
47
            "node": defaultdict(str),
48
        }
49

50
    def _add_nodes(self, states, container):
1✔
51
        for state in states:
1✔
52
            style = self.custom_styles["node"][state["name"]]
1✔
53
            container.node(
1✔
54
                state["name"],
55
                label=self._convert_state_attributes(state),
56
                **self.machine.style_attributes["node"][style]
57
            )
58

59
    def _add_edges(self, transitions, container):
1✔
60
        edge_labels = defaultdict(lambda: defaultdict(list))
1✔
61
        for transition in transitions:
1✔
62
            try:
1✔
63
                dst = transition["dest"]
1✔
64
            except KeyError:
1✔
65
                dst = transition["source"]
1✔
66
            edge_labels[transition["source"]][dst].append(self._transition_label(transition))
1✔
67
        for src, dests in edge_labels.items():
1✔
68
            for dst, labels in dests.items():
1✔
69
                style = self.custom_styles["edge"][src][dst]
1✔
70
                container.edge(
1✔
71
                    src,
72
                    dst,
73
                    label=" | ".join(labels),
74
                    **self.machine.style_attributes["edge"][style]
75
                )
76

77
    def generate(self):
1✔
78
        """Triggers the generation of a graph. With graphviz backend, this does nothing since graph trees need to be
79
        build from scratch with the configured styles.
80
        """
81
        if not pgv:  # pragma: no cover
82
            raise Exception("AGraph diagram requires graphviz")
83
        # we cannot really generate a graph in advance with graphviz
84

85
    def get_graph(self, title=None, roi_state=None):
1✔
86
        title = title if title else self.machine.title
1✔
87

88
        fsm_graph = pgv.Digraph(
1✔
89
            name=title,
90
            node_attr=self.machine.style_attributes["node"]["default"],
91
            edge_attr=self.machine.style_attributes["edge"]["default"],
92
            graph_attr=self.machine.style_attributes["graph"]["default"],
93
        )
94
        fsm_graph.graph_attr.update(**self.machine.machine_attributes)
1✔
95
        fsm_graph.graph_attr["label"] = title
1✔
96
        # For each state, draw a circle
97
        states, transitions = self._get_elements()
1✔
98
        if roi_state:
1✔
99
            active_states = set()
1✔
100
            sep = getattr(self.machine.state_cls, "separator", None)
1✔
101
            for state in self._flatten(roi_state):
1✔
102
                active_states.add(state)
1✔
103
                if sep:
1✔
104
                    state = sep.join(state.split(sep)[:-1])
1✔
105
                    while state:
1✔
106
                        active_states.add(state)
1✔
107
                        state = sep.join(state.split(sep)[:-1])
1✔
108
            transitions = [
1✔
109
                t
110
                for t in transitions
111
                if t["source"] in active_states or self.custom_styles["edge"][t["source"]][t["dest"]]
112
            ]
113
            active_states = active_states.union({
1✔
114
                t
115
                for trans in transitions
116
                for t in [trans["source"], trans.get("dest", trans["source"])]
117
            })
118
            active_states = active_states.union({k for k, style in self.custom_styles["node"].items() if style})
1✔
119
            states = _filter_states(copy.deepcopy(states), active_states, self.machine.state_cls)
1✔
120
        self._add_nodes(states, fsm_graph)
1✔
121
        self._add_edges(transitions, fsm_graph)
1✔
122
        setattr(fsm_graph, "draw", partial(self.draw, fsm_graph))
1✔
123
        return fsm_graph
1✔
124

125
    # pylint: disable=redefined-builtin,unused-argument
126
    def draw(self, graph, filename, format=None, prog="dot", args=""):
1✔
127
        """
128
        Generates and saves an image of the state machine using graphviz. Note that `prog` and `args` are only part
129
        of the signature to mimic `Agraph.draw` and thus allow to easily switch between graph backends.
130
        Args:
131
            filename (str or file descriptor or stream or None): path and name of image output, file descriptor,
132
            stream object or None
133
            format (str): Optional format of the output file
134
            prog (str): ignored
135
            args (str): ignored
136
        Returns:
137
            None or str: Returns a binary string of the graph when the first parameter (`filename`) is set to None.
138
        """
139
        graph.engine = prog
1✔
140
        if filename is None:
1✔
141
            if format is None:
1✔
142
                raise ValueError(
×
143
                    "Parameter 'format' must not be None when filename is no valid file path."
144
                )
145
            return graph.pipe(format)
1✔
146
        try:
1✔
147
            filename, ext = splitext(filename)
1✔
148
            format = format if format is not None else ext[1:]
1✔
149
            graph.render(filename, format=format if format else "png", cleanup=True)
1✔
150
        except (TypeError, AttributeError):
1✔
151
            if format is None:
1✔
152
                raise ValueError(
×
153
                    "Parameter 'format' must not be None when filename is no valid file path."
154
                )  # from None
155
            filename.write(graph.pipe(format))
1✔
156
        return None
1✔
157

158

159
class NestedGraph(Graph):
1✔
160
    """Graph creation support for transitions.extensions.nested.HierarchicalGraphMachine."""
1✔
161

162
    def __init__(self, *args, **kwargs):
1✔
163
        self._cluster_states = []
1✔
164
        super(NestedGraph, self).__init__(*args, **kwargs)
1✔
165

166
    def set_node_style(self, state, style):
1✔
167
        for state_name in self._get_state_names(state):
1✔
168
            super(NestedGraph, self).set_node_style(state_name, style)
1✔
169

170
    def set_previous_transition(self, src, dst):
1✔
171
        src_name = self._get_global_name(src.split(self.machine.state_cls.separator))
1✔
172
        dst_name = self._get_global_name(dst.split(self.machine.state_cls.separator))
1✔
173
        super(NestedGraph, self).set_previous_transition(src_name, dst_name)
1✔
174

175
    def _add_nodes(self, states, container):
1✔
176
        self._add_nested_nodes(states, container, prefix="", default_style="default")
1✔
177

178
    def _add_nested_nodes(self, states, container, prefix, default_style):
1✔
179
        for state in states:
1✔
180
            name = prefix + state["name"]
1✔
181
            label = self._convert_state_attributes(state)
1✔
182
            if state.get("children", None) is not None:
1✔
183
                cluster_name = "cluster_" + name
1✔
184
                attr = {"label": label, "rank": "source"}
1✔
185
                attr.update(
1✔
186
                    **self.machine.style_attributes["graph"][
187
                        self.custom_styles["node"][name] or default_style
188
                    ]
189
                )
190
                with container.subgraph(name=cluster_name, graph_attr=attr) as sub:
1✔
191
                    self._cluster_states.append(name)
1✔
192
                    is_parallel = isinstance(state.get("initial", ""), list)
1✔
193
                    with sub.subgraph(
1✔
194
                        name=cluster_name + "_root",
195
                        graph_attr={"label": "", "color": "None", "rank": "min"},
196
                    ) as root:
197
                        root.node(
1✔
198
                            name,
199
                            shape="point",
200
                            fillcolor="black",
201
                            width="0.0" if is_parallel else "0.1",
202
                        )
203
                    self._add_nested_nodes(
1✔
204
                        state["children"],
205
                        sub,
206
                        default_style="parallel" if is_parallel else "default",
207
                        prefix=prefix + state["name"] + self.machine.state_cls.separator,
208
                    )
209
            else:
210
                style = self.machine.style_attributes["node"][default_style].copy()
1✔
211
                style.update(
1✔
212
                    self.machine.style_attributes["node"][
213
                        self.custom_styles["node"][name] or default_style
214
                    ]
215
                )
216
                container.node(name, label=label, **style)
1✔
217

218
    def _add_edges(self, transitions, container):
1✔
219
        edges_attr = defaultdict(lambda: defaultdict(dict))
1✔
220

221
        for transition in transitions:
1✔
222
            # enable customizable labels
223
            src = transition["source"]
1✔
224
            try:
1✔
225
                dst = transition["dest"]
1✔
226
            except KeyError:
1✔
227
                dst = src
1✔
228
            if edges_attr[src][dst]:
1✔
229
                attr = edges_attr[src][dst]
1✔
230
                attr[attr["label_pos"]] = " | ".join(
1✔
231
                    [edges_attr[src][dst][attr["label_pos"]], self._transition_label(transition)]
232
                )
233
            else:
234
                edges_attr[src][dst] = self._create_edge_attr(src, dst, transition)
1✔
235

236
        for custom_src, dests in self.custom_styles["edge"].items():
1✔
237
            for custom_dst, style in dests.items():
1✔
238
                if style and (
1✔
239
                    custom_src not in edges_attr or custom_dst not in edges_attr[custom_src]
240
                ):
241
                    edges_attr[custom_src][custom_dst] = self._create_edge_attr(
×
242
                        custom_src, custom_dst, {"trigger": "", "dest": ""}
243
                    )
244

245
        for src, dests in edges_attr.items():
1✔
246
            for dst, attr in dests.items():
1✔
247
                del attr["label_pos"]
1✔
248
                style = self.custom_styles["edge"][src][dst]
1✔
249
                attr.update(**self.machine.style_attributes["edge"][style])
1✔
250
                container.edge(attr.pop("source"), attr.pop("dest"), **attr)
1✔
251

252
    def _create_edge_attr(self, src, dst, transition):
1✔
253
        label_pos = "label"
1✔
254
        attr = {}
1✔
255
        if src in self._cluster_states:
1✔
256
            attr["ltail"] = "cluster_" + src
1✔
257
            label_pos = "headlabel"
1✔
258
        src_name = src
1✔
259

260
        if dst in self._cluster_states:
1✔
261
            if not src.startswith(dst):
1✔
262
                attr["lhead"] = "cluster_" + dst
1✔
263
                label_pos = "taillabel" if label_pos.startswith("l") else "label"
1✔
264
        dst_name = dst
1✔
265

266
        # remove ltail when dst (ltail always starts with 'cluster_') is a child of src
267
        if "ltail" in attr and dst_name.startswith(attr["ltail"][8:]):
1✔
268
            del attr["ltail"]
1✔
269

270
        attr[label_pos] = self._transition_label(transition)
1✔
271
        attr["label_pos"] = label_pos
1✔
272
        attr["source"] = src_name
1✔
273
        attr["dest"] = dst_name
1✔
274
        return attr
1✔
275

276

277
def _filter_states(states, state_names, state_cls, prefix=None):
1✔
278
    prefix = prefix or []
1✔
279
    result = []
1✔
280
    for state in states:
1✔
281
        pref = prefix + [state["name"]]
1✔
282
        included = getattr(state_cls, "separator", "_").join(pref) in state_names
1✔
283
        if "children" in state:
1✔
284
            state["children"] = _filter_states(
1✔
285
                state["children"], state_names, state_cls, prefix=pref
286
            )
287
            if state["children"] or included:
1✔
288
                result.append(state)
1✔
289
        elif included:
1✔
290
            result.append(state)
1✔
291
    return result
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc