• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12898

pending completion
#12898

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">715ab4055<a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/0a4f193031d8b1e14b09ec62d83c5def3b7421b0">0a4f19303</a>">0a4f19303</a><a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9"> into 3b6e84588">3b6e84588</a>

69 of 69 new or added lines in 9 files covered. (100.0%)

3052 of 3381 relevant lines covered (90.27%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.8
/datajoint/heading.py
1
import numpy as np
1✔
2
from collections import namedtuple, defaultdict
1✔
3
from itertools import chain
1✔
4
import re
1✔
5
import logging
1✔
6
from .errors import DataJointError, _support_filepath_types, FILEPATH_FEATURE_SWITCH
1✔
7
from .declare import (
1✔
8
    UUID_DATA_TYPE,
9
    SPECIAL_TYPES,
10
    TYPE_PATTERN,
11
    EXTERNAL_TYPES,
12
    NATIVE_TYPES,
13
)
14
from .attribute_adapter import get_adapter, AttributeAdapter
1✔
15

16

17
logger = logging.getLogger(__name__.split(".")[0])
1✔
18

19
default_attribute_properties = (
1✔
20
    dict(  # these default values are set in computed attributes
21
        name=None,
22
        type="expression",
23
        in_key=False,
24
        nullable=False,
25
        default=None,
26
        comment="calculated attribute",
27
        autoincrement=False,
28
        numeric=None,
29
        string=None,
30
        uuid=False,
31
        json=None,
32
        is_blob=False,
33
        is_attachment=False,
34
        is_filepath=False,
35
        is_external=False,
36
        adapter=None,
37
        store=None,
38
        unsupported=False,
39
        attribute_expression=None,
40
        database=None,
41
        dtype=object,
42
    )
43
)
44

45

46
class Attribute(namedtuple("_Attribute", default_attribute_properties)):
1✔
47
    """
48
    Properties of a table column (attribute)
49
    """
50

51
    def todict(self):
1✔
52
        """Convert namedtuple to dict."""
53
        return dict((name, self[i]) for i, name in enumerate(self._fields))
1✔
54

55
    @property
1✔
56
    def sql_type(self):
57
        """:return: datatype (as string) in database. In most cases, it is the same as self.type"""
58
        return UUID_DATA_TYPE if self.uuid else self.type
1✔
59

60
    @property
1✔
61
    def sql_comment(self):
62
        """:return: full comment for the SQL declaration. Includes custom type specification"""
63
        return (":uuid:" if self.uuid else "") + self.comment
1✔
64

65
    @property
1✔
66
    def sql(self):
67
        """
68
        Convert primary key attribute tuple into its SQL CREATE TABLE clause.
69
        Default values are not reflected.
70
        This is used for declaring foreign keys in referencing tables
71

72
        :return: SQL code for attribute declaration
73
        """
74
        return '`{name}` {type} NOT NULL COMMENT "{comment}"'.format(
1✔
75
            name=self.name, type=self.sql_type, comment=self.sql_comment
76
        )
77

78
    @property
1✔
79
    def original_name(self):
80
        if self.attribute_expression is None:
1✔
81
            return self.name
1✔
82
        assert self.attribute_expression.startswith("`")
1✔
83
        return self.attribute_expression.strip("`")
1✔
84

85

86
class Heading:
1✔
87
    """
88
    Local class for table headings.
89
    Heading contains the property attributes, which is an dict in which the keys are
90
    the attribute names and the values are Attributes.
91
    """
92

93
    def __init__(self, attribute_specs=None, table_info=None):
1✔
94
        """
95

96
        :param attribute_specs: a list of dicts with the same keys as Attribute
97
        :param table_info: a dict with information to load the heading from the database
98
        """
99
        self.indexes = None
1✔
100
        self.table_info = table_info
1✔
101
        self._table_status = None
1✔
102
        self._attributes = (
1✔
103
            None
104
            if attribute_specs is None
105
            else dict((q["name"], Attribute(**q)) for q in attribute_specs)
106
        )
107

108
    def __len__(self):
1✔
109
        return 0 if self.attributes is None else len(self.attributes)
1✔
110

111
    @property
1✔
112
    def table_status(self):
113
        if self.table_info is None:
1✔
114
            return None
×
115
        if self._table_status is None:
1✔
116
            self._init_from_database()
1✔
117
        return self._table_status
1✔
118

119
    @property
1✔
120
    def attributes(self):
121
        if self._attributes is None:
1✔
122
            self._init_from_database()  # lazy loading from database
1✔
123
        return self._attributes
1✔
124

125
    @property
1✔
126
    def names(self):
127
        return [k for k in self.attributes]
1✔
128

129
    @property
1✔
130
    def primary_key(self):
131
        return [k for k, v in self.attributes.items() if v.in_key]
1✔
132

133
    @property
1✔
134
    def secondary_attributes(self):
135
        return [k for k, v in self.attributes.items() if not v.in_key]
1✔
136

137
    @property
1✔
138
    def blobs(self):
139
        return [k for k, v in self.attributes.items() if v.is_blob]
×
140

141
    @property
1✔
142
    def non_blobs(self):
143
        return [
1✔
144
            k
145
            for k, v in self.attributes.items()
146
            if not (v.is_blob or v.is_attachment or v.is_filepath or v.json)
147
        ]
148

149
    @property
1✔
150
    def new_attributes(self):
151
        return [
1✔
152
            k for k, v in self.attributes.items() if v.attribute_expression is not None
153
        ]
154

155
    def __getitem__(self, name):
1✔
156
        """shortcut to the attribute"""
157
        return self.attributes[name]
1✔
158

159
    def __repr__(self):
1✔
160
        """
161
        :return:  heading representation in DataJoint declaration format but without foreign key expansion
162
        """
163
        in_key = True
1✔
164
        ret = ""
1✔
165
        if self._table_status is not None:
1✔
166
            ret += "# " + self.table_status["comment"] + "\n"
×
167
        for v in self.attributes.values():
1✔
168
            if in_key and not v.in_key:
1✔
169
                ret += "---\n"
1✔
170
                in_key = False
1✔
171
            ret += "%-20s : %-28s # %s\n" % (
1✔
172
                v.name if v.default is None else "%s=%s" % (v.name, v.default),
173
                "%s%s" % (v.type, "auto_increment" if v.autoincrement else ""),
174
                v.comment,
175
            )
176
        return ret
1✔
177

178
    @property
1✔
179
    def has_autoincrement(self):
180
        return any(e.autoincrement for e in self.attributes.values())
1✔
181

182
    @property
1✔
183
    def as_dtype(self):
184
        """
185
        represent the heading as a numpy dtype
186
        """
187
        return np.dtype(
1✔
188
            dict(names=self.names, formats=[v.dtype for v in self.attributes.values()])
189
        )
190

191
    def as_sql(self, fields, include_aliases=True):
1✔
192
        """
193
        represent heading as the SQL SELECT clause.
194
        """
195
        return ",".join(
1✔
196
            "`%s`" % name
197
            if self.attributes[name].attribute_expression is None
198
            else self.attributes[name].attribute_expression
199
            + (" as `%s`" % name if include_aliases else "")
200
            for name in fields
201
        )
202

203
    def __iter__(self):
1✔
204
        return iter(self.attributes)
1✔
205

206
    def _init_from_database(self):
1✔
207
        """initialize heading from an existing database table."""
208
        conn, database, table_name, context = (
1✔
209
            self.table_info[k] for k in ("conn", "database", "table_name", "context")
210
        )
211
        info = conn.query(
1✔
212
            'SHOW TABLE STATUS FROM `{database}` WHERE name="{table_name}"'.format(
213
                table_name=table_name, database=database
214
            ),
215
            as_dict=True,
216
        ).fetchone()
217
        if info is None:
1✔
218
            if table_name == "~log":
1✔
219
                logger.warning("Could not create the ~log table")
×
220
                return
×
221
            raise DataJointError(
1✔
222
                "The table `{database}`.`{table_name}` is not defined.".format(
223
                    table_name=table_name, database=database
224
                )
225
            )
226
        self._table_status = {k.lower(): v for k, v in info.items()}
1✔
227
        cur = conn.query(
1✔
228
            "SHOW FULL COLUMNS FROM `{table_name}` IN `{database}`".format(
229
                table_name=table_name, database=database
230
            ),
231
            as_dict=True,
232
        )
233

234
        attributes = cur.fetchall()
1✔
235

236
        rename_map = {
1✔
237
            "Field": "name",
238
            "Type": "type",
239
            "Null": "nullable",
240
            "Default": "default",
241
            "Key": "in_key",
242
            "Comment": "comment",
243
        }
244

245
        fields_to_drop = ("Privileges", "Collation")
1✔
246

247
        # rename and drop attributes
248
        attributes = [
1✔
249
            {
250
                rename_map[k] if k in rename_map else k: v
251
                for k, v in x.items()
252
                if k not in fields_to_drop
253
            }
254
            for x in attributes
255
        ]
256
        numeric_types = {
1✔
257
            ("float", False): np.float64,
258
            ("float", True): np.float64,
259
            ("double", False): np.float64,
260
            ("double", True): np.float64,
261
            ("tinyint", False): np.int64,
262
            ("tinyint", True): np.int64,
263
            ("smallint", False): np.int64,
264
            ("smallint", True): np.int64,
265
            ("mediumint", False): np.int64,
266
            ("mediumint", True): np.int64,
267
            ("int", False): np.int64,
268
            ("int", True): np.int64,
269
            ("bigint", False): np.int64,
270
            ("bigint", True): np.uint64,
271
        }
272

273
        sql_literals = ["CURRENT_TIMESTAMP"]
1✔
274

275
        # additional attribute properties
276
        for attr in attributes:
1✔
277

278
            attr.update(
1✔
279
                in_key=(attr["in_key"] == "PRI"),
280
                database=database,
281
                nullable=attr["nullable"] == "YES",
282
                autoincrement=bool(
283
                    re.search(r"auto_increment", attr["Extra"], flags=re.I)
284
                ),
285
                numeric=any(
286
                    TYPE_PATTERN[t].match(attr["type"])
287
                    for t in ("DECIMAL", "INTEGER", "FLOAT")
288
                ),
289
                string=any(
290
                    TYPE_PATTERN[t].match(attr["type"])
291
                    for t in ("ENUM", "TEMPORAL", "STRING")
292
                ),
293
                is_blob=bool(TYPE_PATTERN["INTERNAL_BLOB"].match(attr["type"])),
294
                uuid=False,
295
                json=bool(TYPE_PATTERN["JSON"].match(attr["type"])),
296
                is_attachment=False,
297
                is_filepath=False,
298
                adapter=None,
299
                store=None,
300
                is_external=False,
301
                attribute_expression=None,
302
            )
303

304
            if any(TYPE_PATTERN[t].match(attr["type"]) for t in ("INTEGER", "FLOAT")):
1✔
305
                attr["type"] = re.sub(
1✔
306
                    r"\(\d+\)", "", attr["type"], count=1
307
                )  # strip size off integers and floats
308
            attr["unsupported"] = not any(
1✔
309
                (attr["is_blob"], attr["numeric"], attr["numeric"])
310
            )
311
            attr.pop("Extra")
1✔
312

313
            # process custom DataJoint types
314
            special = re.match(r":(?P<type>[^:]+):(?P<comment>.*)", attr["comment"])
1✔
315
            if special:
1✔
316
                special = special.groupdict()
1✔
317
                attr.update(special)
1✔
318
            # process adapted attribute types
319
            if special and TYPE_PATTERN["ADAPTED"].match(attr["type"]):
1✔
320
                assert context is not None, "Declaration context is not set"
1✔
321
                adapter_name = special["type"]
1✔
322
                try:
1✔
323
                    attr.update(adapter=get_adapter(context, adapter_name))
1✔
324
                except DataJointError:
×
325
                    # if no adapter, then delay the error until the first invocation
326
                    attr.update(adapter=AttributeAdapter())
×
327
                else:
328
                    attr.update(type=attr["adapter"].attribute_type)
1✔
329
                    if not any(r.match(attr["type"]) for r in TYPE_PATTERN.values()):
1✔
330
                        raise DataJointError(
×
331
                            "Invalid attribute type '{type}' in adapter object <{adapter_name}>.".format(
332
                                adapter_name=adapter_name, **attr
333
                            )
334
                        )
335
                    special = not any(
1✔
336
                        TYPE_PATTERN[c].match(attr["type"]) for c in NATIVE_TYPES
337
                    )
338

339
            if special:
1✔
340
                try:
1✔
341
                    category = next(
1✔
342
                        c for c in SPECIAL_TYPES if TYPE_PATTERN[c].match(attr["type"])
343
                    )
344
                except StopIteration:
×
345
                    if attr["type"].startswith("external"):
×
346
                        url = (
×
347
                            "https://docs.datajoint.io/python/admin/5-blob-config.html"
348
                            "#migration-between-datajoint-v0-11-and-v0-12"
349
                        )
350
                        raise DataJointError(
×
351
                            "Legacy datatype `{type}`. Migrate your external stores to "
352
                            "datajoint 0.12: {url}".format(url=url, **attr)
353
                        )
354
                    raise DataJointError(
×
355
                        "Unknown attribute type `{type}`".format(**attr)
356
                    )
357
                if category == "FILEPATH" and not _support_filepath_types():
1✔
358
                    raise DataJointError(
×
359
                        """
360
                        The filepath data type is disabled until complete validation.
361
                        To turn it on as experimental feature, set the environment variable
362
                        {env} = TRUE or upgrade datajoint.
363
                        """.format(
364
                            env=FILEPATH_FEATURE_SWITCH
365
                        )
366
                    )
367
                attr.update(
1✔
368
                    unsupported=False,
369
                    is_attachment=category in ("INTERNAL_ATTACH", "EXTERNAL_ATTACH"),
370
                    is_filepath=category == "FILEPATH",
371
                    # INTERNAL_BLOB is not a custom type but is included for completeness
372
                    is_blob=category in ("INTERNAL_BLOB", "EXTERNAL_BLOB"),
373
                    uuid=category == "UUID",
374
                    is_external=category in EXTERNAL_TYPES,
375
                    store=attr["type"].split("@")[1]
376
                    if category in EXTERNAL_TYPES
377
                    else None,
378
                )
379

380
            if attr["in_key"] and any(
1✔
381
                (
382
                    attr["is_blob"],
383
                    attr["is_attachment"],
384
                    attr["is_filepath"],
385
                    attr["json"],
386
                )
387
            ):
388
                raise DataJointError(
×
389
                    "Json, Blob, attachment, or filepath attributes are not allowed in the primary key"
390
                )
391

392
            if (
1✔
393
                attr["string"]
394
                and attr["default"] is not None
395
                and attr["default"] not in sql_literals
396
            ):
397
                attr["default"] = '"%s"' % attr["default"]
1✔
398

399
            if attr["nullable"]:  # nullable fields always default to null
1✔
400
                attr["default"] = "null"
1✔
401

402
            # fill out dtype. All floats and non-nullable integers are turned into specific dtypes
403
            attr["dtype"] = object
1✔
404
            if attr["numeric"] and not attr["adapter"]:
1✔
405
                is_integer = TYPE_PATTERN["INTEGER"].match(attr["type"])
1✔
406
                is_float = TYPE_PATTERN["FLOAT"].match(attr["type"])
1✔
407
                if is_integer and not attr["nullable"] or is_float:
1✔
408
                    is_unsigned = bool(re.match("sunsigned", attr["type"], flags=re.I))
1✔
409
                    t = re.sub(r"\(.*\)", "", attr["type"])  # remove parentheses
1✔
410
                    t = re.sub(r" unsigned$", "", t)  # remove unsigned
1✔
411
                    assert (t, is_unsigned) in numeric_types, (
1✔
412
                        "dtype not found for type %s" % t
413
                    )
414
                    attr["dtype"] = numeric_types[(t, is_unsigned)]
1✔
415

416
            if attr["adapter"]:
1✔
417
                # restore adapted type name
418
                attr["type"] = adapter_name
1✔
419

420
        self._attributes = dict(((q["name"], Attribute(**q)) for q in attributes))
1✔
421

422
        # Read and tabulate secondary indexes
423
        keys = defaultdict(dict)
1✔
424
        for item in conn.query(
1✔
425
            "SHOW KEYS FROM `{db}`.`{tab}`".format(db=database, tab=table_name),
426
            as_dict=True,
427
        ):
428
            if item["Key_name"] != "PRIMARY":
1✔
429
                keys[item["Key_name"]][item["Seq_in_index"]] = dict(
1✔
430
                    column=item["Column_name"]
431
                    or f"({item['Expression']})".replace(r"\'", "'"),
432
                    unique=(item["Non_unique"] == 0),
433
                    nullable=item["Null"].lower() == "yes",
434
                )
435
        self.indexes = {
1✔
436
            tuple(item[k]["column"] for k in sorted(item.keys())): dict(
437
                unique=item[1]["unique"],
438
                nullable=any(v["nullable"] for v in item.values()),
439
            )
440
            for item in keys.values()
441
        }
442

443
    def select(self, select_list, rename_map=None, compute_map=None):
1✔
444
        """
445
        derive a new heading by selecting, renaming, or computing attributes.
446
        In relational algebra these operators are known as project, rename, and extend.
447

448
        :param select_list:  the full list of existing attributes to include
449
        :param rename_map:  dictionary of renamed attributes: keys=new names, values=old names
450
        :param compute_map: a direction of computed attributes
451
        This low-level method performs no error checking.
452
        """
453
        rename_map = rename_map or {}
1✔
454
        compute_map = compute_map or {}
1✔
455
        copy_attrs = list()
1✔
456
        for name in self.attributes:
1✔
457
            if name in select_list:
1✔
458
                copy_attrs.append(self.attributes[name].todict())
1✔
459
            copy_attrs.extend(
1✔
460
                (
461
                    dict(
462
                        self.attributes[old_name].todict(),
463
                        name=new_name,
464
                        attribute_expression="`%s`" % old_name,
465
                    )
466
                    for new_name, old_name in rename_map.items()
467
                    if old_name == name
468
                )
469
            )
470
        compute_attrs = (
1✔
471
            dict(default_attribute_properties, name=new_name, attribute_expression=expr)
472
            for new_name, expr in compute_map.items()
473
        )
474
        return Heading(chain(copy_attrs, compute_attrs))
1✔
475

476
    def join(self, other):
1✔
477
        """
478
        Join two headings into a new one.
479
        It assumes that self and other are headings that share no common dependent attributes.
480
        """
481
        return Heading(
1✔
482
            [self.attributes[name].todict() for name in self.primary_key]
483
            + [
484
                other.attributes[name].todict()
485
                for name in other.primary_key
486
                if name not in self.primary_key
487
            ]
488
            + [
489
                self.attributes[name].todict()
490
                for name in self.secondary_attributes
491
                if name not in other.primary_key
492
            ]
493
            + [
494
                other.attributes[name].todict()
495
                for name in other.secondary_attributes
496
                if name not in self.primary_key
497
            ]
498
        )
499

500
    def set_primary_key(self, primary_key):
1✔
501
        """
502
        Create a new heading with the specified primary key.
503
        This low-level method performs no error checking.
504
        """
505
        return Heading(
1✔
506
            chain(
507
                (
508
                    dict(self.attributes[name].todict(), in_key=True)
509
                    for name in primary_key
510
                ),
511
                (
512
                    dict(self.attributes[name].todict(), in_key=False)
513
                    for name in self.names
514
                    if name not in primary_key
515
                ),
516
            )
517
        )
518

519
    def make_subquery_heading(self):
1✔
520
        """
521
        Create a new heading with removed attribute sql_expressions.
522
        Used by subqueries, which resolve the sql_expressions.
523
        """
524
        return Heading(
1✔
525
            dict(v.todict(), attribute_expression=None)
526
            for v in self.attributes.values()
527
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc