• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12880

pending completion
#12880

push

travis-ci

web-flow
Merge pull request #1067 from CBroz1/master

Add support for insert CSV

4 of 4 new or added lines in 1 file covered. (100.0%)

3102 of 3424 relevant lines covered (90.6%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

83.33
/datajoint/diagram.py
1
import networkx as nx
1✔
2
import re
1✔
3
import functools
1✔
4
import io
1✔
5
import logging
1✔
6
import inspect
1✔
7
from .table import Table
1✔
8
from .dependencies import unite_master_parts
1✔
9
from .user_tables import Manual, Imported, Computed, Lookup, Part
1✔
10
from .errors import DataJointError
1✔
11
from .table import lookup_class_name
1✔
12

13

14
try:
1✔
15
    from matplotlib import pyplot as plt
1✔
16

17
    plot_active = True
1✔
18
except:
×
19
    plot_active = False
×
20

21
try:
1✔
22
    from networkx.drawing.nx_pydot import pydot_layout
1✔
23

24
    diagram_active = True
1✔
25
except:
×
26
    diagram_active = False
×
27

28

29
logger = logging.getLogger(__name__.split(".")[0])
1✔
30
user_table_classes = (Manual, Lookup, Computed, Imported, Part)
1✔
31

32

33
class _AliasNode:
1✔
34
    """
35
    special class to indicate aliased foreign keys
36
    """
37

38
    pass
1✔
39

40

41
def _get_tier(table_name):
1✔
42
    if not table_name.startswith("`"):
1✔
43
        return _AliasNode
1✔
44
    else:
45
        try:
1✔
46
            return next(
1✔
47
                tier
48
                for tier in user_table_classes
49
                if re.fullmatch(tier.tier_regexp, table_name.split("`")[-2])
50
            )
51
        except StopIteration:
×
52
            return None
×
53

54

55
if not diagram_active:
1✔
56

57
    class Diagram:
×
58
        """
59
        Entity relationship diagram, currently disabled due to the lack of required packages: matplotlib and pygraphviz.
60

61
        To enable Diagram feature, please install both matplotlib and pygraphviz. For instructions on how to install
62
        these two packages, refer to http://docs.datajoint.io/setup/Install-and-connect.html#python and
63
        http://tutorials.datajoint.io/setting-up/datajoint-python.html
64
        """
65

66
        def __init__(self, *args, **kwargs):
×
67
            logger.warning(
×
68
                "Please install matplotlib and pygraphviz libraries to enable the Diagram feature."
69
            )
70

71
else:
72

73
    class Diagram(nx.DiGraph):
1✔
74
        """
75
        Entity relationship diagram.
76

77
        Usage:
78

79
        >>>  diag = Diagram(source)
80

81
        source can be a base table object, a base table class, a schema, or a module that has a schema.
82

83
        >>> diag.draw()
84

85
        draws the diagram using pyplot
86

87
        diag1 + diag2  - combines the two diagrams.
88
        diag + n   - expands n levels of successors
89
        diag - n   - expands n levels of predecessors
90
        Thus dj.Diagram(schema.Table)+1-1 defines the diagram of immediate ancestors and descendants of schema.Table
91

92
        Note that diagram + 1 - 1  may differ from diagram - 1 + 1 and so forth.
93
        Only those tables that are loaded in the connection object are displayed
94
        """
95

96
        def __init__(self, source, context=None):
1✔
97

98
            if isinstance(source, Diagram):
1✔
99
                # copy constructor
100
                self.nodes_to_show = set(source.nodes_to_show)
1✔
101
                self.context = source.context
1✔
102
                super().__init__(source)
1✔
103
                return
1✔
104

105
            # get the caller's context
106
            if context is None:
1✔
107
                frame = inspect.currentframe().f_back
1✔
108
                self.context = dict(frame.f_globals, **frame.f_locals)
1✔
109
                del frame
1✔
110
            else:
111
                self.context = context
1✔
112

113
            # find connection in the source
114
            try:
1✔
115
                connection = source.connection
1✔
116
            except AttributeError:
1✔
117
                try:
1✔
118
                    connection = source.schema.connection
1✔
119
                except AttributeError:
×
120
                    raise DataJointError(
×
121
                        "Could not find database connection in %s" % repr(source[0])
122
                    )
123

124
            # initialize graph from dependencies
125
            connection.dependencies.load()
1✔
126
            super().__init__(connection.dependencies)
1✔
127

128
            # Enumerate nodes from all the items in the list
129
            self.nodes_to_show = set()
1✔
130
            try:
1✔
131
                self.nodes_to_show.add(source.full_table_name)
1✔
132
            except AttributeError:
1✔
133
                try:
1✔
134
                    database = source.database
1✔
135
                except AttributeError:
1✔
136
                    try:
1✔
137
                        database = source.schema.database
1✔
138
                    except AttributeError:
×
139
                        raise DataJointError(
×
140
                            "Cannot plot Diagram for %s" % repr(source)
141
                        )
142
                for node in self:
1✔
143
                    if node.startswith("`%s`" % database):
1✔
144
                        self.nodes_to_show.add(node)
1✔
145

146
        @classmethod
1✔
147
        def from_sequence(cls, sequence):
1✔
148
            """
149
            The join Diagram for all objects in sequence
150

151
            :param sequence: a sequence (e.g. list, tuple)
152
            :return: Diagram(arg1) + ... + Diagram(argn)
153
            """
154
            return functools.reduce(lambda x, y: x + y, map(Diagram, sequence))
×
155

156
        def add_parts(self):
1✔
157
            """
158
            Adds to the diagram the part tables of tables already included in the diagram
159
            :return:
160
            """
161

162
            def is_part(part, master):
1✔
163
                """
164
                :param part:  `database`.`table_name`
165
                :param master:   `database`.`table_name`
166
                :return: True if part is part of master.
167
                """
168
                part = [s.strip("`") for s in part.split(".")]
1✔
169
                master = [s.strip("`") for s in master.split(".")]
1✔
170
                return (
1✔
171
                    master[0] == part[0]
172
                    and master[1] + "__" == part[1][: len(master[1]) + 2]
173
                )
174

175
            self = Diagram(self)  # copy
1✔
176
            self.nodes_to_show.update(
1✔
177
                n
178
                for n in self.nodes()
179
                if any(is_part(n, m) for m in self.nodes_to_show)
180
            )
181
            return self
1✔
182

183
        def topological_sort(self):
1✔
184
            """:return:  list of nodes in topological order"""
185
            return unite_master_parts(
1✔
186
                list(
187
                    nx.algorithms.dag.topological_sort(
188
                        nx.DiGraph(self).subgraph(self.nodes_to_show)
189
                    )
190
                )
191
            )
192

193
        def __add__(self, arg):
1✔
194
            """
195
            :param arg: either another Diagram or a positive integer.
196
            :return: Union of the diagrams when arg is another Diagram
197
                     or an expansion downstream when arg is a positive integer.
198
            """
199
            self = Diagram(self)  # copy
1✔
200
            try:
1✔
201
                self.nodes_to_show.update(arg.nodes_to_show)
1✔
202
            except AttributeError:
1✔
203
                try:
1✔
204
                    self.nodes_to_show.add(arg.full_table_name)
1✔
205
                except AttributeError:
1✔
206
                    for i in range(arg):
1✔
207
                        new = nx.algorithms.boundary.node_boundary(
1✔
208
                            self, self.nodes_to_show
209
                        )
210
                        if not new:
1✔
211
                            break
1✔
212
                        # add nodes referenced by aliased nodes
213
                        new.update(
1✔
214
                            nx.algorithms.boundary.node_boundary(
215
                                self, (a for a in new if a.isdigit())
216
                            )
217
                        )
218
                        self.nodes_to_show.update(new)
1✔
219
            return self
1✔
220

221
        def __sub__(self, arg):
1✔
222
            """
223
            :param arg: either another Diagram or a positive integer.
224
            :return: Difference of the diagrams when arg is another Diagram or
225
                     an expansion upstream when arg is a positive integer.
226
            """
227
            self = Diagram(self)  # copy
1✔
228
            try:
1✔
229
                self.nodes_to_show.difference_update(arg.nodes_to_show)
1✔
230
            except AttributeError:
1✔
231
                try:
1✔
232
                    self.nodes_to_show.remove(arg.full_table_name)
1✔
233
                except AttributeError:
1✔
234
                    for i in range(arg):
1✔
235
                        graph = nx.DiGraph(self).reverse()
1✔
236
                        new = nx.algorithms.boundary.node_boundary(
1✔
237
                            graph, self.nodes_to_show
238
                        )
239
                        if not new:
1✔
240
                            break
1✔
241
                        # add nodes referenced by aliased nodes
242
                        new.update(
1✔
243
                            nx.algorithms.boundary.node_boundary(
244
                                graph, (a for a in new if a.isdigit())
245
                            )
246
                        )
247
                        self.nodes_to_show.update(new)
1✔
248
            return self
1✔
249

250
        def __mul__(self, arg):
1✔
251
            """
252
            Intersection of two diagrams
253
            :param arg: another Diagram
254
            :return: a new Diagram comprising nodes that are present in both operands.
255
            """
256
            self = Diagram(self)  # copy
1✔
257
            self.nodes_to_show.intersection_update(arg.nodes_to_show)
1✔
258
            return self
1✔
259

260
        def _make_graph(self):
1✔
261
            """
262
            Make the self.graph - a graph object ready for drawing
263
            """
264
            # mark "distinguished" tables, i.e. those that introduce new primary key
265
            # attributes
266
            for name in self.nodes_to_show:
1✔
267
                foreign_attributes = set(
1✔
268
                    attr
269
                    for p in self.in_edges(name, data=True)
270
                    for attr in p[2]["attr_map"]
271
                    if p[2]["primary"]
272
                )
273
                self.nodes[name]["distinguished"] = (
1✔
274
                    "primary_key" in self.nodes[name]
275
                    and foreign_attributes < self.nodes[name]["primary_key"]
276
                )
277
            # include aliased nodes that are sandwiched between two displayed nodes
278
            gaps = set(
1✔
279
                nx.algorithms.boundary.node_boundary(self, self.nodes_to_show)
280
            ).intersection(
281
                nx.algorithms.boundary.node_boundary(
282
                    nx.DiGraph(self).reverse(), self.nodes_to_show
283
                )
284
            )
285
            nodes = self.nodes_to_show.union(a for a in gaps if a.isdigit)
1✔
286
            # construct subgraph and rename nodes to class names
287
            graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
1✔
288
            nx.set_node_attributes(
1✔
289
                graph, name="node_type", values={n: _get_tier(n) for n in graph}
290
            )
291
            # relabel nodes to class names
292
            mapping = {
1✔
293
                node: lookup_class_name(node, self.context) or node
294
                for node in graph.nodes()
295
            }
296
            new_names = [mapping.values()]
1✔
297
            if len(new_names) > len(set(new_names)):
1✔
298
                raise DataJointError(
×
299
                    "Some classes have identical names. The Diagram cannot be plotted."
300
                )
301
            nx.relabel_nodes(graph, mapping, copy=False)
1✔
302
            return graph
1✔
303

304
        def make_dot(self):
1✔
305

306
            graph = self._make_graph()
1✔
307
            graph.nodes()
1✔
308

309
            scale = 1.2  # scaling factor for fonts and boxes
1✔
310
            label_props = {  # http://matplotlib.org/examples/color/named_colors.html
1✔
311
                None: dict(
312
                    shape="circle",
313
                    color="#FFFF0040",
314
                    fontcolor="yellow",
315
                    fontsize=round(scale * 8),
316
                    size=0.4 * scale,
317
                    fixed=False,
318
                ),
319
                _AliasNode: dict(
320
                    shape="circle",
321
                    color="#FF880080",
322
                    fontcolor="#FF880080",
323
                    fontsize=round(scale * 0),
324
                    size=0.05 * scale,
325
                    fixed=True,
326
                ),
327
                Manual: dict(
328
                    shape="box",
329
                    color="#00FF0030",
330
                    fontcolor="darkgreen",
331
                    fontsize=round(scale * 10),
332
                    size=0.4 * scale,
333
                    fixed=False,
334
                ),
335
                Lookup: dict(
336
                    shape="plaintext",
337
                    color="#00000020",
338
                    fontcolor="black",
339
                    fontsize=round(scale * 8),
340
                    size=0.4 * scale,
341
                    fixed=False,
342
                ),
343
                Computed: dict(
344
                    shape="ellipse",
345
                    color="#FF000020",
346
                    fontcolor="#7F0000A0",
347
                    fontsize=round(scale * 10),
348
                    size=0.3 * scale,
349
                    fixed=True,
350
                ),
351
                Imported: dict(
352
                    shape="ellipse",
353
                    color="#00007F40",
354
                    fontcolor="#00007FA0",
355
                    fontsize=round(scale * 10),
356
                    size=0.4 * scale,
357
                    fixed=False,
358
                ),
359
                Part: dict(
360
                    shape="plaintext",
361
                    color="#0000000",
362
                    fontcolor="black",
363
                    fontsize=round(scale * 8),
364
                    size=0.1 * scale,
365
                    fixed=False,
366
                ),
367
            }
368
            node_props = {
1✔
369
                node: label_props[d["node_type"]]
370
                for node, d in dict(graph.nodes(data=True)).items()
371
            }
372

373
            dot = nx.drawing.nx_pydot.to_pydot(graph)
1✔
374
            for node in dot.get_nodes():
1✔
375
                node.set_shape("circle")
1✔
376
                name = node.get_name().strip('"')
1✔
377
                props = node_props[name]
1✔
378
                node.set_fontsize(props["fontsize"])
1✔
379
                node.set_fontcolor(props["fontcolor"])
1✔
380
                node.set_shape(props["shape"])
1✔
381
                node.set_fontname("arial")
1✔
382
                node.set_fixedsize("shape" if props["fixed"] else False)
1✔
383
                node.set_width(props["size"])
1✔
384
                node.set_height(props["size"])
1✔
385
                if name.split(".")[0] in self.context:
1✔
386
                    cls = eval(name, self.context)
1✔
387
                    assert issubclass(cls, Table)
1✔
388
                    description = (
1✔
389
                        cls().describe(context=self.context, printout=False).split("\n")
390
                    )
391
                    description = (
1✔
392
                        "-" * 30
393
                        if q.startswith("---")
394
                        else q.replace("->", "&#8594;")
395
                        if "->" in q
396
                        else q.split(":")[0]
397
                        for q in description
398
                        if not q.startswith("#")
399
                    )
400
                    node.set_tooltip("&#13;".join(description))
1✔
401
                node.set_label(
1✔
402
                    "<<u>" + name + "</u>>"
403
                    if node.get("distinguished") == "True"
404
                    else name
405
                )
406
                node.set_color(props["color"])
1✔
407
                node.set_style("filled")
1✔
408

409
            for edge in dot.get_edges():
1✔
410
                # see https://graphviz.org/doc/info/attrs.html
411
                src = edge.get_source().strip('"')
1✔
412
                dest = edge.get_destination().strip('"')
1✔
413
                props = graph.get_edge_data(src, dest)
1✔
414
                edge.set_color("#00000040")
1✔
415
                edge.set_style("solid" if props["primary"] else "dashed")
1✔
416
                master_part = graph.nodes[dest][
1✔
417
                    "node_type"
418
                ] is Part and dest.startswith(src + ".")
419
                edge.set_weight(3 if master_part else 1)
1✔
420
                edge.set_arrowhead("none")
1✔
421
                edge.set_penwidth(0.75 if props["multi"] else 2)
1✔
422

423
            return dot
1✔
424

425
        def make_svg(self):
1✔
426
            from IPython.display import SVG
1✔
427

428
            return SVG(self.make_dot().create_svg())
1✔
429

430
        def make_png(self):
1✔
431
            return io.BytesIO(self.make_dot().create_png())
1✔
432

433
        def make_image(self):
1✔
434
            if plot_active:
1✔
435
                return plt.imread(self.make_png())
1✔
436
            else:
437
                raise DataJointError("pyplot was not imported")
×
438

439
        def _repr_svg_(self):
1✔
440
            return self.make_svg()._repr_svg_()
1✔
441

442
        def draw(self):
1✔
443
            if plot_active:
×
444
                plt.imshow(self.make_image())
×
445
                plt.gca().axis("off")
×
446
                plt.show()
×
447
            else:
448
                raise DataJointError("pyplot was not imported")
×
449

450
        def save(self, filename, format=None):
1✔
451
            if format is None:
×
452
                if filename.lower().endswith(".png"):
×
453
                    format = "png"
×
454
                elif filename.lower().endswith(".svg"):
×
455
                    format = "svg"
×
456
            if format.lower() == "png":
×
457
                with open(filename, "wb") as f:
×
458
                    f.write(self.make_png().getbuffer().tobytes())
×
459
            elif format.lower() == "svg":
×
460
                with open(filename, "w") as f:
×
461
                    f.write(self.make_svg().data)
×
462
            else:
463
                raise DataJointError("Unsupported file format")
×
464

465
        @staticmethod
1✔
466
        def _layout(graph, **kwargs):
1✔
467
            return pydot_layout(graph, prog="dot", **kwargs)
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc