• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12880

pending completion
#12880

push

travis-ci

web-flow
Merge pull request #1067 from CBroz1/master

Add support for insert CSV

4 of 4 new or added lines in 1 file covered. (100.0%)

3102 of 3424 relevant lines covered (90.6%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.54
/datajoint/table.py
1
import collections
1✔
2
import itertools
1✔
3
import inspect
1✔
4
import platform
1✔
5
import numpy as np
1✔
6
import pandas
1✔
7
import logging
1✔
8
import uuid
1✔
9
import csv
1✔
10
import re
1✔
11
from pathlib import Path
1✔
12
from .settings import config
1✔
13
from .declare import declare, alter
1✔
14
from .condition import make_condition
1✔
15
from .expression import QueryExpression
1✔
16
from . import blob
1✔
17
from .utils import user_choice, get_master
1✔
18
from .heading import Heading
1✔
19
from .errors import (
1✔
20
    DuplicateError,
21
    AccessError,
22
    DataJointError,
23
    UnknownAttributeError,
24
    IntegrityError,
25
)
26
from typing import Union
1✔
27
from .version import __version__ as version
1✔
28

29
logger = logging.getLogger(__name__.split(".")[0])
1✔
30

31
foreign_key_error_regexp = re.compile(
1✔
32
    r"[\w\s:]*\((?P<child>`[^`]+`.`[^`]+`), "
33
    r"CONSTRAINT (?P<name>`[^`]+`) "
34
    r"(FOREIGN KEY \((?P<fk_attrs>[^)]+)\) "
35
    r"REFERENCES (?P<parent>`[^`]+`(\.`[^`]+`)?) \((?P<pk_attrs>[^)]+)\)[\s\w]+\))?"
36
)
37

38
constraint_info_query = " ".join(
1✔
39
    """
40
    SELECT
41
        COLUMN_NAME as fk_attrs,
42
        CONCAT('`', REFERENCED_TABLE_SCHEMA, '`.`', REFERENCED_TABLE_NAME, '`') as parent,
43
        REFERENCED_COLUMN_NAME as pk_attrs
44
    FROM INFORMATION_SCHEMA.KEY_COLUMN_USAGE
45
    WHERE
46
        CONSTRAINT_NAME = %s AND TABLE_SCHEMA = %s AND TABLE_NAME = %s;
47
    """.split()
48
)
49

50

51
class _RenameMap(tuple):
1✔
52
    """for internal use"""
53

54
    pass
1✔
55

56

57
class Table(QueryExpression):
1✔
58
    """
59
    Table is an abstract class that represents a table in the schema.
60
    It implements insert and delete methods and inherits query functionality.
61
    To make it a concrete class, override the abstract properties specifying the connection,
62
    table name, database, and definition.
63
    """
64

65
    _table_name = None  # must be defined in subclass
1✔
66
    _log_ = None  # placeholder for the Log table object
1✔
67

68
    # These properties must be set by the schema decorator (schemas.py) at class level
69
    # or by FreeTable at instance level
70
    database = None
1✔
71
    declaration_context = None
1✔
72

73
    @property
1✔
74
    def table_name(self):
1✔
75
        return self._table_name
1✔
76

77
    @property
1✔
78
    def definition(self):
1✔
79
        raise NotImplementedError(
×
80
            "Subclasses of Table must implement the `definition` property"
81
        )
82

83
    def declare(self, context=None):
1✔
84
        """
85
        Declare the table in the schema based on self.definition.
86

87
        :param context: the context for foreign key resolution. If None, foreign keys are
88
            not allowed.
89
        """
90
        if self.connection.in_transaction:
1✔
91
            raise DataJointError(
×
92
                "Cannot declare new tables inside a transaction, "
93
                "e.g. from inside a populate/make call"
94
            )
95
        sql, external_stores = declare(self.full_table_name, self.definition, context)
1✔
96
        sql = sql.format(database=self.database)
1✔
97
        try:
1✔
98
            # declare all external tables before declaring main table
99
            for store in external_stores:
1✔
100
                self.connection.schemas[self.database].external[store]
1✔
101
            self.connection.query(sql)
1✔
102
        except AccessError:
1✔
103
            # skip if no create privilege
104
            pass
1✔
105
        else:
106
            self._log("Declared " + self.full_table_name)
1✔
107

108
    def alter(self, prompt=True, context=None):
1✔
109
        """
110
        Alter the table definition from self.definition
111
        """
112
        if self.connection.in_transaction:
1✔
113
            raise DataJointError(
×
114
                "Cannot update table declaration inside a transaction, "
115
                "e.g. from inside a populate/make call"
116
            )
117
        if context is None:
1✔
118
            frame = inspect.currentframe().f_back
1✔
119
            context = dict(frame.f_globals, **frame.f_locals)
1✔
120
            del frame
1✔
121
        old_definition = self.describe(context=context, printout=False)
1✔
122
        sql, external_stores = alter(self.definition, old_definition, context)
1✔
123
        if not sql:
1✔
124
            if prompt:
×
125
                print("Nothing to alter.")
×
126
        else:
127
            sql = "ALTER TABLE {tab}\n\t".format(
1✔
128
                tab=self.full_table_name
129
            ) + ",\n\t".join(sql)
130
            if not prompt or user_choice(sql + "\n\nExecute?") == "yes":
1✔
131
                try:
1✔
132
                    # declare all external tables before declaring main table
133
                    for store in external_stores:
1✔
134
                        self.connection.schemas[self.database].external[store]
×
135
                    self.connection.query(sql)
1✔
136
                except AccessError:
×
137
                    # skip if no create privilege
138
                    pass
×
139
                else:
140
                    # reset heading
141
                    self.__class__._heading = Heading(
1✔
142
                        table_info=self.heading.table_info
143
                    )
144
                    if prompt:
1✔
145
                        print("Table altered")
×
146
                    self._log("Altered " + self.full_table_name)
1✔
147

148
    def from_clause(self):
1✔
149
        """
150
        :return: the FROM clause of SQL SELECT statements.
151
        """
152
        return self.full_table_name
1✔
153

154
    def get_select_fields(self, select_fields=None):
1✔
155
        """
156
        :return: the selected attributes from the SQL SELECT statement.
157
        """
158
        return (
×
159
            "*" if select_fields is None else self.heading.project(select_fields).as_sql
160
        )
161

162
    def parents(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
163
        """
164

165
        :param primary: if None, then all parents are returned. If True, then only foreign keys composed of
166
            primary key attributes are considered.  If False, return foreign keys including at least one
167
            secondary attribute.
168
        :param as_objects: if False, return table names. If True, return table objects.
169
        :param foreign_key_info: if True, each element in result also includes foreign key info.
170
        :return: list of parents as table names or table objects
171
            with (optional) foreign key information.
172
        """
173
        get_edge = self.connection.dependencies.parents
1✔
174
        nodes = [
1✔
175
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
176
            for name, props in get_edge(self.full_table_name, primary).items()
177
        ]
178
        if as_objects:
1✔
179
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
180
        if not foreign_key_info:
1✔
181
            nodes = [name for name, props in nodes]
1✔
182
        return nodes
1✔
183

184
    def children(self, primary=None, as_objects=False, foreign_key_info=False):
1✔
185
        """
186

187
        :param primary: if None, then all children are returned. If True, then only foreign keys composed of
188
            primary key attributes are considered.  If False, return foreign keys including at least one
189
            secondary attribute.
190
        :param as_objects: if False, return table names. If True, return table objects.
191
        :param foreign_key_info: if True, each element in result also includes foreign key info.
192
        :return: list of children as table names or table objects
193
            with (optional) foreign key information.
194
        """
195
        get_edge = self.connection.dependencies.children
1✔
196
        nodes = [
1✔
197
            next(iter(get_edge(name).items())) if name.isdigit() else (name, props)
198
            for name, props in get_edge(self.full_table_name, primary).items()
199
        ]
200
        if as_objects:
1✔
201
            nodes = [(FreeTable(self.connection, name), props) for name, props in nodes]
1✔
202
        if not foreign_key_info:
1✔
203
            nodes = [name for name, props in nodes]
1✔
204
        return nodes
1✔
205

206
    def descendants(self, as_objects=False):
1✔
207
        """
208

209
        :param as_objects: False - a list of table names; True - a list of table objects.
210
        :return: list of tables descendants in topological order.
211
        """
212
        return [
1✔
213
            FreeTable(self.connection, node) if as_objects else node
214
            for node in self.connection.dependencies.descendants(self.full_table_name)
215
            if not node.isdigit()
216
        ]
217

218
    def ancestors(self, as_objects=False):
1✔
219
        """
220

221
        :param as_objects: False - a list of table names; True - a list of table objects.
222
        :return: list of tables ancestors in topological order.
223
        """
224
        return [
1✔
225
            FreeTable(self.connection, node) if as_objects else node
226
            for node in self.connection.dependencies.ancestors(self.full_table_name)
227
            if not node.isdigit()
228
        ]
229

230
    def parts(self, as_objects=False):
1✔
231
        """
232
        return part tables either as entries in a dict with foreign key informaiton or a list of objects
233

234
        :param as_objects: if False (default), the output is a dict describing the foreign keys. If True, return table objects.
235
        """
236
        nodes = [
1✔
237
            node
238
            for node in self.connection.dependencies.nodes
239
            if not node.isdigit() and node.startswith(self.full_table_name[:-1] + "__")
240
        ]
241
        return [FreeTable(self.connection, c) for c in nodes] if as_objects else nodes
1✔
242

243
    @property
1✔
244
    def is_declared(self):
1✔
245
        """
246
        :return: True is the table is declared in the schema.
247
        """
248
        return (
1✔
249
            self.connection.query(
250
                'SHOW TABLES in `{database}` LIKE "{table_name}"'.format(
251
                    database=self.database, table_name=self.table_name
252
                )
253
            ).rowcount
254
            > 0
255
        )
256

257
    @property
1✔
258
    def full_table_name(self):
1✔
259
        """
260
        :return: full table name in the schema
261
        """
262
        return r"`{0:s}`.`{1:s}`".format(self.database, self.table_name)
1✔
263

264
    @property
1✔
265
    def _log(self):
1✔
266
        if self._log_ is None:
1✔
267
            self._log_ = Log(
1✔
268
                self.connection,
269
                database=self.database,
270
                skip_logging=self.table_name.startswith("~"),
271
            )
272
        return self._log_
1✔
273

274
    @property
1✔
275
    def external(self):
1✔
276
        return self.connection.schemas[self.database].external
1✔
277

278
    def update1(self, row):
1✔
279
        """
280
        ``update1`` updates one existing entry in the table.
281
        Caution: In DataJoint the primary modes for data manipulation is to ``insert`` and
282
        ``delete`` entire records since referential integrity works on the level of records,
283
        not fields. Therefore, updates are reserved for corrective operations outside of main
284
        workflow. Use UPDATE methods sparingly with full awareness of potential violations of
285
        assumptions.
286

287
        :param row: a ``dict`` containing the primary key values and the attributes to update.
288
            Setting an attribute value to None will reset it to the default value (if any).
289

290
        The primary key attributes must always be provided.
291

292
        Examples:
293

294
        >>> table.update1({'id': 1, 'value': 3})  # update value in record with id=1
295
        >>> table.update1({'id': 1, 'value': None})  # reset value to default
296
        """
297
        # argument validations
298
        if not isinstance(row, collections.abc.Mapping):
1✔
299
            raise DataJointError("The argument of update1 must be dict-like.")
×
300
        if not set(row).issuperset(self.primary_key):
1✔
301
            raise DataJointError(
1✔
302
                "The argument of update1 must supply all primary key values."
303
            )
304
        try:
1✔
305
            raise DataJointError(
1✔
306
                "Attribute `%s` not found."
307
                % next(k for k in row if k not in self.heading.names)
308
            )
309
        except StopIteration:
1✔
310
            pass  # ok
1✔
311
        if len(self.restriction):
1✔
312
            raise DataJointError("Update cannot be applied to a restricted table.")
×
313
        key = {k: row[k] for k in self.primary_key}
1✔
314
        if len(self & key) != 1:
1✔
315
            raise DataJointError("Update can only be applied to one existing entry.")
1✔
316
        # UPDATE query
317
        row = [
1✔
318
            self.__make_placeholder(k, v)
319
            for k, v in row.items()
320
            if k not in self.primary_key
321
        ]
322
        query = "UPDATE {table} SET {assignments} WHERE {where}".format(
1✔
323
            table=self.full_table_name,
324
            assignments=",".join("`%s`=%s" % r[:2] for r in row),
325
            where=make_condition(self, key, set()),
326
        )
327
        self.connection.query(query, args=list(r[2] for r in row if r[2] is not None))
1✔
328

329
    def insert1(self, row, **kwargs):
1✔
330
        """
331
        Insert one data record into the table. For ``kwargs``, see ``insert()``.
332

333
        :param row: a numpy record, a dict-like object, or an ordered sequence to be inserted
334
            as one row.
335
        """
336
        self.insert((row,), **kwargs)
1✔
337

338
    def insert(
1✔
339
        self,
340
        rows,
341
        replace=False,
342
        skip_duplicates=False,
343
        ignore_extra_fields=False,
344
        allow_direct_insert=None,
345
    ):
346
        """
347
        Insert a collection of rows.
348

349
        :param rows: Either (a) an iterable where an element is a numpy record, a
350
            dict-like object, a pandas.DataFrame, a sequence, or a query expression with
351
            the same heading as self, or (b) a pathlib.Path object specifying a path
352
            relative to the current directory with a CSV file, the contents of which
353
            will be inserted.
354
        :param replace: If True, replaces the existing tuple.
355
        :param skip_duplicates: If True, silently skip duplicate inserts.
356
        :param ignore_extra_fields: If False, fields that are not in the heading raise error.
357
        :param allow_direct_insert: Only applies in auto-populated tables. If False (default),
358
            insert may only be called from inside the make callback.
359

360
        Example:
361

362
            >>> Table.insert([
363
            >>>     dict(subject_id=7, species="mouse", date_of_birth="2014-09-01"),
364
            >>>     dict(subject_id=8, species="mouse", date_of_birth="2014-09-02")])
365
        """
366
        if isinstance(rows, pandas.DataFrame):
1✔
367
            # drop 'extra' synthetic index for 1-field index case -
368
            # frames with more advanced indices should be prepared by user.
369
            rows = rows.reset_index(
1✔
370
                drop=len(rows.index.names) == 1 and not rows.index.names[0]
371
            ).to_records(index=False)
372

373
        if isinstance(rows, Path):
1✔
374
            with open(rows, newline="") as data_file:
1✔
375
                rows = list(csv.DictReader(data_file, delimiter=","))
1✔
376

377
        # prohibit direct inserts into auto-populated tables
378
        if not allow_direct_insert and not getattr(self, "_allow_insert", True):
1✔
379
            raise DataJointError(
1✔
380
                "Inserts into an auto-populated table can only be done inside "
381
                "its make method during a populate call."
382
                " To override, set keyword argument allow_direct_insert=True."
383
            )
384

385
        if inspect.isclass(rows) and issubclass(rows, QueryExpression):
1✔
386
            rows = rows()  # instantiate if a class
1✔
387
        if isinstance(rows, QueryExpression):
1✔
388
            # insert from select
389
            if not ignore_extra_fields:
1✔
390
                try:
1✔
391
                    raise DataJointError(
1✔
392
                        "Attribute %s not found. To ignore extra attributes in insert, "
393
                        "set ignore_extra_fields=True."
394
                        % next(
395
                            name for name in rows.heading if name not in self.heading
396
                        )
397
                    )
398
                except StopIteration:
1✔
399
                    pass
1✔
400
            fields = list(name for name in rows.heading if name in self.heading)
1✔
401
            query = "{command} INTO {table} ({fields}) {select}{duplicate}".format(
1✔
402
                command="REPLACE" if replace else "INSERT",
403
                fields="`" + "`,`".join(fields) + "`",
404
                table=self.full_table_name,
405
                select=rows.make_sql(fields),
406
                duplicate=(
407
                    " ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`".format(
408
                        table=self.full_table_name, pk=self.primary_key[0]
409
                    )
410
                    if skip_duplicates
411
                    else ""
412
                ),
413
            )
414
            self.connection.query(query)
1✔
415
            return
1✔
416

417
        field_list = []  # collects the field list from first row (passed by reference)
1✔
418
        rows = list(
1✔
419
            self.__make_row_to_insert(row, field_list, ignore_extra_fields)
420
            for row in rows
421
        )
422
        if rows:
1✔
423
            try:
1✔
424
                query = "{command} INTO {destination}(`{fields}`) VALUES {placeholders}{duplicate}".format(
1✔
425
                    command="REPLACE" if replace else "INSERT",
426
                    destination=self.from_clause(),
427
                    fields="`,`".join(field_list),
428
                    placeholders=",".join(
429
                        "(" + ",".join(row["placeholders"]) + ")" for row in rows
430
                    ),
431
                    duplicate=(
432
                        " ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`".format(
433
                            pk=self.primary_key[0]
434
                        )
435
                        if skip_duplicates
436
                        else ""
437
                    ),
438
                )
439
                self.connection.query(
1✔
440
                    query,
441
                    args=list(
442
                        itertools.chain.from_iterable(
443
                            (v for v in r["values"] if v is not None) for r in rows
444
                        )
445
                    ),
446
                )
447
            except UnknownAttributeError as err:
1✔
448
                raise err.suggest(
×
449
                    "To ignore extra fields in insert, set ignore_extra_fields=True"
450
                )
451
            except DuplicateError as err:
1✔
452
                raise err.suggest(
1✔
453
                    "To ignore duplicate entries in insert, set skip_duplicates=True"
454
                )
455

456
    def delete_quick(self, get_count=False):
1✔
457
        """
458
        Deletes the table without cascading and without user prompt.
459
        If this table has populated dependent tables, this will fail.
460
        """
461
        query = "DELETE FROM " + self.full_table_name + self.where_clause()
1✔
462
        self.connection.query(query)
1✔
463
        count = (
1✔
464
            self.connection.query("SELECT ROW_COUNT()").fetchone()[0]
465
            if get_count
466
            else None
467
        )
468
        self._log(query[:255])
1✔
469
        return count
1✔
470

471
    def delete(
1✔
472
        self,
473
        transaction: bool = True,
474
        safemode: Union[bool, None] = None,
475
        force_parts: bool = False,
476
    ) -> int:
477
        """
478
        Deletes the contents of the table and its dependent tables, recursively.
479

480
        Args:
481
            transaction: If `True`, use of the entire delete becomes an atomic transaction.
482
                This is the default and recommended behavior. Set to `False` if this delete is
483
                nested within another transaction.
484
            safemode: If `True`, prohibit nested transactions and prompt to confirm. Default
485
                is `dj.config['safemode']`.
486
            force_parts: Delete from parts even when not deleting from their masters.
487

488
        Returns:
489
            Number of deleted rows (excluding those from dependent tables).
490

491
        Raises:
492
            DataJointError: Delete exceeds maximum number of delete attempts.
493
            DataJointError: When deleting within an existing transaction.
494
            DataJointError: Deleting a part table before its master.
495
        """
496
        deleted = set()
1✔
497

498
        def cascade(table):
1✔
499
            """service function to perform cascading deletes recursively."""
500
            max_attempts = 50
1✔
501
            for _ in range(max_attempts):
1✔
502
                try:
1✔
503
                    delete_count = table.delete_quick(get_count=True)
1✔
504
                except IntegrityError as error:
1✔
505
                    match = foreign_key_error_regexp.match(error.args[0]).groupdict()
1✔
506
                    if "`.`" not in match["child"]:  # if schema name missing, use table
1✔
507
                        match["child"] = "{}.{}".format(
×
508
                            table.full_table_name.split(".")[0], match["child"]
509
                        )
510
                    if (
1✔
511
                        match["pk_attrs"] is not None
512
                    ):  # fully matched, adjusting the keys
513
                        match["fk_attrs"] = [
1✔
514
                            k.strip("`") for k in match["fk_attrs"].split(",")
515
                        ]
516
                        match["pk_attrs"] = [
1✔
517
                            k.strip("`") for k in match["pk_attrs"].split(",")
518
                        ]
519
                    else:  # only partially matched, querying with constraint to determine keys
520
                        match["fk_attrs"], match["parent"], match["pk_attrs"] = list(
1✔
521
                            map(
522
                                list,
523
                                zip(
524
                                    *table.connection.query(
525
                                        constraint_info_query,
526
                                        args=(
527
                                            match["name"].strip("`"),
528
                                            *[
529
                                                _.strip("`")
530
                                                for _ in match["child"].split("`.`")
531
                                            ],
532
                                        ),
533
                                    ).fetchall()
534
                                ),
535
                            )
536
                        )
537
                        match["parent"] = match["parent"][0]
1✔
538

539
                    # Restrict child by table if
540
                    #   1. if table's restriction attributes are not in child's primary key
541
                    #   2. if child renames any attributes
542
                    # Otherwise restrict child by table's restriction.
543
                    child = FreeTable(table.connection, match["child"])
1✔
544
                    if (
1✔
545
                        set(table.restriction_attributes) <= set(child.primary_key)
546
                        and match["fk_attrs"] == match["pk_attrs"]
547
                    ):
548
                        child._restriction = table._restriction
1✔
549
                    elif match["fk_attrs"] != match["pk_attrs"]:
1✔
550
                        child &= table.proj(
1✔
551
                            **dict(zip(match["fk_attrs"], match["pk_attrs"]))
552
                        )
553
                    else:
554
                        child &= table.proj()
1✔
555
                    cascade(child)
1✔
556
                else:
557
                    deleted.add(table.full_table_name)
1✔
558
                    logger.info(
1✔
559
                        "Deleting {count} rows from {table}".format(
560
                            count=delete_count, table=table.full_table_name
561
                        )
562
                    )
563
                    break
1✔
564
            else:
565
                raise DataJointError("Exceeded maximum number of delete attempts.")
×
566
            return delete_count
1✔
567

568
        safemode = config["safemode"] if safemode is None else safemode
1✔
569

570
        # Start transaction
571
        if transaction:
1✔
572
            if not self.connection.in_transaction:
1✔
573
                self.connection.start_transaction()
1✔
574
            else:
575
                if not safemode:
×
576
                    transaction = False
×
577
                else:
578
                    raise DataJointError(
×
579
                        "Delete cannot use a transaction within an ongoing transaction. "
580
                        "Set transaction=False or safemode=False)."
581
                    )
582

583
        # Cascading delete
584
        try:
1✔
585
            delete_count = cascade(self)
1✔
586
        except:
×
587
            if transaction:
×
588
                self.connection.cancel_transaction()
×
589
            raise
×
590

591
        if not force_parts:
1✔
592
            # Avoid deleting from child before master (See issue #151)
593
            for part in deleted:
1✔
594
                master = get_master(part)
1✔
595
                if master and master not in deleted:
1✔
596
                    if transaction:
1✔
597
                        self.connection.cancel_transaction()
1✔
598
                    raise DataJointError(
1✔
599
                        "Attempt to delete part table {part} before deleting from "
600
                        "its master {master} first.".format(part=part, master=master)
601
                    )
602

603
        # Confirm and commit
604
        if delete_count == 0:
1✔
605
            if safemode:
1✔
606
                print("Nothing to delete.")
×
607
            if transaction:
1✔
608
                self.connection.cancel_transaction()
1✔
609
        else:
610
            if not safemode or user_choice("Commit deletes?", default="no") == "yes":
1✔
611
                if transaction:
1✔
612
                    self.connection.commit_transaction()
1✔
613
                if safemode:
1✔
614
                    print("Deletes committed.")
×
615
            else:
616
                if transaction:
×
617
                    self.connection.cancel_transaction()
×
618
                if safemode:
×
619
                    print("Deletes cancelled")
×
620
        return delete_count
1✔
621

622
    def drop_quick(self):
1✔
623
        """
624
        Drops the table without cascading to dependent tables and without user prompt.
625
        """
626
        if self.is_declared:
1✔
627
            query = "DROP TABLE %s" % self.full_table_name
1✔
628
            self.connection.query(query)
1✔
629
            logger.info("Dropped table %s" % self.full_table_name)
1✔
630
            self._log(query[:255])
1✔
631
        else:
632
            logger.info(
×
633
                "Nothing to drop: table %s is not declared" % self.full_table_name
634
            )
635

636
    def drop(self):
1✔
637
        """
638
        Drop the table and all tables that reference it, recursively.
639
        User is prompted for confirmation if config['safemode'] is set to True.
640
        """
641
        if self.restriction:
1✔
642
            raise DataJointError(
×
643
                "A table with an applied restriction cannot be dropped."
644
                " Call drop() on the unrestricted Table."
645
            )
646
        self.connection.dependencies.load()
1✔
647
        do_drop = True
1✔
648
        tables = [
1✔
649
            table
650
            for table in self.connection.dependencies.descendants(self.full_table_name)
651
            if not table.isdigit()
652
        ]
653

654
        # avoid dropping part tables without their masters: See issue #374
655
        for part in tables:
1✔
656
            master = get_master(part)
1✔
657
            if master and master not in tables:
1✔
658
                raise DataJointError(
1✔
659
                    "Attempt to drop part table {part} before dropping "
660
                    "its master. Drop {master} first.".format(part=part, master=master)
661
                )
662

663
        if config["safemode"]:
1✔
664
            for table in tables:
1✔
665
                print(table, "(%d tuples)" % len(FreeTable(self.connection, table)))
1✔
666
            do_drop = user_choice("Proceed?", default="no") == "yes"
1✔
667
        if do_drop:
1✔
668
            for table in reversed(tables):
1✔
669
                FreeTable(self.connection, table).drop_quick()
1✔
670
            print("Tables dropped.  Restart kernel.")
1✔
671

672
    @property
1✔
673
    def size_on_disk(self):
1✔
674
        """
675
        :return: size of data and indices in bytes on the storage device
676
        """
677
        ret = self.connection.query(
1✔
678
            'SHOW TABLE STATUS FROM `{database}` WHERE NAME="{table}"'.format(
679
                database=self.database, table=self.table_name
680
            ),
681
            as_dict=True,
682
        ).fetchone()
683
        return ret["Data_length"] + ret["Index_length"]
1✔
684

685
    def show_definition(self):
1✔
686
        raise AttributeError(
×
687
            "show_definition is deprecated. Use the describe method instead."
688
        )
689

690
    def describe(self, context=None, printout=True):
1✔
691
        """
692
        :return:  the definition string for the query using DataJoint DDL.
693
        """
694
        if context is None:
1✔
695
            frame = inspect.currentframe().f_back
1✔
696
            context = dict(frame.f_globals, **frame.f_locals)
1✔
697
            del frame
1✔
698
        if self.full_table_name not in self.connection.dependencies:
1✔
699
            self.connection.dependencies.load()
1✔
700
        parents = self.parents(foreign_key_info=True)
1✔
701
        in_key = True
1✔
702
        definition = (
1✔
703
            "# " + self.heading.table_status["comment"] + "\n"
704
            if self.heading.table_status["comment"]
705
            else ""
706
        )
707
        attributes_thus_far = set()
1✔
708
        attributes_declared = set()
1✔
709
        indexes = self.heading.indexes.copy()
1✔
710
        for attr in self.heading.attributes.values():
1✔
711
            if in_key and not attr.in_key:
1✔
712
                definition += "---\n"
1✔
713
                in_key = False
1✔
714
            attributes_thus_far.add(attr.name)
1✔
715
            do_include = True
1✔
716
            for parent_name, fk_props in parents:
1✔
717
                if attr.name in fk_props["attr_map"]:
1✔
718
                    do_include = False
1✔
719
                    if attributes_thus_far.issuperset(fk_props["attr_map"]):
1✔
720
                        # foreign key properties
721
                        try:
1✔
722
                            index_props = indexes.pop(tuple(fk_props["attr_map"]))
1✔
723
                        except KeyError:
1✔
724
                            index_props = ""
1✔
725
                        else:
726
                            index_props = [k for k, v in index_props.items() if v]
1✔
727
                            index_props = (
1✔
728
                                " [{}]".format(", ".join(index_props))
729
                                if index_props
730
                                else ""
731
                            )
732

733
                        if not fk_props["aliased"]:
1✔
734
                            # simple foreign key
735
                            definition += "->{props} {class_name}\n".format(
1✔
736
                                props=index_props,
737
                                class_name=lookup_class_name(parent_name, context)
738
                                or parent_name,
739
                            )
740
                        else:
741
                            # projected foreign key
742
                            definition += (
1✔
743
                                "->{props} {class_name}.proj({proj_list})\n".format(
744
                                    props=index_props,
745
                                    class_name=lookup_class_name(parent_name, context)
746
                                    or parent_name,
747
                                    proj_list=",".join(
748
                                        '{}="{}"'.format(attr, ref)
749
                                        for attr, ref in fk_props["attr_map"].items()
750
                                        if ref != attr
751
                                    ),
752
                                )
753
                            )
754
                            attributes_declared.update(fk_props["attr_map"])
1✔
755
            if do_include:
1✔
756
                attributes_declared.add(attr.name)
1✔
757
                definition += "%-20s : %-28s %s\n" % (
1✔
758
                    attr.name
759
                    if attr.default is None
760
                    else "%s=%s" % (attr.name, attr.default),
761
                    "%s%s"
762
                    % (attr.type, " auto_increment" if attr.autoincrement else ""),
763
                    "# " + attr.comment if attr.comment else "",
764
                )
765
        # add remaining indexes
766
        for k, v in indexes.items():
1✔
767
            definition += "{unique}INDEX ({attrs})\n".format(
1✔
768
                unique="UNIQUE " if v["unique"] else "", attrs=", ".join(k)
769
            )
770
        if printout:
1✔
771
            print(definition)
1✔
772
        return definition
1✔
773

774
    def _update(self, attrname, value=None):
1✔
775
        """
776
        This is a deprecated function to be removed in datajoint 0.14.
777
        Use ``.update1`` instead.
778

779
        Updates a field in one existing tuple. self must be restricted to exactly one entry.
780
        In DataJoint the principal way of updating data is to delete and re-insert the
781
        entire record and updates are reserved for corrective actions.
782
        This is because referential integrity is observed on the level of entire
783
        records rather than individual attributes.
784

785
        Safety constraints:
786
           1. self must be restricted to exactly one tuple
787
           2. the update attribute must not be in primary key
788

789
        Example:
790
        >>> (v2p.Mice() & key)._update('mouse_dob', '2011-01-01')
791
        >>> (v2p.Mice() & key)._update( 'lens')   # set the value to NULL
792
        """
793
        logger.warning(
1✔
794
            "`_update` is a deprecated function to be removed in datajoint 0.14. "
795
            "Use `.update1` instead."
796
        )
797
        if len(self) != 1:
1✔
798
            raise DataJointError("Update is only allowed on one tuple at a time")
1✔
799
        if attrname not in self.heading:
1✔
800
            raise DataJointError("Invalid attribute name")
×
801
        if attrname in self.heading.primary_key:
1✔
802
            raise DataJointError("Cannot update a key value.")
×
803

804
        attr = self.heading[attrname]
1✔
805

806
        if attr.is_blob:
1✔
807
            value = blob.pack(value)
1✔
808
            placeholder = "%s"
1✔
809
        elif attr.numeric:
1✔
810
            if value is None or np.isnan(float(value)):  # nans are turned into NULLs
1✔
811
                placeholder = "NULL"
1✔
812
                value = None
1✔
813
            else:
814
                placeholder = "%s"
1✔
815
                value = str(int(value) if isinstance(value, bool) else value)
1✔
816
        else:
817
            placeholder = "%s" if value is not None else "NULL"
1✔
818
        command = "UPDATE {full_table_name} SET `{attrname}`={placeholder} {where_clause}".format(
1✔
819
            full_table_name=self.from_clause(),
820
            attrname=attrname,
821
            placeholder=placeholder,
822
            where_clause=self.where_clause(),
823
        )
824
        self.connection.query(command, args=(value,) if value is not None else ())
1✔
825

826
    # --- private helper functions ----
827
    def __make_placeholder(self, name, value, ignore_extra_fields=False):
1✔
828
        """
829
        For a given attribute `name` with `value`, return its processed value or value placeholder
830
        as a string to be included in the query and the value, if any, to be submitted for
831
        processing by mysql API.
832

833
        :param name:  name of attribute to be inserted
834
        :param value: value of attribute to be inserted
835
        """
836
        if ignore_extra_fields and name not in self.heading:
1✔
837
            return None
×
838
        attr = self.heading[name]
1✔
839
        if attr.adapter:
1✔
840
            value = attr.adapter.put(value)
1✔
841
        if value is None or (attr.numeric and (value == "" or np.isnan(float(value)))):
1✔
842
            # set default value
843
            placeholder, value = "DEFAULT", None
1✔
844
        else:  # not NULL
845
            placeholder = "%s"
1✔
846
            if attr.uuid:
1✔
847
                if not isinstance(value, uuid.UUID):
1✔
848
                    try:
1✔
849
                        value = uuid.UUID(value)
1✔
850
                    except (AttributeError, ValueError):
1✔
851
                        raise DataJointError(
1✔
852
                            "badly formed UUID value {v} for attribute `{n}`".format(
853
                                v=value, n=name
854
                            )
855
                        )
856
                value = value.bytes
1✔
857
            elif attr.is_blob:
1✔
858
                value = blob.pack(value)
1✔
859
                value = (
1✔
860
                    self.external[attr.store].put(value).bytes
861
                    if attr.is_external
862
                    else value
863
                )
864
            elif attr.is_attachment:
1✔
865
                attachment_path = Path(value)
1✔
866
                if attr.is_external:
1✔
867
                    # value is hash of contents
868
                    value = (
1✔
869
                        self.external[attr.store]
870
                        .upload_attachment(attachment_path)
871
                        .bytes
872
                    )
873
                else:
874
                    # value is filename + contents
875
                    value = (
1✔
876
                        str.encode(attachment_path.name)
877
                        + b"\0"
878
                        + attachment_path.read_bytes()
879
                    )
880
            elif attr.is_filepath:
1✔
881
                value = self.external[attr.store].upload_filepath(value).bytes
1✔
882
            elif attr.numeric:
1✔
883
                value = str(int(value) if isinstance(value, bool) else value)
1✔
884
        return name, placeholder, value
1✔
885

886
    def __make_row_to_insert(self, row, field_list, ignore_extra_fields):
1✔
887
        """
888
        Helper function for insert and update
889

890
        :param row:  A tuple to insert
891
        :return: a dict with fields 'names', 'placeholders', 'values'
892
        """
893

894
        def check_fields(fields):
1✔
895
            """
896
            Validates that all items in `fields` are valid attributes in the heading
897

898
            :param fields: field names of a tuple
899
            """
900
            if not field_list:
1✔
901
                if not ignore_extra_fields:
1✔
902
                    for field in fields:
1✔
903
                        if field not in self.heading:
1✔
904
                            raise KeyError(
1✔
905
                                "`{0:s}` is not in the table heading".format(field)
906
                            )
907
            elif set(field_list) != set(fields).intersection(self.heading.names):
1✔
908
                raise DataJointError("Attempt to insert rows with different fields.")
1✔
909

910
        if isinstance(row, np.void):  # np.array
1✔
911
            check_fields(row.dtype.fields)
1✔
912
            attributes = [
1✔
913
                self.__make_placeholder(name, row[name], ignore_extra_fields)
914
                for name in self.heading
915
                if name in row.dtype.fields
916
            ]
917
        elif isinstance(row, collections.abc.Mapping):  # dict-based
1✔
918
            check_fields(row)
1✔
919
            attributes = [
1✔
920
                self.__make_placeholder(name, row[name], ignore_extra_fields)
921
                for name in self.heading
922
                if name in row
923
            ]
924
        else:  # positional
925
            try:
1✔
926
                if len(row) != len(self.heading):
1✔
927
                    raise DataJointError(
1✔
928
                        "Invalid insert argument. Incorrect number of attributes: "
929
                        "{given} given; {expected} expected".format(
930
                            given=len(row), expected=len(self.heading)
931
                        )
932
                    )
933
            except TypeError:
1✔
934
                raise DataJointError("Datatype %s cannot be inserted" % type(row))
1✔
935
            else:
936
                attributes = [
1✔
937
                    self.__make_placeholder(name, value, ignore_extra_fields)
938
                    for name, value in zip(self.heading, row)
939
                ]
940
        if ignore_extra_fields:
1✔
941
            attributes = [a for a in attributes if a is not None]
1✔
942

943
        assert len(attributes), "Empty tuple"
1✔
944
        row_to_insert = dict(zip(("names", "placeholders", "values"), zip(*attributes)))
1✔
945
        if not field_list:
1✔
946
            # first row sets the composition of the field list
947
            field_list.extend(row_to_insert["names"])
1✔
948
        else:
949
            #  reorder attributes in row_to_insert to match field_list
950
            order = list(row_to_insert["names"].index(field) for field in field_list)
1✔
951
            row_to_insert["names"] = list(row_to_insert["names"][i] for i in order)
1✔
952
            row_to_insert["placeholders"] = list(
1✔
953
                row_to_insert["placeholders"][i] for i in order
954
            )
955
            row_to_insert["values"] = list(row_to_insert["values"][i] for i in order)
1✔
956
        return row_to_insert
1✔
957

958

959
def lookup_class_name(name, context, depth=3):
1✔
960
    """
961
    given a table name in the form `schema_name`.`table_name`, find its class in the context.
962

963
    :param name: `schema_name`.`table_name`
964
    :param context: dictionary representing the namespace
965
    :param depth: search depth into imported modules, helps avoid infinite recursion.
966
    :return: class name found in the context or None if not found
967
    """
968
    # breadth-first search
969
    nodes = [dict(context=context, context_name="", depth=depth)]
1✔
970
    while nodes:
1✔
971
        node = nodes.pop(0)
1✔
972
        for member_name, member in node["context"].items():
1✔
973
            if not member_name.startswith("_"):  # skip IPython's implicit variables
1✔
974
                if inspect.isclass(member) and issubclass(member, Table):
1✔
975
                    if member.full_table_name == name:  # found it!
1✔
976
                        return ".".join([node["context_name"], member_name]).lstrip(".")
1✔
977
                    try:  # look for part tables
1✔
978
                        parts = member.__dict__
1✔
979
                    except AttributeError:
×
980
                        pass  # not a UserTable -- cannot have part tables.
×
981
                    else:
982
                        for part in (
1✔
983
                            getattr(member, p)
984
                            for p in parts
985
                            if p[0].isupper() and hasattr(member, p)
986
                        ):
987
                            if (
1✔
988
                                inspect.isclass(part)
989
                                and issubclass(part, Table)
990
                                and part.full_table_name == name
991
                            ):
992
                                return ".".join(
1✔
993
                                    [node["context_name"], member_name, part.__name__]
994
                                ).lstrip(".")
995
                elif (
1✔
996
                    node["depth"] > 0
997
                    and inspect.ismodule(member)
998
                    and member.__name__ != "datajoint"
999
                ):
1000
                    try:
1✔
1001
                        nodes.append(
1✔
1002
                            dict(
1003
                                context=dict(inspect.getmembers(member)),
1004
                                context_name=node["context_name"] + "." + member_name,
1005
                                depth=node["depth"] - 1,
1006
                            )
1007
                        )
1008
                    except ImportError:
×
1009
                        pass  # could not import, so do not attempt
×
1010
    return None
1✔
1011

1012

1013
class FreeTable(Table):
1✔
1014
    """
1015
    A base table without a dedicated class. Each instance is associated with a table
1016
    specified by full_table_name.
1017

1018
    :param conn:  a dj.Connection object
1019
    :param full_table_name: in format `database`.`table_name`
1020
    """
1021

1022
    def __init__(self, conn, full_table_name):
1✔
1023
        self.database, self._table_name = (
1✔
1024
            s.strip("`") for s in full_table_name.split(".")
1025
        )
1026
        self._connection = conn
1✔
1027
        self._support = [full_table_name]
1✔
1028
        self._heading = Heading(
1✔
1029
            table_info=dict(
1030
                conn=conn,
1031
                database=self.database,
1032
                table_name=self.table_name,
1033
                context=None,
1034
            )
1035
        )
1036

1037
    def __repr__(self):
1✔
1038
        return (
1✔
1039
            "FreeTable(`%s`.`%s`)\n" % (self.database, self._table_name)
1040
            + super().__repr__()
1041
        )
1042

1043

1044
class Log(Table):
1✔
1045
    """
1046
    The log table for each schema.
1047
    Instances are callable.  Calls log the time and identifying information along with the event.
1048

1049
    :param skip_logging: if True, then log entry is skipped by default. See __call__
1050
    """
1051

1052
    _table_name = "~log"
1✔
1053

1054
    def __init__(self, conn, database, skip_logging=False):
1✔
1055
        self.database = database
1✔
1056
        self.skip_logging = skip_logging
1✔
1057
        self._connection = conn
1✔
1058
        self._heading = Heading(
1✔
1059
            table_info=dict(
1060
                conn=conn, database=database, table_name=self.table_name, context=None
1061
            )
1062
        )
1063
        self._support = [self.full_table_name]
1✔
1064

1065
        self._definition = """    # event logging table for `{database}`
1✔
1066
        id       :int unsigned auto_increment     # event order id
1067
        ---
1068
        timestamp = CURRENT_TIMESTAMP : timestamp # event timestamp
1069
        version  :varchar(12)                     # datajoint version
1070
        user     :varchar(255)                    # user@host
1071
        host=""  :varchar(255)                    # system hostname
1072
        event="" :varchar(255)                    # event message
1073
        """.format(
1074
            database=database
1075
        )
1076

1077
        super().__init__()
1✔
1078

1079
        if not self.is_declared:
1✔
1080
            self.declare()
1✔
1081
            self.connection.dependencies.clear()
1✔
1082
        self._user = self.connection.get_user()
1✔
1083

1084
    @property
1✔
1085
    def definition(self):
1✔
1086
        return self._definition
1✔
1087

1088
    def __call__(self, event, skip_logging=None):
1✔
1089
        """
1090

1091
        :param event: string to write into the log table
1092
        :param skip_logging: If True then do not log. If None, then use self.skip_logging
1093
        """
1094
        skip_logging = self.skip_logging if skip_logging is None else skip_logging
1✔
1095
        if not skip_logging:
1✔
1096
            try:
1✔
1097
                self.insert1(
1✔
1098
                    dict(
1099
                        user=self._user,
1100
                        version=version + "py",
1101
                        host=platform.uname().node,
1102
                        event=event,
1103
                    ),
1104
                    skip_duplicates=True,
1105
                    ignore_extra_fields=True,
1106
                )
1107
            except DataJointError:
×
1108
                logger.info("could not log event in table ~log")
×
1109

1110
    def delete(self):
1✔
1111
        """
1112
        bypass interactive prompts and cascading dependencies
1113

1114
        :return: number of deleted items
1115
        """
1116
        return self.delete_quick(get_count=True)
×
1117

1118
    def drop(self):
1✔
1119
        """bypass interactive prompts and cascading dependencies"""
1120
        self.drop_quick()
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc