• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

datajoint / datajoint-python / #12880

pending completion
#12880

push

travis-ci

web-flow
Merge pull request #1067 from CBroz1/master

Add support for insert CSV

4 of 4 new or added lines in 1 file covered. (100.0%)

3102 of 3424 relevant lines covered (90.6%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

87.56
/datajoint/connection.py
1
"""
2
This module contains the Connection class that manages the connection to the database, and
3
the ``conn`` function that provides access to a persistent connection in datajoint.
4
"""
5
import warnings
1✔
6
from contextlib import contextmanager
1✔
7
import pymysql as client
1✔
8
import logging
1✔
9
from getpass import getpass
1✔
10
import re
1✔
11
import pathlib
1✔
12

13
from .settings import config
1✔
14
from . import errors
1✔
15
from .dependencies import Dependencies
1✔
16
from .blob import pack, unpack
1✔
17
from .hash import uuid_from_buffer
1✔
18
from .plugin import connection_plugins
1✔
19

20
logger = logging.getLogger(__name__.split(".")[0])
1✔
21
query_log_max_length = 300
1✔
22

23

24
cache_key = "query_cache"  # the key to lookup the query_cache folder in dj.config
1✔
25

26

27
def get_host_hook(host_input):
1✔
28
    if "://" in host_input:
1✔
29
        plugin_name = host_input.split("://")[0]
×
30
        try:
×
31
            return connection_plugins[plugin_name]["object"].load().get_host(host_input)
×
32
        except KeyError:
×
33
            raise errors.DataJointError(
×
34
                "Connection plugin '{}' not found.".format(plugin_name)
35
            )
36
    else:
37
        return host_input
1✔
38

39

40
def connect_host_hook(connection_obj):
1✔
41
    if "://" in connection_obj.conn_info["host_input"]:
1✔
42
        plugin_name = connection_obj.conn_info["host_input"].split("://")[0]
×
43
        try:
×
44
            connection_plugins[plugin_name]["object"].load().connect_host(
×
45
                connection_obj
46
            )
47
        except KeyError:
×
48
            raise errors.DataJointError(
×
49
                "Connection plugin '{}' not found.".format(plugin_name)
50
            )
51
    else:
52
        connection_obj.connect()
1✔
53

54

55
def translate_query_error(client_error, query):
1✔
56
    """
57
    Take client error and original query and return the corresponding DataJoint exception.
58

59
    :param client_error: the exception raised by the client interface
60
    :param query: sql query with placeholders
61
    :return: an instance of the corresponding subclass of datajoint.errors.DataJointError
62
    """
63
    logger.debug("type: {}, args: {}".format(type(client_error), client_error.args))
1✔
64

65
    err, *args = client_error.args
1✔
66

67
    # Loss of connection errors
68
    if err in (0, "(0, '')"):
1✔
69
        return errors.LostConnectionError(
1✔
70
            "Server connection lost due to an interface error.", *args
71
        )
72
    if err == 2006:
1✔
73
        return errors.LostConnectionError("Connection timed out", *args)
×
74
    if err == 2013:
1✔
75
        return errors.LostConnectionError("Server connection lost", *args)
×
76
    # Access errors
77
    if err in (1044, 1142):
1✔
78
        return errors.AccessError("Insufficient privileges.", args[0], query)
1✔
79
    # Integrity errors
80
    if err == 1062:
1✔
81
        return errors.DuplicateError(*args)
1✔
82
    if err == 1451:
1✔
83
        return errors.IntegrityError(*args)
1✔
84
    if err == 1452:
1✔
85
        return errors.IntegrityError(*args)
×
86
    # Syntax errors
87
    if err == 1064:
1✔
88
        return errors.QuerySyntaxError(args[0], query)
1✔
89
    # Existence errors
90
    if err == 1146:
1✔
91
        return errors.MissingTableError(args[0], query)
1✔
92
    if err == 1364:
1✔
93
        return errors.MissingAttributeError(*args)
1✔
94
    if err == 1054:
×
95
        return errors.UnknownAttributeError(*args)
×
96
    # all the other errors are re-raised in original form
97
    return client_error
×
98

99

100
def conn(
1✔
101
    host=None, user=None, password=None, *, init_fun=None, reset=False, use_tls=None
102
):
103
    """
104
    Returns a persistent connection object to be shared by multiple modules.
105
    If the connection is not yet established or reset=True, a new connection is set up.
106
    If connection information is not provided, it is taken from config which takes the
107
    information from dj_local_conf.json. If the password is not specified in that file
108
    datajoint prompts for the password.
109

110
    :param host: hostname
111
    :param user: mysql user
112
    :param password: mysql password
113
    :param init_fun: initialization function
114
    :param reset: whether the connection should be reset or not
115
    :param use_tls: TLS encryption option. Valid options are: True (required), False
116
        (required no TLS), None (TLS prefered, default), dict (Manually specify values per
117
        https://dev.mysql.com/doc/refman/5.7/en/connection-options.html#encrypted-connection-options).
118
    """
119
    if not hasattr(conn, "connection") or reset:
1✔
120
        host = host if host is not None else config["database.host"]
1✔
121
        user = user if user is not None else config["database.user"]
1✔
122
        password = password if password is not None else config["database.password"]
1✔
123
        if user is None:  # pragma: no cover
124
            user = input("Please enter DataJoint username: ")
125
        if password is None:  # pragma: no cover
126
            password = getpass(prompt="Please enter DataJoint password: ")
127
        init_fun = (
1✔
128
            init_fun if init_fun is not None else config["connection.init_function"]
129
        )
130
        use_tls = use_tls if use_tls is not None else config["database.use_tls"]
1✔
131
        conn.connection = Connection(host, user, password, None, init_fun, use_tls)
1✔
132
    return conn.connection
1✔
133

134

135
class EmulatedCursor:
1✔
136
    """acts like a cursor"""
137

138
    def __init__(self, data):
1✔
139
        self._data = data
1✔
140
        self._iter = iter(self._data)
1✔
141

142
    def __iter__(self):
1✔
143
        return self
×
144

145
    def __next__(self):
1✔
146
        return next(self._iter)
×
147

148
    def fetchall(self):
1✔
149
        return self._data
1✔
150

151
    def fetchone(self):
1✔
152
        return next(self._iter)
×
153

154
    @property
1✔
155
    def rowcount(self):
1✔
156
        return len(self._data)
×
157

158

159
class Connection:
1✔
160
    """
161
    A dj.Connection object manages a connection to a database server.
162
    It also catalogues modules, schemas, tables, and their dependencies (foreign keys).
163

164
    Most of the parameters below should be set in the local configuration file.
165

166
    :param host: host name, may include port number as hostname:port, in which case it overrides the value in port
167
    :param user: user name
168
    :param password: password
169
    :param port: port number
170
    :param init_fun: connection initialization function (SQL)
171
    :param use_tls: TLS encryption option
172
    """
173

174
    def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None):
1✔
175
        host_input, host = (host, get_host_hook(host))
1✔
176
        if ":" in host:
1✔
177
            # the port in the hostname overrides the port argument
178
            host, port = host.split(":")
×
179
            port = int(port)
×
180
        elif port is None:
1✔
181
            port = config["database.port"]
1✔
182
        self.conn_info = dict(host=host, port=port, user=user, passwd=password)
1✔
183
        if use_tls is not False:
1✔
184
            self.conn_info["ssl"] = (
1✔
185
                use_tls if isinstance(use_tls, dict) else {"ssl": {}}
186
            )
187
        self.conn_info["ssl_input"] = use_tls
1✔
188
        self.conn_info["host_input"] = host_input
1✔
189
        self.init_fun = init_fun
1✔
190
        logger.info("Connecting {user}@{host}:{port}".format(**self.conn_info))
1✔
191
        self._conn = None
1✔
192
        self._query_cache = None
1✔
193
        connect_host_hook(self)
1✔
194
        if self.is_connected:
1✔
195
            logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
1✔
196
            self.connection_id = self.query("SELECT connection_id()").fetchone()[0]
1✔
197
        else:
198
            raise errors.LostConnectionError("Connection failed.")
×
199
        self._in_transaction = False
1✔
200
        self.schemas = dict()
1✔
201
        self.dependencies = Dependencies(self)
1✔
202

203
    def __eq__(self, other):
1✔
204
        return self.conn_info == other.conn_info
1✔
205

206
    def __repr__(self):
1✔
207
        connected = "connected" if self.is_connected else "disconnected"
1✔
208
        return "DataJoint connection ({connected}) {user}@{host}:{port}".format(
1✔
209
            connected=connected, **self.conn_info
210
        )
211

212
    def connect(self):
1✔
213
        """Connect to the database server."""
214
        with warnings.catch_warnings():
1✔
215
            warnings.filterwarnings("ignore", ".*deprecated.*")
1✔
216
            try:
1✔
217
                self._conn = client.connect(
1✔
218
                    init_command=self.init_fun,
219
                    sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
220
                    "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
221
                    charset=config["connection.charset"],
222
                    **{
223
                        k: v
224
                        for k, v in self.conn_info.items()
225
                        if k not in ["ssl_input", "host_input"]
226
                    },
227
                )
228
            except client.err.InternalError:
1✔
229
                self._conn = client.connect(
×
230
                    init_command=self.init_fun,
231
                    sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
232
                    "STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION,ONLY_FULL_GROUP_BY",
233
                    charset=config["connection.charset"],
234
                    **{
235
                        k: v
236
                        for k, v in self.conn_info.items()
237
                        if not (
238
                            k in ["ssl_input", "host_input"]
239
                            or k == "ssl"
240
                            and self.conn_info["ssl_input"] is None
241
                        )
242
                    },
243
                )
244
        self._conn.autocommit(True)
1✔
245

246
    def set_query_cache(self, query_cache=None):
1✔
247
        """
248
        When query_cache is not None, the connection switches into the query caching mode, which entails:
249
        1. Only SELECT queries are allowed.
250
        2. The results of queries are cached under the path indicated by dj.config['query_cache']
251
        3. query_cache is a string that differentiates different cache states.
252

253
        :param query_cache: a string to initialize the hash for query results
254
        """
255
        self._query_cache = query_cache
1✔
256

257
    def purge_query_cache(self):
1✔
258
        """Purges all query cache."""
259
        if (
1✔
260
            isinstance(config.get(cache_key), str)
261
            and pathlib.Path(config[cache_key]).is_dir()
262
        ):
263
            for path in pathlib.Path(config[cache_key]).iterdir():
1✔
264
                if not path.is_dir():
1✔
265
                    path.unlink()
1✔
266

267
    def close(self):
1✔
268
        self._conn.close()
1✔
269

270
    def register(self, schema):
1✔
271
        self.schemas[schema.database] = schema
1✔
272
        self.dependencies.clear()
1✔
273

274
    def ping(self):
1✔
275
        """Ping the connection or raises an exception if the connection is closed."""
276
        self._conn.ping(reconnect=False)
1✔
277

278
    @property
1✔
279
    def is_connected(self):
1✔
280
        """Return true if the object is connected to the database server."""
281
        try:
1✔
282
            self.ping()
1✔
283
        except:
1✔
284
            return False
1✔
285
        return True
1✔
286

287
    @staticmethod
1✔
288
    def _execute_query(cursor, query, args, suppress_warnings):
1✔
289
        try:
1✔
290
            with warnings.catch_warnings():
1✔
291
                if suppress_warnings:
1✔
292
                    # suppress all warnings arising from underlying SQL library
293
                    warnings.simplefilter("ignore")
1✔
294
                cursor.execute(query, args)
1✔
295
        except client.err.Error as err:
1✔
296
            raise translate_query_error(err, query)
1✔
297

298
    def query(
1✔
299
        self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None
300
    ):
301
        """
302
        Execute the specified query and return the tuple generator (cursor).
303

304
        :param query: SQL query
305
        :param args: additional arguments for the client.cursor
306
        :param as_dict: If as_dict is set to True, the returned cursor objects returns
307
                        query results as dictionary.
308
        :param suppress_warnings: If True, suppress all warnings arising from underlying query library
309
        :param reconnect: when None, get from config, when True, attempt to reconnect if disconnected
310
        """
311
        # check cache first:
312
        use_query_cache = bool(self._query_cache)
1✔
313
        if use_query_cache and not re.match(r"\s*(SELECT|SHOW)", query):
1✔
314
            raise errors.DataJointError(
1✔
315
                "Only SELECT queries are allowed when query caching is on."
316
            )
317
        if use_query_cache:
1✔
318
            if not config[cache_key]:
1✔
319
                raise errors.DataJointError(
×
320
                    f"Provide filepath dj.config['{cache_key}'] when using query caching."
321
                )
322
            hash_ = uuid_from_buffer(
1✔
323
                (str(self._query_cache) + re.sub(r"`\$\w+`", "", query)).encode()
324
                + pack(args)
325
            )
326
            cache_path = pathlib.Path(config[cache_key]) / str(hash_)
1✔
327
            try:
1✔
328
                buffer = cache_path.read_bytes()
1✔
329
            except FileNotFoundError:
1✔
330
                pass  # proceed to query the database
1✔
331
            else:
332
                return EmulatedCursor(unpack(buffer))
1✔
333

334
        if reconnect is None:
1✔
335
            reconnect = config["database.reconnect"]
1✔
336
        logger.debug("Executing SQL:" + query[:query_log_max_length])
1✔
337
        cursor_class = client.cursors.DictCursor if as_dict else client.cursors.Cursor
1✔
338
        cursor = self._conn.cursor(cursor=cursor_class)
1✔
339
        try:
1✔
340
            self._execute_query(cursor, query, args, suppress_warnings)
1✔
341
        except errors.LostConnectionError:
1✔
342
            if not reconnect:
1✔
343
                raise
×
344
            logger.warning("MySQL server has gone away. Reconnecting to the server.")
1✔
345
            connect_host_hook(self)
1✔
346
            if self._in_transaction:
1✔
347
                self.cancel_transaction()
1✔
348
                raise errors.LostConnectionError(
1✔
349
                    "Connection was lost during a transaction."
350
                )
351
            logger.debug("Re-executing")
1✔
352
            cursor = self._conn.cursor(cursor=cursor_class)
1✔
353
            self._execute_query(cursor, query, args, suppress_warnings)
1✔
354

355
        if use_query_cache:
1✔
356
            data = cursor.fetchall()
1✔
357
            cache_path.write_bytes(pack(data))
1✔
358
            return EmulatedCursor(data)
1✔
359

360
        return cursor
1✔
361

362
    def get_user(self):
1✔
363
        """
364
        :return: the user name and host name provided by the client to the server.
365
        """
366
        return self.query("SELECT user()").fetchone()[0]
1✔
367

368
    # ---------- transaction processing
369
    @property
1✔
370
    def in_transaction(self):
1✔
371
        """
372
        :return: True if there is an open transaction.
373
        """
374
        self._in_transaction = self._in_transaction and self.is_connected
1✔
375
        return self._in_transaction
1✔
376

377
    def start_transaction(self):
1✔
378
        """
379
        Starts a transaction error.
380
        """
381
        if self.in_transaction:
1✔
382
            raise errors.DataJointError("Nested connections are not supported.")
×
383
        self.query("START TRANSACTION WITH CONSISTENT SNAPSHOT")
1✔
384
        self._in_transaction = True
1✔
385
        logger.debug("Transaction started")
1✔
386

387
    def cancel_transaction(self):
1✔
388
        """
389
        Cancels the current transaction and rolls back all changes made during the transaction.
390
        """
391
        self.query("ROLLBACK")
1✔
392
        self._in_transaction = False
1✔
393
        logger.debug("Transaction cancelled. Rolling back ...")
1✔
394

395
    def commit_transaction(self):
1✔
396
        """
397
        Commit all changes made during the transaction and close it.
398

399
        """
400
        self.query("COMMIT")
1✔
401
        self._in_transaction = False
1✔
402
        logger.debug("Transaction committed and closed.")
1✔
403

404
    # -------- context manager for transactions
405
    @property
1✔
406
    @contextmanager
1✔
407
    def transaction(self):
1✔
408
        """
409
        Context manager for transactions. Opens an transaction and closes it after the with statement.
410
        If an error is caught during the transaction, the commits are automatically rolled back.
411
        All errors are raised again.
412

413
        Example:
414
        >>> import datajoint as dj
415
        >>> with dj.conn().transaction as conn:
416
        >>>     # transaction is open here
417
        """
418
        try:
1✔
419
            self.start_transaction()
1✔
420
            yield self
1✔
421
        except:
1✔
422
            self.cancel_transaction()
1✔
423
            raise
1✔
424
        else:
425
            self.commit_transaction()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc