• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12900

pending completion
#12900

push

travis-ci

web-flow
<a href="https://github.com/datajoint/datajoint-python/commit/<a class=hub.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">864be0ccc<a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa">">Merge </a><a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/<a class="double-link" href="https://github.com/datajoint/datajoint-python/commit/d3b1af13150e5e3a26410b98f2dc2a19ec2b5368">d3b1af131</a>">d3b1af131</a><a href="https://github.com/datajoint/datajoint-python/commit/864be0ccca479b08973e3dc4531e096bf97088fa"> into 3b6e84588">3b6e84588</a>

81 of 81 new or added lines in 10 files covered. (100.0%)

3134 of 3489 relevant lines covered (89.83%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.6
/datajoint/table.py
1
import collections
1✔
2
import itertools
1✔
3
import inspect
1✔
4
import platform
1✔
5
import numpy as np
1✔
6
import pandas
1✔
7
import logging
1✔
8
import uuid
1✔
9
import csv
1✔
10
import re
1✔
11
from pathlib import Path
1✔
12
from .settings import config
1✔
13
from .declare import declare, alter
1✔
14
from .condition import make_condition
1✔
15
from .expression import QueryExpression
1✔
16
from . import blob
1✔
17
from .utils import user_choice, get_master
1✔
18
from .heading import Heading
1✔
19
from .errors import (
1✔
20
    DuplicateError,
21
    AccessError,
22
    DataJointError,
23
    UnknownAttributeError,
24
    IntegrityError,
25
)
26
from typing import Union
1✔
27
from .version import __version__ as version
1✔
28

29
logger = logging.getLogger(__name__.split(".")[0])
1✔
30

31
foreign_key_error_regexp = re.compile(
1✔
32
    r"[\w\s:]*\((?P<child>`[^`]+`.`[^`]+`), "
33
    r"CONSTRAINT (?P<name>`[^`]+`) "
34
    r"(FOREIGN KEY \((?P<fk_attrs>[^)]+)\) "
35
    r"REFERENCES (?P<parent>`[^`]+`(\.`[^`]+`)?) \((?P<pk_attrs>[^)]+)\)[\s\w]+\))?"
36
)
37

38
constraint_info_query = " ".join(
1✔
39
    """
40
    SELECT
41
        COLUMN_NAME as fk_attrs,
42
        CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent,
43
        REFERENCED_COLUMN_NAME as pk_attrs
44
    FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
45
    WHERE
46
        CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s;
47
    """.split()
48
)
49

50

51
class _RenameMap(tuple):
1✔
52
    """for internal use"""
53

54
    pass
1✔
55

56

57
class Table(QueryExpression):
1✔
58
    """
59
    Table is an abstract class that represents a table in the schema.
60
    It implements insert and delete methods and inherits query functionality.
61
    To make it a concrete class, override the abstract properties specifying the connection,
62
    table name, database, and definition.
63
    """
64

65
    _table_name = None  # must be defined in subclass
1✔
66
    _log_ = None  # placeholder for the Log table object
1✔
67

68
    # These properties must be set by the schema decorator (schemas.py) at class level
69
    # or by FreeTable at instance level
70
    database = None
1✔
71
    declaration_context = None
1✔
72

73
    @property
1✔
74
    def table_name(self):
1✔
75
        return self._table_name
1✔
76

77
    @property
1✔
78
    def definition(self):
1✔
79
        raise NotImplementedError(
×
80
            "Subclasses of Table must implement the `definition` property"
81
        )
82

83
    def declare(self, context=None):
1✔
84
        """
85
        Declare the table in the schema based on self.definition.
86

87
        :param context: the context for foreign key resolution. If None, foreign keys are
88
            not allowed.
89
        """
90
        if self.connection.in_transaction:
1✔
91
            raise DataJointError(
×
92
                "Cannot declare new tables inside a transaction, "
93
                "e.g. from inside a populate/make call"
94
            )
95
        sql, external_stores = declare(self.full_table_name, self.definition, context)
1✔
96
        sql = sql.format(database=self.database)
1✔
97
        try:
1✔
98
            # declare all external tables before declaring main table
99
            for store in external_stores:
1✔
100
                self.connection.schemas[self.database].external[store]
1✔
101
            self.connection.query(sql)
1✔
102
        except AccessError:
1✔
103
            # skip if no create privilege
104
            pass
1✔
105
        else:
106
            self._log("Declared " + self.full_table_name)
1✔
107

108
    def alter(self, prompt=True, context=None):
1✔
109
        """
110
        Alter the table definition from self.definition
111
        """
112
        if self.connection.in_transaction:
1✔
113
            raise DataJointError(
×
114
                "Cannot update table declaration inside a transaction, "
115
                "e.g. from inside a populate/make call"
116
            )
117
        if context is None:
1✔
118
            frame = inspect.currentframe().f_back
1✔
119
            context = dict(frame.f_globals, **frame.f_locals)
1✔
120
            del frame
1✔
121
        old_definition = self.describe(context=context, printout=False)
1✔
122
        sql, external_stores = alter(self.definition, old_definition, context)
1✔
123
        if not sql:
1✔
124
            if prompt:
×
125
                print("Nothing to alter.")
×
126
        else:
127
            sql = "ALTER TABLE {tab}\n\t".format(
1✔
128
                tab=self.full_table_name
129
            ) + ",\n\t".join(sql)
130
            if not prompt or user_choice(sql + "\n\nExecute?") == "yes":
1✔
131
                try:
1✔
132
                    # declare all external tables before declaring main table
133
                    for store in external_stores:
1✔
134
                        self.connection.schemas[self.database].external[store]
×
135
                    self.connection.query(sql)
1✔
136
                except AccessError:
×
137
                    # skip if no create privilege
138
                    pass
×
139
                else:
140
                    # reset heading
141
                    self.__class__._heading = Heading(
1✔
142
                        table_info=self.heading.table_info
143
                    )
144
                    if prompt:
1✔
145
                        print("Table altered")
×
146
                    self._log("Altered " + self.full_table_name)
1✔
147

148
    def from_clause(self):
1✔
149
        """
150
        :return: the FROM clause of SQL SELECT statements.
151
        """
152
        return self.full_table_name
1✔
153

154
    def get_select_fields(self, select_fields=None):
1✔
155
        """
156
        :return: the selected attributes from the SQL SELECT statement.
157
        """
158
        return (
×
159
            "*" if select_fields is None else self.heading.project(select_fields).as_sql
160
        )
161

162
    def parents(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
163
        """
164

165
        :param primary: if None, then all parents are returned. If True, then only foreign keys composed of
166
            primary key attributes are considered.  If False, return foreign keys including at least one
167
            secondary attribute.
168
        :param as_objects: if False, return table names. If True, return table objects.
169
        :param foreign_key_info: if True, each element in result also includes foreign key info.
170
        :return: list of parents as table names or table objects
171
            with (optional) foreign key information.
172
        """
173
        get_edge = self.connection.dependencies.parents
1✔
174
        nodes = [
1✔
175
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
176
            for name, props in get_edge(self.full_table_name, primary).items()
177
        ]
178
        if as_objects:
1✔
179
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
180
        if not foreign_key_info:
1✔
181
            nodes = [name for name, props in nodes]
1✔
182
        return nodes
1✔
183

184
    def children(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
185
        """
186

187
        :param primary: if None, then all children are returned. If True, then only foreign keys composed of
188
            primary key attributes are considered.  If False, return foreign keys including at least one
189
            secondary attribute.
190
        :param as_objects: if False, return table names. If True, return table objects.
191
        :param foreign_key_info: if True, each element in result also includes foreign key info.
192
        :return: list of children as table names or table objects
193
            with (optional) foreign key information.
194
        """
195
        get_edge = self.connection.dependencies.children
1✔
196
        nodes = [
1✔
197
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
198
            for name, props in get_edge(self.full_table_name, primary).items()
199
        ]
200
        if as_objects:
1✔
201
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
202
        if not foreign_key_info:
1✔
203
            nodes = [name for name, props in nodes]
1✔
204
        return nodes
1✔
205

206
    def descendants(self, as_objects=False):
1✔
207
        """
208

209
        :param as_objects: False - a list of table names; True - a list of table objects.
210
        :return: list of tables descendants in topological order.
211
        """
212
        return [
1✔
213
            FreeTable(self.connection, node) if as_objects else node
214
            for node in self.connection.dependencies.descendants(self.full_table_name)
215
            if not node.isdigit()
216
        ]
217

218
    def ancestors(self, as_objects=False):
1✔
219
        """
220

221
        :param as_objects: False - a list of table names; True - a list of table objects.
222
        :return: list of tables ancestors in topological order.
223
        """
224
        return [
1✔
225
            FreeTable(self.connection, node) if as_objects else node
226
            for node in self.connection.dependencies.ancestors(self.full_table_name)
227
            if not node.isdigit()
228
        ]
229

230
    def parts(self, as_objects=False):
1✔
231
        """
232
        return part tables either as entries in a dict with foreign key informaiton or a list of objects
233

234
        :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
235
        """
236
        nodes = [
1✔
237
            node
238
            for node in self.connection.dependencies.nodes
239
            if not node.isdigit() and node.startswith(self.full_table_name[:-1] + "__")
240
        ]
241
        return [FreeTable(self.connection, c) for c in nodes] if as_objects else nodes
1✔
242

243
    @property
1✔
244
    def is_declared(self):
1✔
245
        """
246
        :return: True is the table is declared in the schema.
247
        """
248
        return (
1✔
249
            self.connection.query(
250
                'SHOW TABLES in `{database}` LIKE "{table_name}"'.format(
251
                    database=self.database, table_name=self.table_name
252
                )
253
            ).rowcount
254
            > 0
255
        )
256

257
    @property
1✔
258
    def full_table_name(self):
1✔
259
        """
260
        :return: full table name in the schema
261
        """
262
        return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name)
1✔
263

264
    @property
1✔
265
    def _log(self):
1✔
266
        if self._log_ is None:
1✔
267
            self._log_ = Log(
1✔
268
                self.connection,
269
                database=self.database,
270
                skip_logging=self.table_name.startswith("~"),
271
            )
272
        return self._log_
1✔
273

274
    @property
1✔
275
    def external(self):
1✔
276
        return self.connection.schemas[self.database].external
1✔
277

278
    def update1(self, row):
1✔
279
        """
280
        ``update1`` updates one existing entry in the table.
281
        Caution: In DataJoint the primary modes for data manipulation is to ``insert`` and
282
        ``delete`` entire records since referential integrity works on the level of records,
283
        not fields. Therefore, updates are reserved for corrective operations outside of main
284
        workflow. Use UPDATE methods sparingly with full awareness of potential violations of
285
        assumptions.
286

287
        :param row: a ``dict`` containing the primary key values and the attributes to update.
288
            Setting an attribute value to None will reset it to the default value (if any).
289

290
        The primary key attributes must always be provided.
291

292
        Examples:
293

294
        >>> table.update1({'id': 1, 'value': 3})  # update value in record with id=1
295
        >>> table.update1({'id': 1, 'value': None})  # reset value to default
296
        """
297
        # argument validations
298
        if not isinstance(row, collections.abc.Mapping):
1✔
299
            raise DataJointError("The argument of update1 must be dict-like.")
×
300
        if not set(row).issuperset(self.primary_key):
1✔
301
            raise DataJointError(
1✔
302
                "The argument of update1 must supply all primary key values."
303
            )
304
        try:
1✔
305
            raise DataJointError(
1✔
306
                "Attribute `%s` not found."
307
                % next(k for k in row if k not in self.heading.names)
308
            )
309
        except StopIteration:
1✔
310
            pass  # ok
1✔
311
        if len(self.restriction):
1✔
312
            raise DataJointError("Update cannot be applied to a restricted table.")
×
313
        key = {k: row[k] for k in self.primary_key}
1✔
314
        if len(self & key) != 1:
1✔
315
            raise DataJointError("Update can only be applied to one existing entry.")
1✔
316
        # UPDATE query
317
        row = [
1✔
318
            self.__make_placeholder(k, v)
319
            for k, v in row.items()
320
            if k not in self.primary_key
321
        ]
322
        query = "UPDATE {table} SET {assignments} WHERE {where}".format(
1✔
323
            table=self.full_table_name,
324
            assignments=",".join("`%s`=%s" % r[:2] for r in row),
325
            where=make_condition(self, key, set()),
326
        )
327
        self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))
1✔
328

329
    def insert1(self, row, **kwargs):
1✔
330
        """
331
        Insert one data record into the table. For ``kwargs``, see ``insert()``.
332

333
        :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted
334
            as one row.
335
        """
336
        self.insert((row,), **kwargs)
1✔
337

338
    def _normalize_insert_data(self, rows):
1✔
339
        """Handle types provided to `insert`, including DataFrame, Path and class"""
340

341
        if isinstance(rows, pandas.DataFrame):
1✔
342
            # drop 'extra' synthetic index for 1-field index case -
343
            # frames with more advanced indices should be prepared by user.
344
            rows = rows.reset_index(
1✔
345
                drop=len(rows.index.names) == 1 and not rows.index.names[0]
346
            ).to_records(index=False)
347

348
        if isinstance(rows, Path):
1✔
349
            with open(rows, newline="") as data_file:
1✔
350
                rows = list(csv.DictReader(data_file, delimiter=","))
1✔
351

352
        if inspect.isclass(rows) and issubclass(rows, QueryExpression):
1✔
353
            rows = rows()  # instantiate if a class
1✔
354

355
        return rows
1✔
356

357
    def insert(
1✔
358
        self,
359
        rows,
360
        replace=False,
361
        skip_duplicates=False,
362
        ignore_extra_fields=False,
363
        allow_direct_insert=None,
364
    ):
365
        """
366
        Insert a collection of rows.
367

368
        :param rows: Either (a) an iterable where an element is a numpy record, a
369
            dict-like object, a pandas.DataFrame, a sequence, or a query expression with
370
            the same heading as self, or (b) a pathlib.Path object specifying a path
371
            relative to the current directory with a CSV file, the contents of which
372
            will be inserted.
373
        :param replace: If True, replaces the existing tuple.
374
        :param skip_duplicates: If True, silently skip duplicate inserts.
375
        :param ignore_extra_fields: If False, fields that are not in the heading raise error.
376
        :param allow_direct_insert: Only applies in auto-populated tables. If False (default),
377
            insert may only be called from inside the make callback.
378

379
        Example:
380

381
            >>> Table.insert([
382
            >>>     dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"),
383
            >>>     dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
384
        """
385
        # prohibit direct inserts into auto-populated tables
386
        if not allow_direct_insert and not getattr(self, "_allow_insert", True):
1✔
387
            raise DataJointError(
1✔
388
                "Inserts into an auto-populated table can only be done inside "
389
                "its make method during a populate call."
390
                " To override, set keyword argument allow_direct_insert=True."
391
            )
392

393
        rows = self._normalize_insert_data(rows)
1✔
394

395
        if isinstance(rows, QueryExpression):
1✔
396
            # insert from select
397
            if not ignore_extra_fields:
1✔
398
                try:
1✔
399
                    raise DataJointError(
1✔
400
                        "Attribute %s not found. To ignore extra attributes in insert, "
401
                        "set ignore_extra_fields=True."
402
                        % next(
403
                            name for name in rows.heading if name not in self.heading
404
                        )
405
                    )
406
                except StopIteration:
1✔
407
                    pass
1✔
408
            fields = list(name for name in rows.heading if name in self.heading)
1✔
409
            query = "{command} INTO {table} ({fields}) {select}{duplicate}".format(
1✔
410
                command="REPLACE" if replace else "INSERT",
411
                fields="`" + "`,`".join(fields) + "`",
412
                table=self.full_table_name,
413
                select=rows.make_sql(fields),
414
                duplicate=(
415
                    " ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`".format(
416
                        table=self.full_table_name, pk=self.primary_key[0]
417
                    )
418
                    if skip_duplicates
419
                    else ""
420
                ),
421
            )
422
            self.connection.query(query)
1✔
423
            return
1✔
424

425
        field_list = []  # collects the field list from first row (passed by reference)
1✔
426
        rows = list(
1✔
427
            self.__make_row_to_insert(row, field_list, ignore_extra_fields)
428
            for row in rows
429
        )
430
        if rows:
1✔
431
            try:
1✔
432
                query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
1✔
433
                    command="REPLACE" if replace else "INSERT",
434
                    destination=self.from_clause(),
435
                    fields="`,`".join(field_list),
436
                    placeholders=",".join(
437
                        "(" + ",".join(row["placeholders"]) + ")" for row in rows
438
                    ),
439
                    duplicate=(
440
                        " ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(
441
                            pk=self.primary_key[0]
442
                        )
443
                        if skip_duplicates
444
                        else ""
445
                    ),
446
                )
447
                self.connection.query(
1✔
448
                    query,
449
                    args=list(
450
                        itertools.chain.from_iterable(
451
                            (v for v in r["values"] if v is not None) for r in rows
452
                        )
453
                    ),
454
                )
455
            except UnknownAttributeError as err:
1✔
456
                raise err.suggest(
×
457
                    "To ignore extra fields in insert, set ignore_extra_fields=True"
458
                )
459
            except DuplicateError as err:
1✔
460
                raise err.suggest(
1✔
461
                    "To ignore duplicate entries in insert, set skip_duplicates=True"
462
                )
463

464
    def delete_quick(self, get_count=False):
1✔
465
        """
466
        Deletes the table without cascading and without user prompt.
467
        If this table has populated dependent tables, this will fail.
468
        """
469
        query = "DELETE FROM " + self.full_table_name + self.where_clause()
1✔
470
        self.connection.query(query)
1✔
471
        count = (
1✔
472
            self.connection.query("SELECT ROW_COUNT()").fetchone()[0]
473
            if get_count
474
            else None
475
        )
476
        self._log(query[:255])
1✔
477
        return count
1✔
478

479
    def delete(
1✔
480
        self,
481
        transaction: bool = True,
482
        safemode: Union[bool, None] = None,
483
        force_parts: bool = False,
484
    ) -> int:
485
        """
486
        Deletes the contents of the table and its dependent tables, recursively.
487

488
        Args:
489
            transaction: If `True`, use of the entire delete becomes an atomic transaction.
490
                This is the default and recommended behavior. Set to `False` if this delete is
491
                nested within another transaction.
492
            safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
493
                is `dj.config['safemode']`.
494
            force_parts: Delete from parts even when not deleting from their masters.
495

496
        Returns:
497
            Number of deleted rows (excluding those from dependent tables).
498

499
        Raises:
500
            DataJointError: Delete exceeds maximum number of delete attempts.
501
            DataJointError: When deleting within an existing transaction.
502
            DataJointError: Deleting a part table before its master.
503
        """
504
        deleted = set()
1✔
505

506
        def cascade(table):
1✔
507
            """service function to perform cascading deletes recursively."""
508
            max_attempts = 50
1✔
509
            for _ in range(max_attempts):
1✔
510
                try:
1✔
511
                    delete_count = table.delete_quick(get_count=True)
1✔
512
                except IntegrityError as error:
1✔
513
                    match = foreign_key_error_regexp.match(error.args[0]).groupdict()
1✔
514
                    if "`.`" not in match["child"]:  # if schema name missing, use table
1✔
515
                        match["child"] = "{}.{}".format(
×
516
                            table.full_table_name.split(".")[0], match["child"]
517
                        )
518
                    if (
1✔
519
                        match["pk_attrs"] is not None
520
                    ):  # fully matched, adjusting the keys
521
                        match["fk_attrs"] = [
1✔
522
                            k.strip("`") for k in match["fk_attrs"].split(",")
523
                        ]
524
                        match["pk_attrs"] = [
1✔
525
                            k.strip("`") for k in match["pk_attrs"].split(",")
526
                        ]
527
                    else:  # only partially matched, querying with constraint to determine keys
528
                        match["fk_attrs"], match["parent"], match["pk_attrs"] = list(
1✔
529
                            map(
530
                                list,
531
                                zip(
532
                                    *table.connection.query(
533
                                        constraint_info_query,
534
                                        args=(
535
                                            match["name"].strip("`"),
536
                                            *[
537
                                                _.strip("`")
538
                                                for _ in match["child"].split("`.`")
539
                                            ],
540
                                        ),
541
                                    ).fetchall()
542
                                ),
543
                            )
544
                        )
545
                        match["parent"] = match["parent"][0]
1✔
546

547
                    # Restrict child by table if
548
                    #   1. if table's restriction attributes are not in child's primary key
549
                    #   2. if child renames any attributes
550
                    # Otherwise restrict child by table's restriction.
551
                    child = FreeTable(table.connection, match["child"])
1✔
552
                    if (
1✔
553
                        set(table.restriction_attributes) <= set(child.primary_key)
554
                        and match["fk_attrs"] == match["pk_attrs"]
555
                    ):
556
                        child._restriction = table._restriction
1✔
557
                    elif match["fk_attrs"] != match["pk_attrs"]:
1✔
558
                        child &= table.proj(
1✔
559
                            **dict(zip(match["fk_attrs"], match["pk_attrs"]))
560
                        )
561
                    else:
562
                        child &= table.proj()
1✔
563
                    cascade(child)
1✔
564
                else:
565
                    deleted.add(table.full_table_name)
1✔
566
                    logger.info(
1✔
567
                        "Deleting {count} rows from {table}".format(
568
                            count=delete_count, table=table.full_table_name
569
                        )
570
                    )
571
                    break
1✔
572
            else:
573
                raise DataJointError("Exceeded maximum number of delete attempts.")
×
574
            return delete_count
1✔
575

576
        safemode = config["safemode"] if safemode is None else safemode
1✔
577

578
        # Start transaction
579
        if transaction:
1✔
580
            if not self.connection.in_transaction:
1✔
581
                self.connection.start_transaction()
1✔
582
            else:
583
                if not safemode:
×
584
                    transaction = False
×
585
                else:
586
                    raise DataJointError(
×
587
                        "Delete cannot use a transaction within an ongoing transaction. "
588
                        "Set transaction=False or safemode=False)."
589
                    )
590

591
        # Cascading delete
592
        try:
1✔
593
            delete_count = cascade(self)
1✔
594
        except:
×
595
            if transaction:
×
596
                self.connection.cancel_transaction()
×
597
            raise
×
598

599
        if not force_parts:
1✔
600
            # Avoid deleting from child before master (See issue #151)
601
            for part in deleted:
1✔
602
                master = get_master(part)
1✔
603
                if master and master not in deleted:
1✔
604
                    if transaction:
1✔
605
                        self.connection.cancel_transaction()
1✔
606
                    raise DataJointError(
1✔
607
                        "Attempt to delete part table {part} before deleting from "
608
                        "its master {master} first.".format(part=part, master=master)
609
                    )
610

611
        # Confirm and commit
612
        if delete_count == 0:
1✔
613
            if safemode:
1✔
614
                print("Nothing to delete.")
×
615
            if transaction:
1✔
616
                self.connection.cancel_transaction()
1✔
617
        else:
618
            if not safemode or user_choice("Commit deletes?", default="no") == "yes":
1✔
619
                if transaction:
1✔
620
                    self.connection.commit_transaction()
1✔
621
                if safemode:
1✔
622
                    print("Deletes committed.")
×
623
            else:
624
                if transaction:
×
625
                    self.connection.cancel_transaction()
×
626
                if safemode:
×
627
                    print("Deletes cancelled")
×
628
        return delete_count
1✔
629

630
    def drop_quick(self):
1✔
631
        """
632
        Drops the table without cascading to dependent tables and without user prompt.
633
        """
634
        if self.is_declared:
1✔
635
            query = "DROP TABLE %s" % self.full_table_name
1✔
636
            self.connection.query(query)
1✔
637
            logger.info("Dropped table %s" % self.full_table_name)
1✔
638
            self._log(query[:255])
1✔
639
        else:
640
            logger.info(
×
641
                "Nothing to drop: table %s is not declared" % self.full_table_name
642
            )
643

644
    def drop(self):
1✔
645
        """
646
        Drop the table and all tables that reference it, recursively.
647
        User is prompted for confirmation if config['safemode'] is set to True.
648
        """
649
        if self.restriction:
1✔
650
            raise DataJointError(
×
651
                "A table with an applied restriction cannot be dropped."
652
                " Call drop() on the unrestricted Table."
653
            )
654
        self.connection.dependencies.load()
1✔
655
        do_drop = True
1✔
656
        tables = [
1✔
657
            table
658
            for table in self.connection.dependencies.descendants(self.full_table_name)
659
            if not table.isdigit()
660
        ]
661

662
        # avoid dropping part tables without their masters: See issue #374
663
        for part in tables:
1✔
664
            master = get_master(part)
1✔
665
            if master and master not in tables:
1✔
666
                raise DataJointError(
1✔
667
                    "Attempt to drop part table {part} before dropping "
668
                    "its master. Drop {master} first.".format(part=part, master=master)
669
                )
670

671
        if config["safemode"]:
1✔
672
            for table in tables:
1✔
673
                print(table, "(%d tuples)" % len(FreeTable(self.connection, table)))
1✔
674
            do_drop = user_choice("Proceed?", default="no") == "yes"
1✔
675
        if do_drop:
1✔
676
            for table in reversed(tables):
1✔
677
                FreeTable(self.connection, table).drop_quick()
1✔
678
            print("Tables dropped.  Restart kernel.")
1✔
679

680
    @property
1✔
681
    def size_on_disk(self):
1✔
682
        """
683
        :return: size of data and indices in bytes on the storage device
684
        """
685
        ret = self.connection.query(
1✔
686
            'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(
687
                database=self.database, table=self.table_name
688
            ),
689
            as_dict=True,
690
        ).fetchone()
691
        return ret["Data_length"] + ret["Index_length"]
1✔
692

693
    def show_definition(self):
1✔
694
        raise AttributeError(
×
695
            "show_definition is deprecated. Use the describe method instead."
696
        )
697

698
    def describe(self, context=None, printout=True):
1✔
699
        """
700
        :return:  the definition string for the query using DataJoint DDL.
701
        """
702
        if context is None:
1✔
703
            frame = inspect.currentframe().f_back
1✔
704
            context = dict(frame.f_globals, **frame.f_locals)
1✔
705
            del frame
1✔
706
        if self.full_table_name not in self.connection.dependencies:
1✔
707
            self.connection.dependencies.load()
1✔
708
        parents = self.parents(foreign_key_info=True)
1✔
709
        in_key = True
1✔
710
        definition = (
1✔
711
            "# " + self.heading.table_status["comment"] + "\n"
712
            if self.heading.table_status["comment"]
713
            else ""
714
        )
715
        attributes_thus_far = set()
1✔
716
        attributes_declared = set()
1✔
717
        indexes = self.heading.indexes.copy()
1✔
718
        for attr in self.heading.attributes.values():
1✔
719
            if in_key and not attr.in_key:
1✔
720
                definition += "---\n"
1✔
721
                in_key = False
1✔
722
            attributes_thus_far.add(attr.name)
1✔
723
            do_include = True
1✔
724
            for parent_name, fk_props in parents:
1✔
725
                if attr.name in fk_props["attr_map"]:
1✔
726
                    do_include = False
1✔
727
                    if attributes_thus_far.issuperset(fk_props["attr_map"]):
1✔
728
                        # foreign key properties
729
                        try:
1✔
730
                            index_props = indexes.pop(tuple(fk_props["attr_map"]))
1✔
731
                        except KeyError:
1✔
732
                            index_props = ""
1✔
733
                        else:
734
                            index_props = [k for k, v in index_props.items() if v]
1✔
735
                            index_props = (
1✔
736
                                " [{}]".format(", ".join(index_props))
737
                                if index_props
738
                                else ""
739
                            )
740

741
                        if not fk_props["aliased"]:
1✔
742
                            # simple foreign key
743
                            definition += "->{props} {class_name}\n".format(
1✔
744
                                props=index_props,
745
                                class_name=lookup_class_name(parent_name, context)
746
                                or parent_name,
747
                            )
748
                        else:
749
                            # projected foreign key
750
                            definition += (
1✔
751
                                "->{props} {class_name}.proj({proj_list})\n".format(
752
                                    props=index_props,
753
                                    class_name=lookup_class_name(parent_name, context)
754
                                    or parent_name,
755
                                    proj_list=",".join(
756
                                        '{}="{}"'.format(attr, ref)
757
                                        for attr, ref in fk_props["attr_map"].items()
758
                                        if ref != attr
759
                                    ),
760
                                )
761
                            )
762
                            attributes_declared.update(fk_props["attr_map"])
1✔
763
            if do_include:
1✔
764
                attributes_declared.add(attr.name)
1✔
765
                definition += "%-20s : %-28s %s\n" % (
1✔
766
                    attr.name
767
                    if attr.default is None
768
                    else "%s=%s" % (attr.name, attr.default),
769
                    "%s%s"
770
                    % (attr.type, " auto_increment" if attr.autoincrement else ""),
771
                    "# " + attr.comment if attr.comment else "",
772
                )
773
        # add remaining indexes
774
        for k, v in indexes.items():
1✔
775
            definition += "{unique}INDEX ({attrs})\n".format(
1✔
776
                unique="UNIQUE " if v["unique"] else "", attrs=", ".join(k)
777
            )
778
        if printout:
1✔
779
            print(definition)
1✔
780
        return definition
1✔
781

782
    def _update(self, attrname, value=None):
1✔
783
        """
784
        This is a deprecated function to be removed in datajoint 0.14.
785
        Use ``.update1`` instead.
786

787
        Updates a field in one existing tuple. self must be restricted to exactly one entry.
788
        In DataJoint the principal way of updating data is to delete and re-insert the
789
        entire record and updates are reserved for corrective actions.
790
        This is because referential integrity is observed on the level of entire
791
        records rather than individual attributes.
792

793
        Safety constraints:
794
           1. self must be restricted to exactly one tuple
795
           2. the update attribute must not be in primary key
796

797
        Example:
798
        >>> (v2p.Mice() & key)._update('mouse_dob', '2011-01-01')
799
        >>> (v2p.Mice() & key)._update( 'lens')   # set the value to NULL
800
        """
801
        logger.warning(
1✔
802
            "`_update` is a deprecated function to be removed in datajoint 0.14. "
803
            "Use `.update1` instead."
804
        )
805
        if len(self) != 1:
1✔
806
            raise DataJointError("Update is only allowed on one tuple at a time")
1✔
807
        if attrname not in self.heading:
1✔
808
            raise DataJointError("Invalid attribute name")
×
809
        if attrname in self.heading.primary_key:
1✔
810
            raise DataJointError("Cannot update a key value.")
×
811

812
        attr = self.heading[attrname]
1✔
813

814
        if attr.is_blob:
1✔
815
            value = blob.pack(value)
1✔
816
            placeholder = "%s"
1✔
817
        elif attr.numeric:
1✔
818
            if value is None or np.isnan(float(value)):  # nans are turned into NULLs
1✔
819
                placeholder = "NULL"
1✔
820
                value = None
1✔
821
            else:
822
                placeholder = "%s"
1✔
823
                value = str(int(value) if isinstance(value, bool) else value)
1✔
824
        else:
825
            placeholder = "%s" if value is not None else "NULL"
1✔
826
        command = "UPDATE {full_table_name} SET `{attrname}`={placeholder} {where_clause}".format(
1✔
827
            full_table_name=self.from_clause(),
828
            attrname=attrname,
829
            placeholder=placeholder,
830
            where_clause=self.where_clause(),
831
        )
832
        self.connection.query(command, args=(value,) if value is not None else ())
1✔
833

834
    # --- private helper functions ----
835
    def __make_placeholder(self, name, value, ignore_extra_fields=False):
1✔
836
        """
837
        For a given attribute `name` with `value`, return its processed value or value placeholder
838
        as a string to be included in the query and the value, if any, to be submitted for
839
        processing by mysql API.
840

841
        :param name:  name of attribute to be inserted
842
        :param value: value of attribute to be inserted
843
        """
844
        if ignore_extra_fields and name not in self.heading:
1✔
845
            return None
×
846
        attr = self.heading[name]
1✔
847
        if attr.adapter:
1✔
848
            value = attr.adapter.put(value)
1✔
849
        if value is None or (attr.numeric and (value == "" or np.isnan(float(value)))):
1✔
850
            # set default value
851
            placeholder, value = "DEFAULT", None
1✔
852
        else:  # not NULL
853
            placeholder = "%s"
1✔
854
            if attr.uuid:
1✔
855
                if not isinstance(value, uuid.UUID):
1✔
856
                    try:
1✔
857
                        value = uuid.UUID(value)
1✔
858
                    except (AttributeError, ValueError):
1✔
859
                        raise DataJointError(
1✔
860
                            "badly formed UUID value {v} for attribute `{n}`".format(
861
                                v=value, n=name
862
                            )
863
                        )
864
                value = value.bytes
1✔
865
            elif attr.is_blob:
1✔
866
                value = blob.pack(value)
1✔
867
                value = (
1✔
868
                    self.external[attr.store].put(value).bytes
869
                    if attr.is_external
870
                    else value
871
                )
872
            elif attr.is_attachment:
1✔
873
                attachment_path = Path(value)
1✔
874
                if attr.is_external:
1✔
875
                    # value is hash of contents
876
                    value = (
1✔
877
                        self.external[attr.store]
878
                        .upload_attachment(attachment_path)
879
                        .bytes
880
                    )
881
                else:
882
                    # value is filename + contents
883
                    value = (
1✔
884
                        str.encode(attachment_path.name)
885
                        + b"\0"
886
                        + attachment_path.read_bytes()
887
                    )
888
            elif attr.is_filepath:
1✔
889
                value = self.external[attr.store].upload_filepath(value).bytes
1✔
890
            elif attr.numeric:
1✔
891
                value = str(int(value) if isinstance(value, bool) else value)
1✔
892
        return name, placeholder, value
1✔
893

894
    def __make_row_to_insert(self, row, field_list, ignore_extra_fields):
1✔
895
        """
896
        Helper function for insert and update
897

898
        :param row:  A tuple to insert
899
        :return: a dict with fields 'names', 'placeholders', 'values'
900
        """
901

902
        def check_fields(fields):
1✔
903
            """
904
            Validates that all items in `fields` are valid attributes in the heading
905

906
            :param fields: field names of a tuple
907
            """
908
            if not field_list:
1✔
909
                if not ignore_extra_fields:
1✔
910
                    for field in fields:
1✔
911
                        if field not in self.heading:
1✔
912
                            raise KeyError(
1✔
913
                                "`{0:s}` is not in the table heading".format(field)
914
                            )
915
            elif set(field_list) != set(fields).intersection(self.heading.names):
1✔
916
                raise DataJointError("Attempt to insert rows with different fields.")
1✔
917

918
        if isinstance(row, np.void):  # np.array
1✔
919
            check_fields(row.dtype.fields)
1✔
920
            attributes = [
1✔
921
                self.__make_placeholder(name, row[name], ignore_extra_fields)
922
                for name in self.heading
923
                if name in row.dtype.fields
924
            ]
925
        elif isinstance(row, collections.abc.Mapping):  # dict-based
1✔
926
            check_fields(row)
1✔
927
            attributes = [
1✔
928
                self.__make_placeholder(name, row[name], ignore_extra_fields)
929
                for name in self.heading
930
                if name in row
931
            ]
932
        else:  # positional
933
            try:
1✔
934
                if len(row) != len(self.heading):
1✔
935
                    raise DataJointError(
1✔
936
                        "Invalid insert argument. Incorrect number of attributes: "
937
                        "{given} given; {expected} expected".format(
938
                            given=len(row), expected=len(self.heading)
939
                        )
940
                    )
941
            except TypeError:
1✔
942
                raise DataJointError("Datatype %s cannot be inserted" % type(row))
1✔
943
            else:
944
                attributes = [
1✔
945
                    self.__make_placeholder(name, value, ignore_extra_fields)
946
                    for name, value in zip(self.heading, row)
947
                ]
948
        if ignore_extra_fields:
1✔
949
            attributes = [a for a in attributes if a is not None]
1✔
950

951
        assert len(attributes), "Empty tuple"
1✔
952
        row_to_insert = dict(zip(("names", "placeholders", "values"), zip(*attributes)))
1✔
953
        if not field_list:
1✔
954
            # first row sets the composition of the field list
955
            field_list.extend(row_to_insert["names"])
1✔
956
        else:
957
            #  reorder attributes in row_to_insert to match field_list
958
            order = list(row_to_insert["names"].index(field) for field in field_list)
1✔
959
            row_to_insert["names"] = list(row_to_insert["names"][i] for i in order)
1✔
960
            row_to_insert["placeholders"] = list(
1✔
961
                row_to_insert["placeholders"][i] for i in order
962
            )
963
            row_to_insert["values"] = list(row_to_insert["values"][i] for i in order)
1✔
964
        return row_to_insert
1✔
965

966

967
def lookup_class_name(name, context, depth=3):
1✔
968
    """
969
    given a table name in the form `schema_name`.`table_name`, find its class in the context.
970

971
    :param name: `schema_name`.`table_name`
972
    :param context: dictionary representing the namespace
973
    :param depth: search depth into imported modules, helps avoid infinite recursion.
974
    :return: class name found in the context or None if not found
975
    """
976
    # breadth-first search
977
    nodes = [dict(context=context, context_name="", depth=depth)]
1✔
978
    while nodes:
1✔
979
        node = nodes.pop(0)
1✔
980
        for member_name, member in node["context"].items():
1✔
981
            if not member_name.startswith("_"):  # skip IPython's implicit variables
1✔
982
                if inspect.isclass(member) and issubclass(member, Table):
1✔
983
                    if member.full_table_name == name:  # found it!
1✔
984
                        return ".".join([node["context_name"], member_name]).lstrip(".")
1✔
985
                    try:  # look for part tables
1✔
986
                        parts = member.__dict__
1✔
987
                    except AttributeError:
×
988
                        pass  # not a UserTable -- cannot have part tables.
×
989
                    else:
990
                        for part in (
1✔
991
                            getattr(member, p)
992
                            for p in parts
993
                            if p[0].isupper() and hasattr(member, p)
994
                        ):
995
                            if (
1✔
996
                                inspect.isclass(part)
997
                                and issubclass(part, Table)
998
                                and part.full_table_name == name
999
                            ):
1000
                                return ".".join(
1✔
1001
                                    [node["context_name"], member_name, part.__name__]
1002
                                ).lstrip(".")
1003
                elif (
1✔
1004
                    node["depth"] > 0
1005
                    and inspect.ismodule(member)
1006
                    and member.__name__ != "datajoint"
1007
                ):
1008
                    try:
1✔
1009
                        nodes.append(
1✔
1010
                            dict(
1011
                                context=dict(inspect.getmembers(member)),
1012
                                context_name=node["context_name"] + "." + member_name,
1013
                                depth=node["depth"] - 1,
1014
                            )
1015
                        )
1016
                    except ImportError:
×
1017
                        pass  # could not import, so do not attempt
×
1018
    return None
1✔
1019

1020

1021
class FreeTable(Table):
1✔
1022
    """
1023
    A base table without a dedicated class. Each instance is associated with a table
1024
    specified by full_table_name.
1025

1026
    :param conn:  a dj.Connection object
1027
    :param full_table_name: in format `database`.`table_name`
1028
    """
1029

1030
    def __init__(self, conn, full_table_name):
1✔
1031
        self.database, self._table_name = (
1✔
1032
            s.strip("`") for s in full_table_name.split(".")
1033
        )
1034
        self._connection = conn
1✔
1035
        self._support = [full_table_name]
1✔
1036
        self._heading = Heading(
1✔
1037
            table_info=dict(
1038
                conn=conn,
1039
                database=self.database,
1040
                table_name=self.table_name,
1041
                context=None,
1042
            )
1043
        )
1044

1045
    def __repr__(self):
1✔
1046
        return (
1✔
1047
            "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name)
1048
            + super().__repr__()
1049
        )
1050

1051

1052
class Log(Table):
1✔
1053
    """
1054
    The log table for each schema.
1055
    Instances are callable.  Calls log the time and identifying information along with the event.
1056

1057
    :param skip_logging: if True, then log entry is skipped by default. See __call__
1058
    """
1059

1060
    _table_name = "~log"
1✔
1061

1062
    def __init__(self, conn, database, skip_logging=False):
1✔
1063
        self.database = database
1✔
1064
        self.skip_logging = skip_logging
1✔
1065
        self._connection = conn
1✔
1066
        self._heading = Heading(
1✔
1067
            table_info=dict(
1068
                conn=conn, database=database, table_name=self.table_name, context=None
1069
            )
1070
        )
1071
        self._support = [self.full_table_name]
1✔
1072

1073
        self._definition = """    # event logging table for `{database}`
1✔
1074
        id       :int unsigned auto_increment     # event order id
1075
        ---
1076
        timestamp = CURRENT_TIMESTAMP : timestamp # event timestamp
1077
        version  :varchar(12)                     # datajoint version
1078
        user     :varchar(255)                    # user@host
1079
        host=""  :varchar(255)                    # system hostname
1080
        event="" :varchar(255)                    # event message
1081
        """.format(
1082
            database=database
1083
        )
1084

1085
        super().__init__()
1✔
1086

1087
        if not self.is_declared:
1✔
1088
            self.declare()
1✔
1089
            self.connection.dependencies.clear()
1✔
1090
        self._user = self.connection.get_user()
1✔
1091

1092
    @property
1✔
1093
    def definition(self):
1✔
1094
        return self._definition
1✔
1095

1096
    def __call__(self, event, skip_logging=None):
1✔
1097
        """
1098

1099
        :param event: string to write into the log table
1100
        :param skip_logging: If True then do not log. If None, then use self.skip_logging
1101
        """
1102
        skip_logging = self.skip_logging if skip_logging is None else skip_logging
1✔
1103
        if not skip_logging:
1✔
1104
            try:
1✔
1105
                self.insert1(
1✔
1106
                    dict(
1107
                        user=self._user,
1108
                        version=version + "py",
1109
                        host=platform.uname().node,
1110
                        event=event,
1111
                    ),
1112
                    skip_duplicates=True,
1113
                    ignore_extra_fields=True,
1114
                )
1115
            except DataJointError:
×
1116
                logger.info("could not log event in table ~log")
×
1117

1118
    def delete(self):
1✔
1119
        """
1120
        bypass interactive prompts and cascading dependencies
1121

1122
        :return: number of deleted items
1123
        """
1124
        return self.delete_quick(get_count=True)
×
1125

1126
    def drop(self):
1✔
1127
        """bypass interactive prompts and cascading dependencies"""
1128
        self.drop_quick()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc