• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12898

pending completion
#12898

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">715ab4055<a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/0a4f193031d8b1e14b09ec62d83c5def3b7421b0">0a4f19303</a>">0a4f19303</a><a href="https://github.com/datajoint/datajoint-python/commit/715ab40552f63cd79723ed2830c6691b2cb228b9"> into 3b6e84588">3b6e84588</a>

69 of 69 new or added lines in 9 files covered. (100.0%)

3052 of 3381 relevant lines covered (90.27%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.98
/datajoint/table.py
1
import collections
1✔
2
import itertools
1✔
3
import inspect
1✔
4
import platform
1✔
5
import numpy as np
1✔
6
import pandas
1✔
7
import logging
1✔
8
import uuid
1✔
9
import csv
1✔
10
import re
1✔
11
import json
1✔
12
from pathlib import Path
1✔
13
from .settings import config
1✔
14
from .declare import declare, alter
1✔
15
from .condition import make_condition
1✔
16
from .expression import QueryExpression
1✔
17
from . import blob
1✔
18
from .utils import user_choice, get_master
1✔
19
from .heading import Heading
1✔
20
from .errors import (
1✔
21
    DuplicateError,
22
    AccessError,
23
    DataJointError,
24
    UnknownAttributeError,
25
    IntegrityError,
26
)
27
from typing import Union
1✔
28
from .version import __version__ as version
1✔
29

30
logger = logging.getLogger(__name__.split(".")[0])
1✔
31

32
foreign_key_error_regexp = re.compile(
1✔
33
    r"[\w\s:]*\((?P<child>`[^`]+`.`[^`]+`), "
34
    r"CONSTRAINT (?P<name>`[^`]+`) "
35
    r"(FOREIGN KEY \((?P<fk_attrs>[^)]+)\) "
36
    r"REFERENCES (?P<parent>`[^`]+`(\.`[^`]+`)?) \((?P<pk_attrs>[^)]+)\)[\s\w]+\))?"
37
)
38

39
constraint_info_query = " ".join(
1✔
40
    """
41
    SELECT
42
        COLUMN_NAME as fk_attrs,
43
        CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent,
44
        REFERENCED_COLUMN_NAME as pk_attrs
45
    FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
46
    WHERE
47
        CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s;
48
    """.split()
49
)
50

51

52
class _RenameMap(tuple):
1✔
53
    """for internal use"""
54

55
    pass
1✔
56

57

58
class Table(QueryExpression):
1✔
59
    """
60
    Table is an abstract class that represents a table in the schema.
61
    It implements insert and delete methods and inherits query functionality.
62
    To make it a concrete class, override the abstract properties specifying the connection,
63
    table name, database, and definition.
64
    """
65

66
    _table_name = None  # must be defined in subclass
1✔
67
    _log_ = None  # placeholder for the Log table object
1✔
68

69
    # These properties must be set by the schema decorator (schemas.py) at class level
70
    # or by FreeTable at instance level
71
    database = None
1✔
72
    declaration_context = None
1✔
73

74
    @property
1✔
75
    def table_name(self):
76
        return self._table_name
1✔
77

78
    @property
1✔
79
    def definition(self):
80
        raise NotImplementedError(
×
81
            "Subclasses of Table must implement the `definition` property"
82
        )
83

84
    def declare(self, context=None):
1✔
85
        """
86
        Declare the table in the schema based on self.definition.
87

88
        :param context: the context for foreign key resolution. If None, foreign keys are
89
            not allowed.
90
        """
91
        if self.connection.in_transaction:
1✔
92
            raise DataJointError(
×
93
                "Cannot declare new tables inside a transaction, "
94
                "e.g. from inside a populate/make call"
95
            )
96
        sql, external_stores = declare(self.full_table_name, self.definition, context)
1✔
97
        sql = sql.format(database=self.database)
1✔
98
        try:
1✔
99
            # declare all external tables before declaring main table
100
            for store in external_stores:
1✔
101
                self.connection.schemas[self.database].external[store]
1✔
102
            self.connection.query(sql)
1✔
103
        except AccessError:
1✔
104
            # skip if no create privilege
105
            pass
1✔
106
        else:
107
            self._log("Declared " + self.full_table_name)
1✔
108

109
    def alter(self, prompt=True, context=None):
1✔
110
        """
111
        Alter the table definition from self.definition
112
        """
113
        if self.connection.in_transaction:
1✔
114
            raise DataJointError(
×
115
                "Cannot update table declaration inside a transaction, "
116
                "e.g. from inside a populate/make call"
117
            )
118
        if context is None:
1✔
119
            frame = inspect.currentframe().f_back
1✔
120
            context = dict(frame.f_globals, **frame.f_locals)
1✔
121
            del frame
1✔
122
        old_definition = self.describe(context=context)
1✔
123
        sql, external_stores = alter(self.definition, old_definition, context)
1✔
124
        if not sql:
1✔
125
            if prompt:
×
126
                logger.warn("Nothing to alter.")
×
127
        else:
128
            sql = "ALTER TABLE {tab}\n\t".format(
1✔
129
                tab=self.full_table_name
130
            ) + ",\n\t".join(sql)
131
            if not prompt or user_choice(sql + "\n\nExecute?") == "yes":
1✔
132
                try:
1✔
133
                    # declare all external tables before declaring main table
134
                    for store in external_stores:
1✔
135
                        self.connection.schemas[self.database].external[store]
×
136
                    self.connection.query(sql)
1✔
137
                except AccessError:
×
138
                    # skip if no create privilege
139
                    pass
×
140
                else:
141
                    # reset heading
142
                    self.__class__._heading = Heading(
1✔
143
                        table_info=self.heading.table_info
144
                    )
145
                    if prompt:
1✔
146
                        logger.info("Table altered")
×
147
                    self._log("Altered " + self.full_table_name)
1✔
148

149
    def from_clause(self):
1✔
150
        """
151
        :return: the FROM clause of SQL SELECT statements.
152
        """
153
        return self.full_table_name
1✔
154

155
    def get_select_fields(self, select_fields=None):
1✔
156
        """
157
        :return: the selected attributes from the SQL SELECT statement.
158
        """
159
        return (
×
160
            "*" if select_fields is None else self.heading.project(select_fields).as_sql
161
        )
162

163
    def parents(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
164
        """
165

166
        :param primary: if None, then all parents are returned. If True, then only foreign keys composed of
167
            primary key attributes are considered.  If False, return foreign keys including at least one
168
            secondary attribute.
169
        :param as_objects: if False, return table names. If True, return table objects.
170
        :param foreign_key_info: if True, each element in result also includes foreign key info.
171
        :return: list of parents as table names or table objects
172
            with (optional) foreign key information.
173
        """
174
        get_edge = self.connection.dependencies.parents
1✔
175
        nodes = [
1✔
176
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
177
            for name, props in get_edge(self.full_table_name, primary).items()
178
        ]
179
        if as_objects:
1✔
180
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
181
        if not foreign_key_info:
1✔
182
            nodes = [name for name, props in nodes]
1✔
183
        return nodes
1✔
184

185
    def children(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
186
        """
187

188
        :param primary: if None, then all children are returned. If True, then only foreign keys composed of
189
            primary key attributes are considered.  If False, return foreign keys including at least one
190
            secondary attribute.
191
        :param as_objects: if False, return table names. If True, return table objects.
192
        :param foreign_key_info: if True, each element in result also includes foreign key info.
193
        :return: list of children as table names or table objects
194
            with (optional) foreign key information.
195
        """
196
        get_edge = self.connection.dependencies.children
1✔
197
        nodes = [
1✔
198
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
199
            for name, props in get_edge(self.full_table_name, primary).items()
200
        ]
201
        if as_objects:
1✔
202
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
203
        if not foreign_key_info:
1✔
204
            nodes = [name for name, props in nodes]
1✔
205
        return nodes
1✔
206

207
    def descendants(self, as_objects=False):
1✔
208
        """
209

210
        :param as_objects: False - a list of table names; True - a list of table objects.
211
        :return: list of tables descendants in topological order.
212
        """
213
        return [
1✔
214
            FreeTable(self.connection, node) if as_objects else node
215
            for node in self.connection.dependencies.descendants(self.full_table_name)
216
            if not node.isdigit()
217
        ]
218

219
    def ancestors(self, as_objects=False):
1✔
220
        """
221

222
        :param as_objects: False - a list of table names; True - a list of table objects.
223
        :return: list of tables ancestors in topological order.
224
        """
225
        return [
1✔
226
            FreeTable(self.connection, node) if as_objects else node
227
            for node in self.connection.dependencies.ancestors(self.full_table_name)
228
            if not node.isdigit()
229
        ]
230

231
    def parts(self, as_objects=False):
1✔
232
        """
233
        return part tables either as entries in a dict with foreign key informaiton or a list of objects
234

235
        :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
236
        """
237
        nodes = [
1✔
238
            node
239
            for node in self.connection.dependencies.nodes
240
            if not node.isdigit() and node.startswith(self.full_table_name[:-1] + "__")
241
        ]
242
        return [FreeTable(self.connection, c) for c in nodes] if as_objects else nodes
1✔
243

244
    @property
1✔
245
    def is_declared(self):
246
        """
247
        :return: True is the table is declared in the schema.
248
        """
249
        return (
1✔
250
            self.connection.query(
251
                'SHOW TABLES in `{database}` LIKE "{table_name}"'.format(
252
                    database=self.database, table_name=self.table_name
253
                )
254
            ).rowcount
255
            > 0
256
        )
257

258
    @property
1✔
259
    def full_table_name(self):
260
        """
261
        :return: full table name in the schema
262
        """
263
        return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name)
1✔
264

265
    @property
1✔
266
    def _log(self):
267
        if self._log_ is None:
1✔
268
            self._log_ = Log(
1✔
269
                self.connection,
270
                database=self.database,
271
                skip_logging=self.table_name.startswith("~"),
272
            )
273
        return self._log_
1✔
274

275
    @property
1✔
276
    def external(self):
277
        return self.connection.schemas[self.database].external
1✔
278

279
    def update1(self, row):
1✔
280
        """
281
        ``update1`` updates one existing entry in the table.
282
        Caution: In DataJoint the primary modes for data manipulation is to ``insert`` and
283
        ``delete`` entire records since referential integrity works on the level of records,
284
        not fields. Therefore, updates are reserved for corrective operations outside of main
285
        workflow. Use UPDATE methods sparingly with full awareness of potential violations of
286
        assumptions.
287

288
        :param row: a ``dict`` containing the primary key values and the attributes to update.
289
            Setting an attribute value to None will reset it to the default value (if any).
290

291
        The primary key attributes must always be provided.
292

293
        Examples:
294

295
        >>> table.update1({'id': 1, 'value': 3})  # update value in record with id=1
296
        >>> table.update1({'id': 1, 'value': None})  # reset value to default
297
        """
298
        # argument validations
299
        if not isinstance(row, collections.abc.Mapping):
1✔
300
            raise DataJointError("The argument of update1 must be dict-like.")
×
301
        if not set(row).issuperset(self.primary_key):
1✔
302
            raise DataJointError(
1✔
303
                "The argument of update1 must supply all primary key values."
304
            )
305
        try:
1✔
306
            raise DataJointError(
1✔
307
                "Attribute `%s` not found."
308
                % next(k for k in row if k not in self.heading.names)
309
            )
310
        except StopIteration:
1✔
311
            pass  # ok
1✔
312
        if len(self.restriction):
1✔
313
            raise DataJointError("Update cannot be applied to a restricted table.")
×
314
        key = {k: row[k] for k in self.primary_key}
1✔
315
        if len(self & key) != 1:
1✔
316
            raise DataJointError("Update can only be applied to one existing entry.")
1✔
317
        # UPDATE query
318
        row = [
1✔
319
            self.__make_placeholder(k, v)
320
            for k, v in row.items()
321
            if k not in self.primary_key
322
        ]
323
        query = "UPDATE {table} SET {assignments} WHERE {where}".format(
1✔
324
            table=self.full_table_name,
325
            assignments=",".join("`%s`=%s" % r[:2] for r in row),
326
            where=make_condition(self, key, set()),
327
        )
328
        self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))
1✔
329

330
    def insert1(self, row, **kwargs):
1✔
331
        """
332
        Insert one data record into the table. For ``kwargs``, see ``insert()``.
333

334
        :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted
335
            as one row.
336
        """
337
        self.insert((row,), **kwargs)
1✔
338

339
    def insert(
1✔
340
        self,
341
        rows,
342
        replace=False,
343
        skip_duplicates=False,
344
        ignore_extra_fields=False,
345
        allow_direct_insert=None,
346
    ):
347
        """
348
        Insert a collection of rows.
349

350
        :param rows: Either (a) an iterable where an element is a numpy record, a
351
            dict-like object, a pandas.DataFrame, a sequence, or a query expression with
352
            the same heading as self, or (b) a pathlib.Path object specifying a path
353
            relative to the current directory with a CSV file, the contents of which
354
            will be inserted.
355
        :param replace: If True, replaces the existing tuple.
356
        :param skip_duplicates: If True, silently skip duplicate inserts.
357
        :param ignore_extra_fields: If False, fields that are not in the heading raise error.
358
        :param allow_direct_insert: Only applies in auto-populated tables. If False (default),
359
            insert may only be called from inside the make callback.
360

361
        Example:
362

363
            >>> Table.insert([
364
            >>>     dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"),
365
            >>>     dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
366
        """
367
        if isinstance(rows, pandas.DataFrame):
1✔
368
            # drop 'extra' synthetic index for 1-field index case -
369
            # frames with more advanced indices should be prepared by user.
370
            rows = rows.reset_index(
1✔
371
                drop=len(rows.index.names) == 1 and not rows.index.names[0]
372
            ).to_records(index=False)
373

374
        if isinstance(rows, Path):
1✔
375
            with open(rows, newline="") as data_file:
1✔
376
                rows = list(csv.DictReader(data_file, delimiter=","))
1✔
377

378
        # prohibit direct inserts into auto-populated tables
379
        if not allow_direct_insert and not getattr(self, "_allow_insert", True):
1✔
380
            raise DataJointError(
1✔
381
                "Inserts into an auto-populated table can only be done inside "
382
                "its make method during a populate call."
383
                " To override, set keyword argument allow_direct_insert=True."
384
            )
385

386
        if inspect.isclass(rows) and issubclass(rows, QueryExpression):
1✔
387
            rows = rows()  # instantiate if a class
1✔
388
        if isinstance(rows, QueryExpression):
1✔
389
            # insert from select
390
            if not ignore_extra_fields:
1✔
391
                try:
1✔
392
                    raise DataJointError(
1✔
393
                        "Attribute %s not found. To ignore extra attributes in insert, "
394
                        "set ignore_extra_fields=True."
395
                        % next(
396
                            name for name in rows.heading if name not in self.heading
397
                        )
398
                    )
399
                except StopIteration:
1✔
400
                    pass
1✔
401
            fields = list(name for name in rows.heading if name in self.heading)
1✔
402
            query = "{command} INTO {table} ({fields}) {select}{duplicate}".format(
1✔
403
                command="REPLACE" if replace else "INSERT",
404
                fields="`" + "`,`".join(fields) + "`",
405
                table=self.full_table_name,
406
                select=rows.make_sql(fields),
407
                duplicate=(
408
                    " ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`".format(
409
                        table=self.full_table_name, pk=self.primary_key[0]
410
                    )
411
                    if skip_duplicates
412
                    else ""
413
                ),
414
            )
415
            self.connection.query(query)
1✔
416
            return
1✔
417

418
        field_list = []  # collects the field list from first row (passed by reference)
1✔
419
        rows = list(
1✔
420
            self.__make_row_to_insert(row, field_list, ignore_extra_fields)
421
            for row in rows
422
        )
423
        if rows:
1✔
424
            try:
1✔
425
                query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
1✔
426
                    command="REPLACE" if replace else "INSERT",
427
                    destination=self.from_clause(),
428
                    fields="`,`".join(field_list),
429
                    placeholders=",".join(
430
                        "(" + ",".join(row["placeholders"]) + ")" for row in rows
431
                    ),
432
                    duplicate=(
433
                        " ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(
434
                            pk=self.primary_key[0]
435
                        )
436
                        if skip_duplicates
437
                        else ""
438
                    ),
439
                )
440
                self.connection.query(
1✔
441
                    query,
442
                    args=list(
443
                        itertools.chain.from_iterable(
444
                            (v for v in r["values"] if v is not None) for r in rows
445
                        )
446
                    ),
447
                )
448
            except UnknownAttributeError as err:
1✔
449
                raise err.suggest(
×
450
                    "To ignore extra fields in insert, set ignore_extra_fields=True"
451
                )
452
            except DuplicateError as err:
1✔
453
                raise err.suggest(
1✔
454
                    "To ignore duplicate entries in insert, set skip_duplicates=True"
455
                )
456

457
    def delete_quick(self, get_count=False):
1✔
458
        """
459
        Deletes the table without cascading and without user prompt.
460
        If this table has populated dependent tables, this will fail.
461
        """
462
        query = "DELETE FROM " + self.full_table_name + self.where_clause()
1✔
463
        self.connection.query(query)
1✔
464
        count = (
1✔
465
            self.connection.query("SELECT ROW_COUNT()").fetchone()[0]
466
            if get_count
467
            else None
468
        )
469
        self._log(query[:255])
1✔
470
        return count
1✔
471

472
    def delete(
1✔
473
        self,
474
        transaction: bool = True,
475
        safemode: Union[bool, None] = None,
476
        force_parts: bool = False,
477
    ) -> int:
478
        """
479
        Deletes the contents of the table and its dependent tables, recursively.
480

481
        Args:
482
            transaction: If `True`, use of the entire delete becomes an atomic transaction.
483
                This is the default and recommended behavior. Set to `False` if this delete is
484
                nested within another transaction.
485
            safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
486
                is `dj.config['safemode']`.
487
            force_parts: Delete from parts even when not deleting from their masters.
488

489
        Returns:
490
            Number of deleted rows (excluding those from dependent tables).
491

492
        Raises:
493
            DataJointError: Delete exceeds maximum number of delete attempts.
494
            DataJointError: When deleting within an existing transaction.
495
            DataJointError: Deleting a part table before its master.
496
        """
497
        deleted = set()
1✔
498

499
        def cascade(table):
1✔
500
            """service function to perform cascading deletes recursively."""
501
            max_attempts = 50
1✔
502
            for _ in range(max_attempts):
1✔
503
                try:
1✔
504
                    delete_count = table.delete_quick(get_count=True)
1✔
505
                except IntegrityError as error:
1✔
506
                    match = foreign_key_error_regexp.match(error.args[0]).groupdict()
1✔
507
                    if "`.`" not in match["child"]:  # if schema name missing, use table
1✔
508
                        match["child"] = "{}.{}".format(
×
509
                            table.full_table_name.split(".")[0], match["child"]
510
                        )
511
                    if (
1✔
512
                        match["pk_attrs"] is not None
513
                    ):  # fully matched, adjusting the keys
514
                        match["fk_attrs"] = [
1✔
515
                            k.strip("`") for k in match["fk_attrs"].split(",")
516
                        ]
517
                        match["pk_attrs"] = [
1✔
518
                            k.strip("`") for k in match["pk_attrs"].split(",")
519
                        ]
520
                    else:  # only partially matched, querying with constraint to determine keys
521
                        match["fk_attrs"], match["parent"], match["pk_attrs"] = list(
1✔
522
                            map(
523
                                list,
524
                                zip(
525
                                    *table.connection.query(
526
                                        constraint_info_query,
527
                                        args=(
528
                                            match["name"].strip("`"),
529
                                            *[
530
                                                _.strip("`")
531
                                                for _ in match["child"].split("`.`")
532
                                            ],
533
                                        ),
534
                                    ).fetchall()
535
                                ),
536
                            )
537
                        )
538
                        match["parent"] = match["parent"][0]
1✔
539

540
                    # Restrict child by table if
541
                    #   1. if table's restriction attributes are not in child's primary key
542
                    #   2. if child renames any attributes
543
                    # Otherwise restrict child by table's restriction.
544
                    child = FreeTable(table.connection, match["child"])
1✔
545
                    if (
1✔
546
                        set(table.restriction_attributes) <= set(child.primary_key)
547
                        and match["fk_attrs"] == match["pk_attrs"]
548
                    ):
549
                        child._restriction = table._restriction
1✔
550
                    elif match["fk_attrs"] != match["pk_attrs"]:
1✔
551
                        child &= table.proj(
1✔
552
                            **dict(zip(match["fk_attrs"], match["pk_attrs"]))
553
                        )
554
                    else:
555
                        child &= table.proj()
1✔
556
                    cascade(child)
1✔
557
                else:
558
                    deleted.add(table.full_table_name)
1✔
559
                    logger.info(
1✔
560
                        "Deleting {count} rows from {table}".format(
561
                            count=delete_count, table=table.full_table_name
562
                        )
563
                    )
564
                    break
1✔
565
            else:
566
                raise DataJointError("Exceeded maximum number of delete attempts.")
×
567
            return delete_count
1✔
568

569
        safemode = config["safemode"] if safemode is None else safemode
1✔
570

571
        # Start transaction
572
        if transaction:
1✔
573
            if not self.connection.in_transaction:
1✔
574
                self.connection.start_transaction()
1✔
575
            else:
576
                if not safemode:
×
577
                    transaction = False
×
578
                else:
579
                    raise DataJointError(
×
580
                        "Delete cannot use a transaction within an ongoing transaction. "
581
                        "Set transaction=False or safemode=False)."
582
                    )
583

584
        # Cascading delete
585
        try:
1✔
586
            delete_count = cascade(self)
1✔
587
        except:
×
588
            if transaction:
×
589
                self.connection.cancel_transaction()
×
590
            raise
×
591

592
        if not force_parts:
1✔
593
            # Avoid deleting from child before master (See issue #151)
594
            for part in deleted:
1✔
595
                master = get_master(part)
1✔
596
                if master and master not in deleted:
1✔
597
                    if transaction:
1✔
598
                        self.connection.cancel_transaction()
1✔
599
                    raise DataJointError(
1✔
600
                        "Attempt to delete part table {part} before deleting from "
601
                        "its master {master} first.".format(part=part, master=master)
602
                    )
603

604
        # Confirm and commit
605
        if delete_count == 0:
1✔
606
            if safemode:
1✔
607
                logger.warn("Nothing to delete.")
×
608
            if transaction:
1✔
609
                self.connection.cancel_transaction()
1✔
610
        else:
611
            if not safemode or user_choice("Commit deletes?", default="no") == "yes":
1✔
612
                if transaction:
1✔
613
                    self.connection.commit_transaction()
1✔
614
                if safemode:
1✔
615
                    logger.info("Deletes committed.")
×
616
            else:
617
                if transaction:
×
618
                    self.connection.cancel_transaction()
×
619
                if safemode:
×
620
                    logger.warn("Deletes cancelled")
×
621
        return delete_count
1✔
622

623
    def drop_quick(self):
1✔
624
        """
625
        Drops the table without cascading to dependent tables and without user prompt.
626
        """
627
        if self.is_declared:
1✔
628
            query = "DROP TABLE %s" % self.full_table_name
1✔
629
            self.connection.query(query)
1✔
630
            logger.info("Dropped table %s" % self.full_table_name)
1✔
631
            self._log(query[:255])
1✔
632
        else:
633
            logger.info(
×
634
                "Nothing to drop: table %s is not declared" % self.full_table_name
635
            )
636

637
    def drop(self):
1✔
638
        """
639
        Drop the table and all tables that reference it, recursively.
640
        User is prompted for confirmation if config['safemode'] is set to True.
641
        """
642
        if self.restriction:
1✔
643
            raise DataJointError(
×
644
                "A table with an applied restriction cannot be dropped."
645
                " Call drop() on the unrestricted Table."
646
            )
647
        self.connection.dependencies.load()
1✔
648
        do_drop = True
1✔
649
        tables = [
1✔
650
            table
651
            for table in self.connection.dependencies.descendants(self.full_table_name)
652
            if not table.isdigit()
653
        ]
654

655
        # avoid dropping part tables without their masters: See issue #374
656
        for part in tables:
1✔
657
            master = get_master(part)
1✔
658
            if master and master not in tables:
1✔
659
                raise DataJointError(
1✔
660
                    "Attempt to drop part table {part} before dropping "
661
                    "its master. Drop {master} first.".format(part=part, master=master)
662
                )
663

664
        if config["safemode"]:
1✔
665
            for table in tables:
1✔
666
                logger.info(
1✔
667
                    table + " (%d tuples)" % len(FreeTable(self.connection, table))
668
                )
669
            do_drop = user_choice("Proceed?", default="no") == "yes"
1✔
670
        if do_drop:
1✔
671
            for table in reversed(tables):
1✔
672
                FreeTable(self.connection, table).drop_quick()
1✔
673
            logger.info("Tables dropped.  Restart kernel.")
1✔
674

675
    @property
1✔
676
    def size_on_disk(self):
677
        """
678
        :return: size of data and indices in bytes on the storage device
679
        """
680
        ret = self.connection.query(
1✔
681
            'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(
682
                database=self.database, table=self.table_name
683
            ),
684
            as_dict=True,
685
        ).fetchone()
686
        return ret["Data_length"] + ret["Index_length"]
1✔
687

688
    def show_definition(self):
1✔
689
        raise AttributeError(
×
690
            "show_definition is deprecated. Use the describe method instead."
691
        )
692

693
    def describe(self, context=None, printout=False):
1✔
694
        """
695
        :return:  the definition string for the query using DataJoint DDL.
696
        """
697
        if context is None:
1✔
698
            frame = inspect.currentframe().f_back
1✔
699
            context = dict(frame.f_globals, **frame.f_locals)
1✔
700
            del frame
1✔
701
        if self.full_table_name not in self.connection.dependencies:
1✔
702
            self.connection.dependencies.load()
1✔
703
        parents = self.parents(foreign_key_info=True)
1✔
704
        in_key = True
1✔
705
        definition = (
1✔
706
            "# " + self.heading.table_status["comment"] + "\n"
707
            if self.heading.table_status["comment"]
708
            else ""
709
        )
710
        attributes_thus_far = set()
1✔
711
        attributes_declared = set()
1✔
712
        indexes = self.heading.indexes.copy()
1✔
713
        for attr in self.heading.attributes.values():
1✔
714
            if in_key and not attr.in_key:
1✔
715
                definition += "---\n"
1✔
716
                in_key = False
1✔
717
            attributes_thus_far.add(attr.name)
1✔
718
            do_include = True
1✔
719
            for parent_name, fk_props in parents:
1✔
720
                if attr.name in fk_props["attr_map"]:
1✔
721
                    do_include = False
1✔
722
                    if attributes_thus_far.issuperset(fk_props["attr_map"]):
1✔
723
                        # foreign key properties
724
                        try:
1✔
725
                            index_props = indexes.pop(tuple(fk_props["attr_map"]))
1✔
726
                        except KeyError:
1✔
727
                            index_props = ""
1✔
728
                        else:
729
                            index_props = [k for k, v in index_props.items() if v]
1✔
730
                            index_props = (
1✔
731
                                " [{}]".format(", ".join(index_props))
732
                                if index_props
733
                                else ""
734
                            )
735

736
                        if not fk_props["aliased"]:
1✔
737
                            # simple foreign key
738
                            definition += "->{props} {class_name}\n".format(
1✔
739
                                props=index_props,
740
                                class_name=lookup_class_name(parent_name, context)
741
                                or parent_name,
742
                            )
743
                        else:
744
                            # projected foreign key
745
                            definition += (
1✔
746
                                "->{props} {class_name}.proj({proj_list})\n".format(
747
                                    props=index_props,
748
                                    class_name=lookup_class_name(parent_name, context)
749
                                    or parent_name,
750
                                    proj_list=",".join(
751
                                        '{}="{}"'.format(attr, ref)
752
                                        for attr, ref in fk_props["attr_map"].items()
753
                                        if ref != attr
754
                                    ),
755
                                )
756
                            )
757
                            attributes_declared.update(fk_props["attr_map"])
1✔
758
            if do_include:
1✔
759
                attributes_declared.add(attr.name)
1✔
760
                definition += "%-20s : %-28s %s\n" % (
1✔
761
                    attr.name
762
                    if attr.default is None
763
                    else "%s=%s" % (attr.name, attr.default),
764
                    "%s%s"
765
                    % (attr.type, " auto_increment" if attr.autoincrement else ""),
766
                    "# " + attr.comment if attr.comment else "",
767
                )
768
        # add remaining indexes
769
        for k, v in indexes.items():
1✔
770
            definition += "{unique}INDEX ({attrs})\n".format(
1✔
771
                unique="UNIQUE " if v["unique"] else "", attrs=", ".join(k)
772
            )
773
        if printout:
1✔
774
            logger.info("\n" + definition)
×
775
        return definition
1✔
776

777
    def _update(self, attrname, value=None):
1✔
778
        """
779
        This is a deprecated function to be removed in datajoint 0.14.
780
        Use ``.update1`` instead.
781

782
        Updates a field in one existing tuple. self must be restricted to exactly one entry.
783
        In DataJoint the principal way of updating data is to delete and re-insert the
784
        entire record and updates are reserved for corrective actions.
785
        This is because referential integrity is observed on the level of entire
786
        records rather than individual attributes.
787

788
        Safety constraints:
789
           1. self must be restricted to exactly one tuple
790
           2. the update attribute must not be in primary key
791

792
        Example:
793
        >>> (v2p.Mice() & key)._update('mouse_dob', '2011-01-01')
794
        >>> (v2p.Mice() & key)._update( 'lens')   # set the value to NULL
795
        """
796
        logger.warning(
1✔
797
            "`_update` is a deprecated function to be removed in datajoint 0.14. "
798
            "Use `.update1` instead."
799
        )
800
        if len(self) != 1:
1✔
801
            raise DataJointError("Update is only allowed on one tuple at a time")
1✔
802
        if attrname not in self.heading:
1✔
803
            raise DataJointError("Invalid attribute name")
×
804
        if attrname in self.heading.primary_key:
1✔
805
            raise DataJointError("Cannot update a key value.")
×
806

807
        attr = self.heading[attrname]
1✔
808

809
        if attr.is_blob:
1✔
810
            value = blob.pack(value)
1✔
811
            placeholder = "%s"
1✔
812
        elif attr.numeric:
1✔
813
            if value is None or np.isnan(float(value)):  # nans are turned into NULLs
1✔
814
                placeholder = "NULL"
1✔
815
                value = None
1✔
816
            else:
817
                placeholder = "%s"
1✔
818
                value = str(int(value) if isinstance(value, bool) else value)
1✔
819
        else:
820
            placeholder = "%s" if value is not None else "NULL"
1✔
821
        command = "UPDATE {full_table_name} SET `{attrname}`={placeholder} {where_clause}".format(
1✔
822
            full_table_name=self.from_clause(),
823
            attrname=attrname,
824
            placeholder=placeholder,
825
            where_clause=self.where_clause(),
826
        )
827
        self.connection.query(command, args=(value,) if value is not None else ())
1✔
828

829
    # --- private helper functions ----
830
    def __make_placeholder(self, name, value, ignore_extra_fields=False):
1✔
831
        """
832
        For a given attribute `name` with `value`, return its processed value or value placeholder
833
        as a string to be included in the query and the value, if any, to be submitted for
834
        processing by mysql API.
835

836
        :param name:  name of attribute to be inserted
837
        :param value: value of attribute to be inserted
838
        """
839
        if ignore_extra_fields and name not in self.heading:
1✔
840
            return None
×
841
        attr = self.heading[name]
1✔
842
        if attr.adapter:
1✔
843
            value = attr.adapter.put(value)
1✔
844
        if value is None or (attr.numeric and (value == "" or np.isnan(float(value)))):
1✔
845
            # set default value
846
            placeholder, value = "DEFAULT", None
1✔
847
        else:  # not NULL
848
            placeholder = "%s"
1✔
849
            if attr.uuid:
1✔
850
                if not isinstance(value, uuid.UUID):
1✔
851
                    try:
1✔
852
                        value = uuid.UUID(value)
1✔
853
                    except (AttributeError, ValueError):
1✔
854
                        raise DataJointError(
1✔
855
                            "badly formed UUID value {v} for attribute `{n}`".format(
856
                                v=value, n=name
857
                            )
858
                        )
859
                value = value.bytes
1✔
860
            elif attr.is_blob:
1✔
861
                value = blob.pack(value)
1✔
862
                value = (
1✔
863
                    self.external[attr.store].put(value).bytes
864
                    if attr.is_external
865
                    else value
866
                )
867
            elif attr.is_attachment:
1✔
868
                attachment_path = Path(value)
1✔
869
                if attr.is_external:
1✔
870
                    # value is hash of contents
871
                    value = (
1✔
872
                        self.external[attr.store]
873
                        .upload_attachment(attachment_path)
874
                        .bytes
875
                    )
876
                else:
877
                    # value is filename + contents
878
                    value = (
1✔
879
                        str.encode(attachment_path.name)
880
                        + b"\0"
881
                        + attachment_path.read_bytes()
882
                    )
883
            elif attr.is_filepath:
1✔
884
                value = self.external[attr.store].upload_filepath(value).bytes
1✔
885
            elif attr.numeric:
1✔
886
                value = str(int(value) if isinstance(value, bool) else value)
1✔
887
            elif attr.json:
1✔
888
                value = json.dumps(value)
×
889
        return name, placeholder, value
1✔
890

891
    def __make_row_to_insert(self, row, field_list, ignore_extra_fields):
1✔
892
        """
893
        Helper function for insert and update
894

895
        :param row:  A tuple to insert
896
        :return: a dict with fields 'names', 'placeholders', 'values'
897
        """
898

899
        def check_fields(fields):
1✔
900
            """
901
            Validates that all items in `fields` are valid attributes in the heading
902

903
            :param fields: field names of a tuple
904
            """
905
            if not field_list:
1✔
906
                if not ignore_extra_fields:
1✔
907
                    for field in fields:
1✔
908
                        if field not in self.heading:
1✔
909
                            raise KeyError(
1✔
910
                                "`{0:s}` is not in the table heading".format(field)
911
                            )
912
            elif set(field_list) != set(fields).intersection(self.heading.names):
1✔
913
                raise DataJointError("Attempt to insert rows with different fields.")
1✔
914

915
        if isinstance(row, np.void):  # np.array
1✔
916
            check_fields(row.dtype.fields)
1✔
917
            attributes = [
1✔
918
                self.__make_placeholder(name, row[name], ignore_extra_fields)
919
                for name in self.heading
920
                if name in row.dtype.fields
921
            ]
922
        elif isinstance(row, collections.abc.Mapping):  # dict-based
1✔
923
            check_fields(row)
1✔
924
            attributes = [
1✔
925
                self.__make_placeholder(name, row[name], ignore_extra_fields)
926
                for name in self.heading
927
                if name in row
928
            ]
929
        else:  # positional
930
            try:
1✔
931
                if len(row) != len(self.heading):
1✔
932
                    raise DataJointError(
1✔
933
                        "Invalid insert argument. Incorrect number of attributes: "
934
                        "{given} given; {expected} expected".format(
935
                            given=len(row), expected=len(self.heading)
936
                        )
937
                    )
938
            except TypeError:
1✔
939
                raise DataJointError("Datatype %s cannot be inserted" % type(row))
1✔
940
            else:
941
                attributes = [
1✔
942
                    self.__make_placeholder(name, value, ignore_extra_fields)
943
                    for name, value in zip(self.heading, row)
944
                ]
945
        if ignore_extra_fields:
1✔
946
            attributes = [a for a in attributes if a is not None]
1✔
947

948
        assert len(attributes), "Empty tuple"
1✔
949
        row_to_insert = dict(zip(("names", "placeholders", "values"), zip(*attributes)))
1✔
950
        if not field_list:
1✔
951
            # first row sets the composition of the field list
952
            field_list.extend(row_to_insert["names"])
1✔
953
        else:
954
            #  reorder attributes in row_to_insert to match field_list
955
            order = list(row_to_insert["names"].index(field) for field in field_list)
1✔
956
            row_to_insert["names"] = list(row_to_insert["names"][i] for i in order)
1✔
957
            row_to_insert["placeholders"] = list(
1✔
958
                row_to_insert["placeholders"][i] for i in order
959
            )
960
            row_to_insert["values"] = list(row_to_insert["values"][i] for i in order)
1✔
961
        return row_to_insert
1✔
962

963

964
def lookup_class_name(name, context, depth=3):
1✔
965
    """
966
    given a table name in the form `schema_name`.`table_name`, find its class in the context.
967

968
    :param name: `schema_name`.`table_name`
969
    :param context: dictionary representing the namespace
970
    :param depth: search depth into imported modules, helps avoid infinite recursion.
971
    :return: class name found in the context or None if not found
972
    """
973
    # breadth-first search
974
    nodes = [dict(context=context, context_name="", depth=depth)]
1✔
975
    while nodes:
1✔
976
        node = nodes.pop(0)
1✔
977
        for member_name, member in node["context"].items():
1✔
978
            if not member_name.startswith("_"):  # skip IPython's implicit variables
1✔
979
                if inspect.isclass(member) and issubclass(member, Table):
1✔
980
                    if member.full_table_name == name:  # found it!
1✔
981
                        return ".".join([node["context_name"], member_name]).lstrip(".")
1✔
982
                    try:  # look for part tables
1✔
983
                        parts = member.__dict__
1✔
984
                    except AttributeError:
×
985
                        pass  # not a UserTable -- cannot have part tables.
×
986
                    else:
987
                        for part in (
1✔
988
                            getattr(member, p)
989
                            for p in parts
990
                            if p[0].isupper() and hasattr(member, p)
991
                        ):
992
                            if (
1✔
993
                                inspect.isclass(part)
994
                                and issubclass(part, Table)
995
                                and part.full_table_name == name
996
                            ):
997
                                return ".".join(
1✔
998
                                    [node["context_name"], member_name, part.__name__]
999
                                ).lstrip(".")
1000
                elif (
1✔
1001
                    node["depth"] > 0
1002
                    and inspect.ismodule(member)
1003
                    and member.__name__ != "datajoint"
1004
                ):
1005
                    try:
1✔
1006
                        nodes.append(
1✔
1007
                            dict(
1008
                                context=dict(inspect.getmembers(member)),
1009
                                context_name=node["context_name"] + "." + member_name,
1010
                                depth=node["depth"] - 1,
1011
                            )
1012
                        )
1013
                    except ImportError:
×
1014
                        pass  # could not import, so do not attempt
×
1015
    return None
1✔
1016

1017

1018
class FreeTable(Table):
1✔
1019
    """
1020
    A base table without a dedicated class. Each instance is associated with a table
1021
    specified by full_table_name.
1022

1023
    :param conn:  a dj.Connection object
1024
    :param full_table_name: in format `database`.`table_name`
1025
    """
1026

1027
    def __init__(self, conn, full_table_name):
1✔
1028
        self.database, self._table_name = (
1✔
1029
            s.strip("`") for s in full_table_name.split(".")
1030
        )
1031
        self._connection = conn
1✔
1032
        self._support = [full_table_name]
1✔
1033
        self._heading = Heading(
1✔
1034
            table_info=dict(
1035
                conn=conn,
1036
                database=self.database,
1037
                table_name=self.table_name,
1038
                context=None,
1039
            )
1040
        )
1041

1042
    def __repr__(self):
1✔
1043
        return (
1✔
1044
            "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name)
1045
            + super().__repr__()
1046
        )
1047

1048

1049
class Log(Table):
1✔
1050
    """
1051
    The log table for each schema.
1052
    Instances are callable.  Calls log the time and identifying information along with the event.
1053

1054
    :param skip_logging: if True, then log entry is skipped by default. See __call__
1055
    """
1056

1057
    _table_name = "~log"
1✔
1058

1059
    def __init__(self, conn, database, skip_logging=False):
1✔
1060
        self.database = database
1✔
1061
        self.skip_logging = skip_logging
1✔
1062
        self._connection = conn
1✔
1063
        self._heading = Heading(
1✔
1064
            table_info=dict(
1065
                conn=conn, database=database, table_name=self.table_name, context=None
1066
            )
1067
        )
1068
        self._support = [self.full_table_name]
1✔
1069

1070
        self._definition = """    # event logging table for `{database}`
1✔
1071
        id       :int unsigned auto_increment     # event order id
1072
        ---
1073
        timestamp = CURRENT_TIMESTAMP : timestamp # event timestamp
1074
        version  :varchar(12)                     # datajoint version
1075
        user     :varchar(255)                    # user@host
1076
        host=""  :varchar(255)                    # system hostname
1077
        event="" :varchar(255)                    # event message
1078
        """.format(
1079
            database=database
1080
        )
1081

1082
        super().__init__()
1✔
1083

1084
        if not self.is_declared:
1✔
1085
            self.declare()
1✔
1086
            self.connection.dependencies.clear()
1✔
1087
        self._user = self.connection.get_user()
1✔
1088

1089
    @property
1✔
1090
    def definition(self):
1091
        return self._definition
1✔
1092

1093
    def __call__(self, event, skip_logging=None):
1✔
1094
        """
1095

1096
        :param event: string to write into the log table
1097
        :param skip_logging: If True then do not log. If None, then use self.skip_logging
1098
        """
1099
        skip_logging = self.skip_logging if skip_logging is None else skip_logging
1✔
1100
        if not skip_logging:
1✔
1101
            try:
1✔
1102
                self.insert1(
1✔
1103
                    dict(
1104
                        user=self._user,
1105
                        version=version + "py",
1106
                        host=platform.uname().node,
1107
                        event=event,
1108
                    ),
1109
                    skip_duplicates=True,
1110
                    ignore_extra_fields=True,
1111
                )
1112
            except DataJointError:
×
1113
                logger.info("could not log event in table ~log")
×
1114

1115
    def delete(self):
1✔
1116
        """
1117
        bypass interactive prompts and cascading dependencies
1118

1119
        :return: number of deleted items
1120
        """
1121
        return self.delete_quick(get_count=True)
×
1122

1123
    def drop(self):
1✔
1124
        """bypass interactive prompts and cascading dependencies"""
1125
        self.drop_quick()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc