Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Sign In

matrix-org / synapse / 4532

23 Sep 2019 - 19:39 coverage decreased (-49.7%) to 17.596%
4532

Pull #6079

buildkite

Richard van der Hoff
update changelog
Pull Request #6079: Add submit_url response parameter to msisdn /requestToken

359 of 12986 branches covered (2.76%)

Branch coverage included in aggregate %.

0 of 7 new or added lines in 1 file covered. (0.0%)

18869 existing lines in 281 files now uncovered.

8809 of 39116 relevant lines covered (22.52%)

0.23 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

14.99
/synapse/storage/_base.py
1
# -*- coding: utf-8 -*-
2
# Copyright 2014-2016 OpenMarket Ltd
3
# Copyright 2017-2018 New Vector Ltd
4
# Copyright 2019 The Matrix.org Foundation C.I.C.
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
import itertools
1×
18
import logging
1×
19
import random
1×
20
import sys
1×
21
import threading
1×
22
import time
1×
23

24
from six import PY2, iteritems, iterkeys, itervalues
1×
25
from six.moves import builtins, intern, range
1×
26

27
from canonicaljson import json
1×
28
from prometheus_client import Histogram
1×
29

30
from twisted.internet import defer
1×
31

32
from synapse.api.errors import StoreError
1×
33
from synapse.logging.context import LoggingContext, PreserveLoggingContext
1×
34
from synapse.metrics.background_process_metrics import run_as_background_process
1×
35
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
1×
36
from synapse.types import get_domain_from_id
1×
37
from synapse.util import batch_iter
1×
38
from synapse.util.caches.descriptors import Cache
1×
39
from synapse.util.stringutils import exception_to_unicode
1×
40

41
# import a function which will return a monotonic time, in seconds
42
try:
1×
43
    # on python 3, use time.monotonic, since time.clock can go backwards
44
    from time import monotonic as monotonic_time
1×
45
except ImportError:
!
46
    # ... but python 2 doesn't have it
47
    from time import clock as monotonic_time
!
48

49
logger = logging.getLogger(__name__)
1×
50

51
try:
1×
52
    MAX_TXN_ID = sys.maxint - 1
1×
53
except AttributeError:
1×
54
    # python 3 does not have a maximum int value
55
    MAX_TXN_ID = 2 ** 63 - 1
1×
56

57
sql_logger = logging.getLogger("synapse.storage.SQL")
1×
58
transaction_logger = logging.getLogger("synapse.storage.txn")
1×
59
perf_logger = logging.getLogger("synapse.storage.TIME")
1×
60

61
sql_scheduling_timer = Histogram("synapse_storage_schedule_time", "sec")
1×
62

63
sql_query_timer = Histogram("synapse_storage_query_time", "sec", ["verb"])
1×
64
sql_txn_timer = Histogram("synapse_storage_transaction_time", "sec", ["desc"])
1×
65

66

67
# Unique indexes which have been added in background updates. Maps from table name
68
# to the name of the background update which added the unique index to that table.
69
#
70
# This is used by the upsert logic to figure out which tables are safe to do a proper
71
# UPSERT on: until the relevant background update has completed, we
72
# have to emulate an upsert by locking the table.
73
#
74
UNIQUE_INDEX_BACKGROUND_UPDATES = {
1×
75
    "user_ips": "user_ips_device_unique_index",
76
    "device_lists_remote_extremeties": "device_lists_remote_extremeties_unique_idx",
77
    "device_lists_remote_cache": "device_lists_remote_cache_unique_idx",
78
    "event_search": "event_search_event_id_idx",
79
}
80

81
# This is a special cache name we use to batch multiple invalidations of caches
82
# based on the current state when notifying workers over replication.
83
_CURRENT_STATE_CACHE_NAME = "cs_cache_fake"
1×
84

85

86
class LoggingTransaction(object):
1×
87
    """An object that almost-transparently proxies for the 'txn' object
88
    passed to the constructor. Adds logging and metrics to the .execute()
89
    method.
90

91
    Args:
92
        txn: The database transcation object to wrap.
93
        name (str): The name of this transactions for logging.
94
        database_engine (Sqlite3Engine|PostgresEngine)
95
        after_callbacks(list|None): A list that callbacks will be appended to
96
            that have been added by `call_after` which should be run on
97
            successful completion of the transaction. None indicates that no
98
            callbacks should be allowed to be scheduled to run.
99
        exception_callbacks(list|None): A list that callbacks will be appended
100
            to that have been added by `call_on_exception` which should be run
101
            if transaction ends with an error. None indicates that no callbacks
102
            should be allowed to be scheduled to run.
103
    """
104

105
    __slots__ = [
1×
106
        "txn",
107
        "name",
108
        "database_engine",
109
        "after_callbacks",
110
        "exception_callbacks",
111
    ]
112

113
    def __init__(
1×
114
        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
115
    ):
UNCOV
116
        object.__setattr__(self, "txn", txn)
!
UNCOV
117
        object.__setattr__(self, "name", name)
!
UNCOV
118
        object.__setattr__(self, "database_engine", database_engine)
!
UNCOV
119
        object.__setattr__(self, "after_callbacks", after_callbacks)
!
UNCOV
120
        object.__setattr__(self, "exception_callbacks", exception_callbacks)
!
121

122
    def call_after(self, callback, *args, **kwargs):
1×
123
        """Call the given callback on the main twisted thread after the
124
        transaction has finished. Used to invalidate the caches on the
125
        correct thread.
126
        """
UNCOV
127
        self.after_callbacks.append((callback, args, kwargs))
!
128

129
    def call_on_exception(self, callback, *args, **kwargs):
1×
UNCOV
130
        self.exception_callbacks.append((callback, args, kwargs))
!
131

132
    def __getattr__(self, name):
1×
UNCOV
133
        return getattr(self.txn, name)
!
134

135
    def __setattr__(self, name, value):
1×
136
        setattr(self.txn, name, value)
!
137

138
    def __iter__(self):
1×
UNCOV
139
        return self.txn.__iter__()
!
140

141
    def execute_batch(self, sql, args):
1×
UNCOV
142
        if isinstance(self.database_engine, PostgresEngine):
Branches [[0, 143], [0, 147]] missed. !
UNCOV
143
            from psycopg2.extras import execute_batch
!
144

UNCOV
145
            self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
Branches [[0, 145], [0, 141]] missed. !
146
        else:
147
            for val in args:
Branches [[0, 141], [0, 148]] missed. !
148
                self.execute(sql, val)
!
149

150
    def execute(self, sql, *args):
1×
UNCOV
151
        self._do_execute(self.txn.execute, sql, *args)
!
152

153
    def executemany(self, sql, *args):
1×
UNCOV
154
        self._do_execute(self.txn.executemany, sql, *args)
!
155

156
    def _make_sql_one_line(self, sql):
1×
157
        "Strip newlines out of SQL so that the loggers in the DB are on one line"
UNCOV
158
        return " ".join(l.strip() for l in sql.splitlines() if l.strip())
Branches [[0, 158], [0, 156]] missed. !
159

160
    def _do_execute(self, func, sql, *args):
1×
UNCOV
161
        sql = self._make_sql_one_line(sql)
!
162

163
        # TODO(paul): Maybe use 'info' and 'debug' for values?
UNCOV
164
        sql_logger.debug("[SQL] {%s} %s", self.name, sql)
!
165

UNCOV
166
        sql = self.database_engine.convert_param_style(sql)
!
UNCOV
167
        if args:
Branches [[0, 168], [0, 174]] missed. !
UNCOV
168
            try:
!
UNCOV
169
                sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
!
170
            except Exception:
!
171
                # Don't let logging failures stop SQL from working
172
                pass
!
173

UNCOV
174
        start = time.time()
!
175

UNCOV
176
        try:
!
UNCOV
177
            return func(sql, *args)
!
UNCOV
178
        except Exception as e:
!
UNCOV
179
            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
!
UNCOV
180
            raise
!
181
        finally:
UNCOV
182
            secs = time.time() - start
!
UNCOV
183
            sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
!
UNCOV
184
            sql_query_timer.labels(sql.split()[0]).observe(secs)
!
185

186

187
class PerformanceCounters(object):
1×
188
    def __init__(self):
1×
UNCOV
189
        self.current_counters = {}
!
UNCOV
190
        self.previous_counters = {}
!
191

192
    def update(self, key, duration_secs):
1×
UNCOV
193
        count, cum_time = self.current_counters.get(key, (0, 0))
!
UNCOV
194
        count += 1
!
UNCOV
195
        cum_time += duration_secs
!
UNCOV
196
        self.current_counters[key] = (count, cum_time)
!
197

198
    def interval(self, interval_duration_secs, limit=3):
1×
UNCOV
199
        counters = []
!
UNCOV
200
        for name, (count, cum_time) in iteritems(self.current_counters):
Branches [[0, 201], [0, 210]] missed. !
UNCOV
201
            prev_count, prev_time = self.previous_counters.get(name, (0, 0))
!
UNCOV
202
            counters.append(
!
203
                (
204
                    (cum_time - prev_time) / interval_duration_secs,
205
                    count - prev_count,
206
                    name,
207
                )
208
            )
209

UNCOV
210
        self.previous_counters = dict(self.current_counters)
!
211

UNCOV
212
        counters.sort(reverse=True)
!
213

UNCOV
214
        top_n_counters = ", ".join(
Branches [[0, 215], [0, 219]] missed. !
215
            "%s(%d): %.3f%%" % (name, count, 100 * ratio)
216
            for ratio, count, name in counters[:limit]
217
        )
218

UNCOV
219
        return top_n_counters
!
220

221

222
class SQLBaseStore(object):
1×
223
    _TXN_ID = 0
1×
224

225
    def __init__(self, db_conn, hs):
1×
UNCOV
226
        self.hs = hs
!
UNCOV
227
        self._clock = hs.get_clock()
!
UNCOV
228
        self._db_pool = hs.get_db_pool()
!
229

UNCOV
230
        self._previous_txn_total_time = 0
!
UNCOV
231
        self._current_txn_total_time = 0
!
UNCOV
232
        self._previous_loop_ts = 0
!
233

234
        # TODO(paul): These can eventually be removed once the metrics code
235
        #   is running in mainline, and we have some nice monitoring frontends
236
        #   to watch it
UNCOV
237
        self._txn_perf_counters = PerformanceCounters()
!
238

UNCOV
239
        self._get_event_cache = Cache(
!
240
            "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
241
        )
242

UNCOV
243
        self._event_fetch_lock = threading.Condition()
!
UNCOV
244
        self._event_fetch_list = []
!
UNCOV
245
        self._event_fetch_ongoing = 0
!
246

UNCOV
247
        self._pending_ds = []
!
248

UNCOV
249
        self.database_engine = hs.database_engine
!
250

251
        # A set of tables that are not safe to use native upserts in.
UNCOV
252
        self._unsafe_to_upsert_tables = set(UNIQUE_INDEX_BACKGROUND_UPDATES.keys())
!
253

UNCOV
254
        self._account_validity = self.hs.config.account_validity
!
255

256
        # We add the user_directory_search table to the blacklist on SQLite
257
        # because the existing search table does not have an index, making it
258
        # unsafe to use native upserts.
UNCOV
259
        if isinstance(self.database_engine, Sqlite3Engine):
Branches [[0, 260], [0, 262]] missed. !
260
            self._unsafe_to_upsert_tables.add("user_directory_search")
!
261

UNCOV
262
        if self.database_engine.can_native_upsert:
Branches [[0, 265], [0, 272]] missed. !
263
            # Check ASAP (and then later, every 1s) to see if we have finished
264
            # background updates of tables that aren't safe to update.
UNCOV
265
            self._clock.call_later(
!
266
                0.0,
267
                run_as_background_process,
268
                "upsert_safety_check",
269
                self._check_safe_to_upsert,
270
            )
271

UNCOV
272
        self.rand = random.SystemRandom()
!
273

UNCOV
274
        if self._account_validity.enabled:
Branches [[0, 225], [0, 275]] missed. !
275
            self._clock.call_later(
!
276
                0.0,
277
                run_as_background_process,
278
                "account_validity_set_expiration_dates",
279
                self._set_expiration_date_when_missing,
280
            )
281

282
    @defer.inlineCallbacks
1×
283
    def _check_safe_to_upsert(self):
284
        """
285
        Is it safe to use native UPSERT?
286

287
        If there are background updates, we will need to wait, as they may be
288
        the addition of indexes that set the UNIQUE constraint that we require.
289

290
        If the background updates have not completed, wait 15 sec and check again.
291
        """
UNCOV
292
        updates = yield self._simple_select_list(
!
293
            "background_updates",
294
            keyvalues=None,
295
            retcols=["update_name"],
296
            desc="check_background_updates",
297
        )
UNCOV
298
        updates = [x["update_name"] for x in updates]
Branches [[0, 298], [0, 300]] missed. !
299

UNCOV
300
        for table, update_name in UNIQUE_INDEX_BACKGROUND_UPDATES.items():
Branches [[0, 301], [0, 306]] missed. !
UNCOV
301
            if update_name not in updates:
Branches [[0, 300], [0, 302]] missed. !
UNCOV
302
                logger.debug("Now safe to upsert in %s", table)
!
UNCOV
303
                self._unsafe_to_upsert_tables.discard(table)
!
304

305
        # If there's any updates still running, reschedule to run.
UNCOV
306
        if updates:
Branches [[0, 282], [0, 307]] missed. !
UNCOV
307
            self._clock.call_later(
!
308
                15.0,
309
                run_as_background_process,
310
                "upsert_safety_check",
311
                self._check_safe_to_upsert,
312
            )
313

314
    @defer.inlineCallbacks
1×
315
    def _set_expiration_date_when_missing(self):
316
        """
317
        Retrieves the list of registered users that don't have an expiration date, and
318
        adds an expiration date for each of them.
319
        """
320

321
        def select_users_with_no_expiration_date_txn(txn):
!
322
            """Retrieves the list of registered users with no expiration date from the
323
            database, filtering out deactivated users.
324
            """
325
            sql = (
!
326
                "SELECT users.name FROM users"
327
                " LEFT JOIN account_validity ON (users.name = account_validity.user_id)"
328
                " WHERE account_validity.user_id is NULL AND users.deactivated = 0;"
329
            )
330
            txn.execute(sql, [])
!
331

332
            res = self.cursor_to_dict(txn)
!
333
            if res:
Branches [[0, 321], [0, 334]] missed. !
334
                for user in res:
Branches [[0, 321], [0, 335]] missed. !
335
                    self.set_expiration_date_for_user_txn(
!
336
                        txn, user["name"], use_delta=True
337
                    )
338

339
        yield self.runInteraction(
!
340
            "get_users_with_no_expiration_date",
341
            select_users_with_no_expiration_date_txn,
342
        )
343

344
    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
1×
345
        """Sets an expiration date to the account with the given user ID.
346

347
        Args:
348
             user_id (str): User ID to set an expiration date for.
349
             use_delta (bool): If set to False, the expiration date for the user will be
350
                now + validity period. If set to True, this expiration date will be a
351
                random value in the [now + period - d ; now + period] range, d being a
352
                delta equal to 10% of the validity period.
353
        """
354
        now_ms = self._clock.time_msec()
!
355
        expiration_ts = now_ms + self._account_validity.period
!
356

357
        if use_delta:
Branches [[0, 358], [0, 363]] missed. !
358
            expiration_ts = self.rand.randrange(
!
359
                expiration_ts - self._account_validity.startup_job_max_delta,
360
                expiration_ts,
361
            )
362

363
        self._simple_insert_txn(
!
364
            txn,
365
            "account_validity",
366
            values={
367
                "user_id": user_id,
368
                "expiration_ts_ms": expiration_ts,
369
                "email_sent": False,
370
            },
371
        )
372

373
    def start_profiling(self):
1×
UNCOV
374
        self._previous_loop_ts = monotonic_time()
!
375

UNCOV
376
        def loop():
!
UNCOV
377
            curr = self._current_txn_total_time
!
UNCOV
378
            prev = self._previous_txn_total_time
!
UNCOV
379
            self._previous_txn_total_time = curr
!
380

UNCOV
381
            time_now = monotonic_time()
!
UNCOV
382
            time_then = self._previous_loop_ts
!
UNCOV
383
            self._previous_loop_ts = time_now
!
384

UNCOV
385
            duration = time_now - time_then
!
UNCOV
386
            ratio = (curr - prev) / duration
!
387

UNCOV
388
            top_three_counters = self._txn_perf_counters.interval(duration, limit=3)
!
389

UNCOV
390
            perf_logger.info(
!
391
                "Total database time: %.3f%% {%s}", ratio * 100, top_three_counters
392
            )
393

UNCOV
394
        self._clock.looping_call(loop, 10000)
!
395

396
    def _new_transaction(
1×
397
        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
398
    ):
UNCOV
399
        start = monotonic_time()
!
UNCOV
400
        txn_id = self._TXN_ID
!
401

402
        # We don't really need these to be unique, so lets stop it from
403
        # growing really large.
UNCOV
404
        self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
!
405

UNCOV
406
        name = "%s-%x" % (desc, txn_id)
!
407

UNCOV
408
        transaction_logger.debug("[TXN START] {%s}", name)
!
409

UNCOV
410
        try:
!
UNCOV
411
            i = 0
!
UNCOV
412
            N = 5
!
UNCOV
413
            while True:
!
UNCOV
414
                try:
!
UNCOV
415
                    txn = conn.cursor()
!
UNCOV
416
                    txn = LoggingTransaction(
!
417
                        txn,
418
                        name,
419
                        self.database_engine,
420
                        after_callbacks,
421
                        exception_callbacks,
422
                    )
UNCOV
423
                    r = func(txn, *args, **kwargs)
!
UNCOV
424
                    conn.commit()
!
UNCOV
425
                    return r
!
UNCOV
426
                except self.database_engine.module.OperationalError as e:
Branches [[0, 429], [0, 446]] missed. !
427
                    # This can happen if the database disappears mid
428
                    # transaction.
UNCOV
429
                    logger.warning(
!
430
                        "[TXN OPERROR] {%s} %s %d/%d",
431
                        name,
432
                        exception_to_unicode(e),
433
                        i,
434
                        N,
435
                    )
UNCOV
436
                    if i < N:
Branches [[0, 437], [0, 445]] missed. !
UNCOV
437
                        i += 1
!
UNCOV
438
                        try:
!
UNCOV
439
                            conn.rollback()
!
440
                        except self.database_engine.module.Error as e1:
!
441
                            logger.warning(
!
442
                                "[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
443
                            )
UNCOV
444
                        continue
!
445
                    raise
!
UNCOV
446
                except self.database_engine.module.DatabaseError as e:
!
UNCOV
447
                    if self.database_engine.is_deadlock(e):
Branches [[0, 448], [0, 460]] missed. !
448
                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
!
449
                        if i < N:
Branches [[0, 450], [0, 460]] missed. !
450
                            i += 1
!
451
                            try:
!
452
                                conn.rollback()
!
453
                            except self.database_engine.module.Error as e1:
!
454
                                logger.warning(
!
455
                                    "[TXN EROLL] {%s} %s",
456
                                    name,
457
                                    exception_to_unicode(e1),
458
                                )
459
                            continue
!
UNCOV
460
                    raise
!
UNCOV
461
        except Exception as e:
!
UNCOV
462
            logger.debug("[TXN FAIL] {%s} %s", name, e)
!
UNCOV
463
            raise
!
464
        finally:
UNCOV
465
            end = monotonic_time()
!
UNCOV
466
            duration = end - start
!
467

UNCOV
468
            LoggingContext.current_context().add_database_transaction(duration)
!
469

UNCOV
470
            transaction_logger.debug("[TXN END] {%s} %f sec", name, duration)
!
471

UNCOV
472
            self._current_txn_total_time += duration
!
UNCOV
473
            self._txn_perf_counters.update(desc, duration)
!
UNCOV
474
            sql_txn_timer.labels(desc).observe(duration)
!
475

476
    @defer.inlineCallbacks
1×
477
    def runInteraction(self, desc, func, *args, **kwargs):
478
        """Starts a transaction on the database and runs a given function
479

480
        Arguments:
481
            desc (str): description of the transaction, for logging and metrics
482
            func (func): callback function, which will be called with a
483
                database transaction (twisted.enterprise.adbapi.Transaction) as
484
                its first argument, followed by `args` and `kwargs`.
485

486
            args (list): positional args to pass to `func`
487
            kwargs (dict): named args to pass to `func`
488

489
        Returns:
490
            Deferred: The result of func
491
        """
UNCOV
492
        after_callbacks = []
!
UNCOV
493
        exception_callbacks = []
!
494

UNCOV
495
        if LoggingContext.current_context() == LoggingContext.sentinel:
Branches [[0, 496], [0, 498]] missed. !
UNCOV
496
            logger.warn("Starting db txn '%s' from sentinel context", desc)
!
497

UNCOV
498
        try:
!
UNCOV
499
            result = yield self.runWithConnection(
!
500
                self._new_transaction,
501
                desc,
502
                after_callbacks,
503
                exception_callbacks,
504
                func,
505
                *args,
506
                **kwargs
507
            )
508

UNCOV
509
            for after_callback, after_args, after_kwargs in after_callbacks:
Branches [[0, 510], [0, 516]] missed. !
UNCOV
510
                after_callback(*after_args, **after_kwargs)
!
UNCOV
511
        except:  # noqa: E722, as we reraise the exception this is fine.
!
UNCOV
512
            for after_callback, after_args, after_kwargs in exception_callbacks:
Branches [[0, 513], [0, 514]] missed. !
UNCOV
513
                after_callback(*after_args, **after_kwargs)
!
UNCOV
514
            raise
!
515

UNCOV
516
        return result
!
517

518
    @defer.inlineCallbacks
1×
519
    def runWithConnection(self, func, *args, **kwargs):
520
        """Wraps the .runWithConnection() method on the underlying db_pool.
521

522
        Arguments:
523
            func (func): callback function, which will be called with a
524
                database connection (twisted.enterprise.adbapi.Connection) as
525
                its first argument, followed by `args` and `kwargs`.
526
            args (list): positional args to pass to `func`
527
            kwargs (dict): named args to pass to `func`
528

529
        Returns:
530
            Deferred: The result of func
531
        """
UNCOV
532
        parent_context = LoggingContext.current_context()
!
UNCOV
533
        if parent_context == LoggingContext.sentinel:
Branches [[0, 534], [0, 539]] missed. !
UNCOV
534
            logger.warn(
!
535
                "Starting db connection from sentinel context: metrics will be lost"
536
            )
UNCOV
537
            parent_context = None
!
538

UNCOV
539
        start_time = monotonic_time()
!
540

UNCOV
541
        def inner_func(conn, *args, **kwargs):
!
UNCOV
542
            with LoggingContext("runWithConnection", parent_context) as context:
!
UNCOV
543
                sched_duration_sec = monotonic_time() - start_time
!
UNCOV
544
                sql_scheduling_timer.observe(sched_duration_sec)
!
UNCOV
545
                context.add_database_scheduled(sched_duration_sec)
!
546

UNCOV
547
                if self.database_engine.is_connection_closed(conn):
Branches [[0, 548], [0, 551]] missed. !
548
                    logger.debug("Reconnecting closed database connection")
!
549
                    conn.reconnect()
!
550

UNCOV
551
                return func(conn, *args, **kwargs)
!
552

UNCOV
553
        with PreserveLoggingContext():
!
UNCOV
554
            result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
!
555

UNCOV
556
        return result
!
557

558
    @staticmethod
1×
559
    def cursor_to_dict(cursor):
560
        """Converts a SQL cursor into an list of dicts.
561

562
        Args:
563
            cursor : The DBAPI cursor which has executed a query.
564
        Returns:
565
            A list of dicts where the key is the column header.
566
        """
UNCOV
567
        col_headers = list(intern(str(column[0])) for column in cursor.description)
Branches [[0, 567], [0, 568]] missed. !
UNCOV
568
        results = list(dict(zip(col_headers, row)) for row in cursor)
Branches [[0, 568], [0, 569]] missed. !
UNCOV
569
        return results
!
570

571
    def _execute(self, desc, decoder, query, *args):
1×
572
        """Runs a single query for a result set.
573

574
        Args:
575
            decoder - The function which can resolve the cursor results to
576
                something meaningful.
577
            query - The query string to execute
578
            *args - Query args.
579
        Returns:
580
            The result of decoder(results)
581
        """
582

UNCOV
583
        def interaction(txn):
!
UNCOV
584
            txn.execute(query, args)
!
UNCOV
585
            if decoder:
Branches [[0, 586], [0, 588]] missed. !
UNCOV
586
                return decoder(txn)
!
587
            else:
UNCOV
588
                return txn.fetchall()
!
589

UNCOV
590
        return self.runInteraction(desc, interaction)
!
591

592
    # "Simple" SQL API methods that operate on a single table with no JOINs,
593
    # no complex WHERE clauses, just a dict of values for columns.
594

595
    @defer.inlineCallbacks
1×
596
    def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
1×
597
        """Executes an INSERT query on the named table.
598

599
        Args:
600
            table : string giving the table name
601
            values : dict of new column names and values for them
602
            or_ignore : bool stating whether an exception should be raised
603
                when a conflicting row already exists. If True, False will be
604
                returned by the function instead
605
            desc : string giving a description of the transaction
606

607
        Returns:
608
            bool: Whether the row was inserted or not. Only useful when
609
            `or_ignore` is True
610
        """
UNCOV
611
        try:
!
UNCOV
612
            yield self.runInteraction(desc, self._simple_insert_txn, table, values)
!
613
        except self.database_engine.module.IntegrityError:
!
614
            # We have to do or_ignore flag at this layer, since we can't reuse
615
            # a cursor after we receive an error from the db.
616
            if not or_ignore:
Branches [[0, 617], [0, 618]] missed. !
617
                raise
!
618
            return False
!
UNCOV
619
        return True
!
620

621
    @staticmethod
1×
622
    def _simple_insert_txn(txn, table, values):
UNCOV
623
        keys, vals = zip(*values.items())
!
624

UNCOV
625
        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
Branches [[0, 628], [0, 627], [0, 631]] missed. !
626
            table,
627
            ", ".join(k for k in keys),
628
            ", ".join("?" for _ in keys),
629
        )
630

UNCOV
631
        txn.execute(sql, vals)
!
632

633
    def _simple_insert_many(self, table, values, desc):
1×
634
        return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
!
635

636
    @staticmethod
1×
637
    def _simple_insert_many_txn(txn, table, values):
UNCOV
638
        if not values:
Branches [[0, 639], [0, 649]] missed. !
UNCOV
639
            return
!
640

641
        # This is a *slight* abomination to get a list of tuples of key names
642
        # and a list of tuples of value names.
643
        #
644
        # i.e. [{"a": 1, "b": 2}, {"c": 3, "d": 4}]
645
        #         => [("a", "b",), ("c", "d",)] and [(1, 2,), (3, 4,)]
646
        #
647
        # The sort is to ensure that we don't rely on dictionary iteration
648
        # order.
UNCOV
649
        keys, vals = zip(
Branches [[0, 650], [0, 653]] missed. !
650
            *[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
651
        )
652

UNCOV
653
        for k in keys:
Branches [[0, 654], [0, 657]] missed. !
UNCOV
654
            if k != keys[0]:
Branches [[0, 653], [0, 655]] missed. !
655
                raise RuntimeError("All items must have the same keys")
!
656

UNCOV
657
        sql = "INSERT INTO %s (%s) VALUES(%s)" % (
Branches [[0, 660], [0, 659], [0, 663]] missed. !
658
            table,
659
            ", ".join(k for k in keys[0]),
660
            ", ".join("?" for _ in keys[0]),
661
        )
662

UNCOV
663
        txn.executemany(sql, vals)
!
664

665
    @defer.inlineCallbacks
1×
666
    def _simple_upsert(
1×
667
        self,
668
        table,
669
        keyvalues,
670
        values,
671
        insertion_values={},
672
        desc="_simple_upsert",
673
        lock=True,
674
    ):
675
        """
676

677
        `lock` should generally be set to True (the default), but can be set
678
        to False if either of the following are true:
679

680
        * there is a UNIQUE INDEX on the key columns. In this case a conflict
681
          will cause an IntegrityError in which case this function will retry
682
          the update.
683

684
        * we somehow know that we are the only thread which will be updating
685
          this table.
686

687
        Args:
688
            table (str): The table to upsert into
689
            keyvalues (dict): The unique key columns and their new values
690
            values (dict): The nonunique columns and their new values
691
            insertion_values (dict): additional key/values to use only when
692
                inserting
693
            lock (bool): True to lock the table when doing the upsert.
694
        Returns:
695
            Deferred(None or bool): Native upserts always return None. Emulated
696
            upserts return True if a new entry was created, False if an existing
697
            one was updated.
698
        """
UNCOV
699
        attempts = 0
!
UNCOV
700
        while True:
!
UNCOV
701
            try:
!
UNCOV
702
                result = yield self.runInteraction(
!
703
                    desc,
704
                    self._simple_upsert_txn,
705
                    table,
706
                    keyvalues,
707
                    values,
708
                    insertion_values,
709
                    lock=lock,
710
                )
UNCOV
711
                return result
!
712
            except self.database_engine.module.IntegrityError as e:
!
713
                attempts += 1
!
714
                if attempts >= 5:
Branches [[0, 717], [0, 720]] missed. !
715
                    # don't retry forever, because things other than races
716
                    # can cause IntegrityErrors
717
                    raise
!
718

719
                # presumably we raced with another transaction: let's retry.
720
                logger.warn(
!
721
                    "IntegrityError when upserting into %s; retrying: %s", table, e
722
                )
723

724
    def _simple_upsert_txn(
1×
725
        self, txn, table, keyvalues, values, insertion_values={}, lock=True
726
    ):
727
        """
728
        Pick the UPSERT method which works best on the platform. Either the
729
        native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
730

731
        Args:
732
            txn: The transaction to use.
733
            table (str): The table to upsert into
734
            keyvalues (dict): The unique key tables and their new values
735
            values (dict): The nonunique columns and their new values
736
            insertion_values (dict): additional key/values to use only when
737
                inserting
738
            lock (bool): True to lock the table when doing the upsert.
739
        Returns:
740
            None or bool: Native upserts always return None. Emulated
741
            upserts return True if a new entry was created, False if an existing
742
            one was updated.
743
        """
UNCOV
744
        if (
Branches [[0, 748], [0, 752]] missed. !
745
            self.database_engine.can_native_upsert
746
            and table not in self._unsafe_to_upsert_tables
747
        ):
UNCOV
748
            return self._simple_upsert_txn_native_upsert(
!
749
                txn, table, keyvalues, values, insertion_values=insertion_values
750
            )
751
        else:
752
            return self._simple_upsert_txn_emulated(
!
753
                txn,
754
                table,
755
                keyvalues,
756
                values,
757
                insertion_values=insertion_values,
758
                lock=lock,
759
            )
760

761
    def _simple_upsert_txn_emulated(
1×
762
        self, txn, table, keyvalues, values, insertion_values={}, lock=True
763
    ):
764
        """
765
        Args:
766
            table (str): The table to upsert into
767
            keyvalues (dict): The unique key tables and their new values
768
            values (dict): The nonunique columns and their new values
769
            insertion_values (dict): additional key/values to use only when
770
                inserting
771
            lock (bool): True to lock the table when doing the upsert.
772
        Returns:
773
            bool: Return True if a new entry was created, False if an existing
774
            one was updated.
775
        """
776
        # We need to lock the table :(, unless we're *really* careful
777
        if lock:
Branches [[0, 778], [0, 780]] missed. !
778
            self.database_engine.lock_table(txn, table)
!
779

780
        def _getwhere(key):
!
781
            # If the value we're passing in is None (aka NULL), we need to use
782
            # IS, not =, as NULL = NULL equals NULL (False).
783
            if keyvalues[key] is None:
Branches [[0, 784], [0, 786]] missed. !
784
                return "%s IS ?" % (key,)
!
785
            else:
786
                return "%s = ?" % (key,)
!
787

788
        if not values:
Branches [[0, 792], [0, 803]] missed. !
789
            # If `values` is empty, then all of the values we care about are in
790
            # the unique key, so there is nothing to UPDATE. We can just do a
791
            # SELECT instead to see if it exists.
792
            sql = "SELECT 1 FROM %s WHERE %s" % (
Branches [[0, 794], [0, 796]] missed. !
793
                table,
794
                " AND ".join(_getwhere(k) for k in keyvalues),
795
            )
796
            sqlargs = list(keyvalues.values())
!
797
            txn.execute(sql, sqlargs)
!
798
            if txn.fetchall():
Branches [[0, 800], [0, 816]] missed. !
799
                # We have an existing record.
800
                return False
!
801
        else:
802
            # First try to update.
803
            sql = "UPDATE %s SET %s WHERE %s" % (
Branches [[0, 806], [0, 805], [0, 808]] missed. !
804
                table,
805
                ", ".join("%s = ?" % (k,) for k in values),
806
                " AND ".join(_getwhere(k) for k in keyvalues),
807
            )
808
            sqlargs = list(values.values()) + list(keyvalues.values())
!
809

810
            txn.execute(sql, sqlargs)
!
811
            if txn.rowcount > 0:
Branches [[0, 813], [0, 816]] missed. !
812
                # successfully updated at least one row.
813
                return False
!
814

815
        # We didn't find any existing rows, so insert a new one
816
        allvalues = {}
!
817
        allvalues.update(keyvalues)
!
818
        allvalues.update(values)
!
819
        allvalues.update(insertion_values)
!
820

821
        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
Branches [[0, 824], [0, 823], [0, 826]] missed. !
822
            table,
823
            ", ".join(k for k in allvalues),
824
            ", ".join("?" for _ in allvalues),
825
        )
826
        txn.execute(sql, list(allvalues.values()))
!
827
        # successfully inserted
828
        return True
!
829

830
    def _simple_upsert_txn_native_upsert(
1×
831
        self, txn, table, keyvalues, values, insertion_values={}
832
    ):
833
        """
834
        Use the native UPSERT functionality in recent PostgreSQL versions.
835

836
        Args:
837
            table (str): The table to upsert into
838
            keyvalues (dict): The unique key tables and their new values
839
            values (dict): The nonunique columns and their new values
840
            insertion_values (dict): additional key/values to use only when
841
                inserting
842
        Returns:
843
            None
844
        """
UNCOV
845
        allvalues = {}
!
UNCOV
846
        allvalues.update(keyvalues)
!
UNCOV
847
        allvalues.update(insertion_values)
!
848

UNCOV
849
        if not values:
Branches [[0, 850], [0, 852]] missed. !
UNCOV
850
            latter = "NOTHING"
!
851
        else:
UNCOV
852
            allvalues.update(values)
!
UNCOV
853
            latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
Branches [[0, 853], [0, 855]] missed. !
854

UNCOV
855
        sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
Branches [[0, 859], [0, 858], [0, 857], [0, 862]] missed. !
856
            table,
857
            ", ".join(k for k in allvalues),
858
            ", ".join("?" for _ in allvalues),
859
            ", ".join(k for k in keyvalues),
860
            latter,
861
        )
UNCOV
862
        txn.execute(sql, list(allvalues.values()))
!
863

864
    def _simple_upsert_many_txn(
1×
865
        self, txn, table, key_names, key_values, value_names, value_values
866
    ):
867
        """
868
        Upsert, many times.
869

870
        Args:
871
            table (str): The table to upsert into
872
            key_names (list[str]): The key column names.
873
            key_values (list[list]): A list of each row's key column values.
874
            value_names (list[str]): The value column names. If empty, no
875
                values will be used, even if value_values is provided.
876
            value_values (list[list]): A list of each row's value column values.
877
        Returns:
878
            None
879
        """
UNCOV
880
        if (
Branches [[0, 884], [0, 888]] missed. !
881
            self.database_engine.can_native_upsert
882
            and table not in self._unsafe_to_upsert_tables
883
        ):
UNCOV
884
            return self._simple_upsert_many_txn_native_upsert(
!
885
                txn, table, key_names, key_values, value_names, value_values
886
            )
887
        else:
888
            return self._simple_upsert_many_txn_emulated(
!
889
                txn, table, key_names, key_values, value_names, value_values
890
            )
891

892
    def _simple_upsert_many_txn_emulated(
1×
893
        self, txn, table, key_names, key_values, value_names, value_values
894
    ):
895
        """
896
        Upsert, many times, but without native UPSERT support or batching.
897

898
        Args:
899
            table (str): The table to upsert into
900
            key_names (list[str]): The key column names.
901
            key_values (list[list]): A list of each row's key column values.
902
            value_names (list[str]): The value column names. If empty, no
903
                values will be used, even if value_values is provided.
904
            value_values (list[list]): A list of each row's value column values.
905
        Returns:
906
            None
907
        """
908
        # No value columns, therefore make a blank list so that the following
909
        # zip() works correctly.
910
        if not value_names:
Branches [[0, 911], [0, 913]] missed. !
911
            value_values = [() for x in range(len(key_values))]
Branches [[0, 911], [0, 913]] missed. !
912

913
        for keyv, valv in zip(key_values, value_values):
Branches [[0, 892], [0, 914]] missed. !
914
            _keys = {x: y for x, y in zip(key_names, keyv)}
Branches [[0, 914], [0, 915]] missed. !
915
            _vals = {x: y for x, y in zip(value_names, valv)}
Branches [[0, 915], [0, 917]] missed. !
916

917
            self._simple_upsert_txn_emulated(txn, table, _keys, _vals)
!
918

919
    def _simple_upsert_many_txn_native_upsert(
1×
920
        self, txn, table, key_names, key_values, value_names, value_values
921
    ):
922
        """
923
        Upsert, many times, using batching where possible.
924

925
        Args:
926
            table (str): The table to upsert into
927
            key_names (list[str]): The key column names.
928
            key_values (list[list]): A list of each row's key column values.
929
            value_names (list[str]): The value column names. If empty, no
930
                values will be used, even if value_values is provided.
931
            value_values (list[list]): A list of each row's value column values.
932
        Returns:
933
            None
934
        """
UNCOV
935
        allnames = []
!
UNCOV
936
        allnames.extend(key_names)
!
UNCOV
937
        allnames.extend(value_names)
!
938

UNCOV
939
        if not value_names:
Branches [[0, 942], [0, 945]] missed. !
940
            # No value columns, therefore make a blank list so that the
941
            # following zip() works correctly.
UNCOV
942
            latter = "NOTHING"
!
UNCOV
943
            value_values = [() for x in range(len(key_values))]
Branches [[0, 943], [0, 949]] missed. !
944
        else:
UNCOV
945
            latter = "UPDATE SET " + ", ".join(
Branches [[0, 946], [0, 949]] missed. !
946
                k + "=EXCLUDED." + k for k in value_names
947
            )
948

UNCOV
949
        sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
Branches [[0, 952], [0, 951], [0, 957]] missed. !
950
            table,
951
            ", ".join(k for k in allnames),
952
            ", ".join("?" for _ in allnames),
953
            ", ".join(key_names),
954
            latter,
955
        )
956

UNCOV
957
        args = []
!
958

UNCOV
959
        for x, y in zip(key_values, value_values):
Branches [[0, 960], [0, 962]] missed. !
UNCOV
960
            args.append(tuple(x) + tuple(y))
!
961

UNCOV
962
        return txn.execute_batch(sql, args)
!
963

964
    def _simple_select_one(
1×
965
        self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
966
    ):
967
        """Executes a SELECT query on the named table, which is expected to
968
        return a single row, returning multiple columns from it.
969

970
        Args:
971
            table : string giving the table name
972
            keyvalues : dict of column names and values to select the row with
973
            retcols : list of strings giving the names of the columns to return
974

975
            allow_none : If true, return None instead of failing if the SELECT
976
              statement returns no rows
977
        """
UNCOV
978
        return self.runInteraction(
!
979
            desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
980
        )
981

982
    def _simple_select_one_onecol(
1×
983
        self,
984
        table,
985
        keyvalues,
986
        retcol,
987
        allow_none=False,
988
        desc="_simple_select_one_onecol",
989
    ):
990
        """Executes a SELECT query on the named table, which is expected to
991
        return a single row, returning a single column from it.
992

993
        Args:
994
            table : string giving the table name
995
            keyvalues : dict of column names and values to select the row with
996
            retcol : string giving the name of the column to return
997
        """
UNCOV
998
        return self.runInteraction(
!
999
            desc,
1000
            self._simple_select_one_onecol_txn,
1001
            table,
1002
            keyvalues,
1003
            retcol,
1004
            allow_none=allow_none,
1005
        )
1006

1007
    @classmethod
1×
1008
    def _simple_select_one_onecol_txn(
1×
1009
        cls, txn, table, keyvalues, retcol, allow_none=False
1010
    ):
UNCOV
1011
        ret = cls._simple_select_onecol_txn(
!
1012
            txn, table=table, keyvalues=keyvalues, retcol=retcol
1013
        )
1014

UNCOV
1015
        if ret:
Branches [[0, 1016], [0, 1018]] missed. !
UNCOV
1016
            return ret[0]
!
1017
        else:
UNCOV
1018
            if allow_none:
Branches [[0, 1019], [0, 1021]] missed. !
UNCOV
1019
                return None
!
1020
            else:
UNCOV
1021
                raise StoreError(404, "No row found")
!
1022

1023
    @staticmethod
1×
1024
    def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
UNCOV
1025
        sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
!
1026

UNCOV
1027
        if keyvalues:
Branches [[0, 1028], [0, 1031]] missed. !
UNCOV
1028
            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
Branches [[0, 1028], [0, 1029]] missed. !
UNCOV
1029
            txn.execute(sql, list(keyvalues.values()))
!
1030
        else:
UNCOV
1031
            txn.execute(sql)
!
1032

UNCOV
1033
        return [r[0] for r in txn]
Branches [[0, 1033], [0, 1023]] missed. !
1034

1035
    def _simple_select_onecol(
1×
1036
        self, table, keyvalues, retcol, desc="_simple_select_onecol"
1037
    ):
1038
        """Executes a SELECT query on the named table, which returns a list
1039
        comprising of the values of the named column from the selected rows.
1040

1041
        Args:
1042
            table (str): table name
1043
            keyvalues (dict|None): column names and values to select the rows with
1044
            retcol (str): column whos value we wish to retrieve.
1045

1046
        Returns:
1047
            Deferred: Results in a list
1048
        """
UNCOV
1049
        return self.runInteraction(
!
1050
            desc, self._simple_select_onecol_txn, table, keyvalues, retcol
1051
        )
1052

1053
    def _simple_select_list(
1×
1054
        self, table, keyvalues, retcols, desc="_simple_select_list"
1055
    ):
1056
        """Executes a SELECT query on the named table, which may return zero or
1057
        more rows, returning the result as a list of dicts.
1058

1059
        Args:
1060
            table (str): the table name
1061
            keyvalues (dict[str, Any] | None):
1062
                column names and values to select the rows with, or None to not
1063
                apply a WHERE clause.
1064
            retcols (iterable[str]): the names of the columns to return
1065
        Returns:
1066
            defer.Deferred: resolves to list[dict[str, Any]]
1067
        """
UNCOV
1068
        return self.runInteraction(
!
1069
            desc, self._simple_select_list_txn, table, keyvalues, retcols
1070
        )
1071

1072
    @classmethod
1×
1073
    def _simple_select_list_txn(cls, txn, table, keyvalues, retcols):
1074
        """Executes a SELECT query on the named table, which may return zero or
1075
        more rows, returning the result as a list of dicts.
1076

1077
        Args:
1078
            txn : Transaction object
1079
            table (str): the table name
1080
            keyvalues (dict[str, T] | None):
1081
                column names and values to select the rows with, or None to not
1082
                apply a WHERE clause.
1083
            retcols (iterable[str]): the names of the columns to return
1084
        """
UNCOV
1085
        if keyvalues:
Branches [[0, 1086], [0, 1093]] missed. !
UNCOV
1086
            sql = "SELECT %s FROM %s WHERE %s" % (
Branches [[0, 1089], [0, 1091]] missed. !
1087
                ", ".join(retcols),
1088
                table,
1089
                " AND ".join("%s = ?" % (k,) for k in keyvalues),
1090
            )
UNCOV
1091
            txn.execute(sql, list(keyvalues.values()))
!
1092
        else:
UNCOV
1093
            sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
!
UNCOV
1094
            txn.execute(sql)
!
1095

UNCOV
1096
        return cls.cursor_to_dict(txn)
!
1097

1098
    @defer.inlineCallbacks
1×
1099
    def _simple_select_many_batch(
1×
1100
        self,
1101
        table,
1102
        column,
1103
        iterable,
1104
        retcols,
1105
        keyvalues={},
1106
        desc="_simple_select_many_batch",
1107
        batch_size=100,
1108
    ):
1109
        """Executes a SELECT query on the named table, which may return zero or
1110
        more rows, returning the result as a list of dicts.
1111

1112
        Filters rows by if value of `column` is in `iterable`.
1113

1114
        Args:
1115
            table : string giving the table name
1116
            column : column name to test for inclusion against `iterable`
1117
            iterable : list
1118
            keyvalues : dict of column names and values to select the rows with
1119
            retcols : list of strings giving the names of the columns to return
1120
        """
UNCOV
1121
        results = []
!
1122

UNCOV
1123
        if not iterable:
Branches [[0, 1124], [0, 1127]] missed. !
UNCOV
1124
            return results
!
1125

1126
        # iterables can not be sliced, so convert it to a list first
UNCOV
1127
        it_list = list(iterable)
!
1128

UNCOV
1129
        chunks = [
Branches [[0, 1130], [0, 1132]] missed. !
1130
            it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
1131
        ]
UNCOV
1132
        for chunk in chunks:
Branches [[0, 1133], [0, 1145]] missed. !
UNCOV
1133
            rows = yield self.runInteraction(
!
1134
                desc,
1135
                self._simple_select_many_txn,
1136
                table,
1137
                column,
1138
                chunk,
1139
                keyvalues,
1140
                retcols,
1141
            )
1142

UNCOV
1143
            results.extend(rows)
!
1144

UNCOV
1145
        return results
!
1146

1147
    @classmethod
1×
1148
    def _simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
1149
        """Executes a SELECT query on the named table, which may return zero or
1150
        more rows, returning the result as a list of dicts.
1151

1152
        Filters rows by if value of `column` is in `iterable`.
1153

1154
        Args:
1155
            txn : Transaction object
1156
            table : string giving the table name
1157
            column : column name to test for inclusion against `iterable`
1158
            iterable : list
1159
            keyvalues : dict of column names and values to select the rows with
1160
            retcols : list of strings giving the names of the columns to return
1161
        """
UNCOV
1162
        if not iterable:
Branches [[0, 1163], [0, 1165]] missed. !
1163
            return []
!
1164

UNCOV
1165
        sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
!
1166

UNCOV
1167
        clauses = []
!
UNCOV
1168
        values = []
!
UNCOV
1169
        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
Branches [[0, 1169], [0, 1170]] missed. !
UNCOV
1170
        values.extend(iterable)
!
1171

UNCOV
1172
        for key, value in iteritems(keyvalues):
Branches [[0, 1173], [0, 1176]] missed. !
UNCOV
1173
            clauses.append("%s = ?" % (key,))
!
UNCOV
1174
            values.append(value)
!
1175

UNCOV
1176
        if clauses:
Branches [[0, 1177], [0, 1179]] missed. !
UNCOV
1177
            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
!
1178

UNCOV
1179
        txn.execute(sql, values)
!
UNCOV
1180
        return cls.cursor_to_dict(txn)
!
1181

1182
    def _simple_update(self, table, keyvalues, updatevalues, desc):
1×
UNCOV
1183
        return self.runInteraction(
!
1184
            desc, self._simple_update_txn, table, keyvalues, updatevalues
1185
        )
1186

1187
    @staticmethod
1×
1188
    def _simple_update_txn(txn, table, keyvalues, updatevalues):
UNCOV
1189
        if keyvalues:
Branches [[0, 1190], [0, 1192]] missed. !
UNCOV
1190
            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
Branches [[0, 1190], [0, 1194]] missed. !
1191
        else:
UNCOV
1192
            where = ""
!
1193

UNCOV
1194
        update_sql = "UPDATE %s SET %s %s" % (
Branches [[0, 1196], [0, 1200]] missed. !
1195
            table,
1196
            ", ".join("%s = ?" % (k,) for k in updatevalues),
1197
            where,
1198
        )
1199

UNCOV
1200
        txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
!
1201

UNCOV
1202
        return txn.rowcount
!
1203

1204
    def _simple_update_one(
1×
1205
        self, table, keyvalues, updatevalues, desc="_simple_update_one"
1206
    ):
1207
        """Executes an UPDATE query on the named table, setting new values for
1208
        columns in a row matching the key values.
1209

1210
        Args:
1211
            table : string giving the table name
1212
            keyvalues : dict of column names and values to select the row with
1213
            updatevalues : dict giving column names and values to update
1214
            retcols : optional list of column names to return
1215

1216
        If present, retcols gives a list of column names on which to perform
1217
        a SELECT statement *before* performing the UPDATE statement. The values
1218
        of these will be returned in a dict.
1219

1220
        These are performed within the same transaction, allowing an atomic
1221
        get-and-set.  This can be used to implement compare-and-set by putting
1222
        the update column in the 'keyvalues' dict as well.
1223
        """
UNCOV
1224
        return self.runInteraction(
!
1225
            desc, self._simple_update_one_txn, table, keyvalues, updatevalues
1226
        )
1227

1228
    @classmethod
1×
1229
    def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
UNCOV
1230
        rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
!
1231

UNCOV
1232
        if rowcount == 0:
Branches [[0, 1233], [0, 1234]] missed. !
UNCOV
1233
            raise StoreError(404, "No row found (%s)" % (table,))
!
UNCOV
1234
        if rowcount > 1:
Branches [[0, 1228], [0, 1235]] missed. !
1235
            raise StoreError(500, "More than one row matched (%s)" % (table,))
!
1236

1237
    @staticmethod
1×
1238
    def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
1×
UNCOV
1239
        select_sql = "SELECT %s FROM %s WHERE %s" % (
Branches [[0, 1242], [0, 1245]] missed. !
1240
            ", ".join(retcols),
1241
            table,
1242
            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1243
        )
1244

UNCOV
1245
        txn.execute(select_sql, list(keyvalues.values()))
!
UNCOV
1246
        row = txn.fetchone()
!
1247

UNCOV
1248
        if not row:
Branches [[0, 1249], [0, 1252]] missed. !
UNCOV
1249
            if allow_none:
Branches [[0, 1250], [0, 1251]] missed. !
UNCOV
1250
                return None
!
UNCOV
1251
            raise StoreError(404, "No row found (%s)" % (table,))
!
UNCOV
1252
        if txn.rowcount > 1:
Branches [[0, 1253], [0, 1255]] missed. !
1253
            raise StoreError(500, "More than one row matched (%s)" % (table,))
!
1254

UNCOV
1255
        return dict(zip(retcols, row))
!
1256

1257
    def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
1×
1258
        """Executes a DELETE query on the named table, expecting to delete a
1259
        single row.
1260

1261
        Args:
1262
            table : string giving the table name
1263
            keyvalues : dict of column names and values to select the row with
1264
        """
UNCOV
1265
        return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
!
1266

1267
    @staticmethod
1×
1268
    def _simple_delete_one_txn(txn, table, keyvalues):
1269
        """Executes a DELETE query on the named table, expecting to delete a
1270
        single row.
1271

1272
        Args:
1273
            table : string giving the table name
1274
            keyvalues : dict of column names and values to select the row with
1275
        """
UNCOV
1276
        sql = "DELETE FROM %s WHERE %s" % (
Branches [[0, 1278], [0, 1281]] missed. !
1277
            table,
1278
            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1279
        )
1280

UNCOV
1281
        txn.execute(sql, list(keyvalues.values()))
!
UNCOV
1282
        if txn.rowcount == 0:
Branches [[0, 1283], [0, 1284]] missed. !
UNCOV
1283
            raise StoreError(404, "No row found (%s)" % (table,))
!
UNCOV
1284
        if txn.rowcount > 1:
Branches [[0, 1267], [0, 1285]] missed. !
1285
            raise StoreError(500, "More than one row matched (%s)" % (table,))
!
1286

1287
    def _simple_delete(self, table, keyvalues, desc):
1×
UNCOV
1288
        return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
!
1289

1290
    @staticmethod
1×
1291
    def _simple_delete_txn(txn, table, keyvalues):
UNCOV
1292
        sql = "DELETE FROM %s WHERE %s" % (
Branches [[0, 1294], [0, 1297]] missed. !
1293
            table,
1294
            " AND ".join("%s = ?" % (k,) for k in keyvalues),
1295
        )
1296

UNCOV
1297
        txn.execute(sql, list(keyvalues.values()))
!
UNCOV
1298
        return txn.rowcount
!
1299

1300
    def _simple_delete_many(self, table, column, iterable, keyvalues, desc):
1×
UNCOV
1301
        return self.runInteraction(
!
1302
            desc, self._simple_delete_many_txn, table, column, iterable, keyvalues
1303
        )
1304

1305
    @staticmethod
1×
1306
    def _simple_delete_many_txn(txn, table, column, iterable, keyvalues):
1307
        """Executes a DELETE query on the named table.
1308

1309
        Filters rows by if value of `column` is in `iterable`.
1310

1311
        Args:
1312
            txn : Transaction object
1313
            table : string giving the table name
1314
            column : column name to test for inclusion against `iterable`
1315
            iterable : list
1316
            keyvalues : dict of column names and values to select the rows with
1317

1318
        Returns:
1319
            int: Number rows deleted
1320
        """
UNCOV
1321
        if not iterable:
Branches [[0, 1322], [0, 1324]] missed. !
UNCOV
1322
            return 0
!
1323

UNCOV
1324
        sql = "DELETE FROM %s" % table
!
1325

UNCOV
1326
        clauses = []
!
UNCOV
1327
        values = []
!
UNCOV
1328
        clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
Branches [[0, 1328], [0, 1329]] missed. !
UNCOV
1329
        values.extend(iterable)
!
1330

UNCOV
1331
        for key, value in iteritems(keyvalues):
Branches [[0, 1332], [0, 1335]] missed. !
UNCOV
1332
            clauses.append("%s = ?" % (key,))
!
UNCOV
1333
            values.append(value)
!
1334

UNCOV
1335
        if clauses:
Branches [[0, 1336], [0, 1337]] missed. !
UNCOV
1336
            sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
!
UNCOV
1337
        txn.execute(sql, values)
!
1338

UNCOV
1339
        return txn.rowcount
!
1340

1341
    def _get_cache_dict(
1×
1342
        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
1343
    ):
1344
        # Fetch a mapping of room_id -> max stream position for "recent" rooms.
1345
        # It doesn't really matter how many we get, the StreamChangeCache will
1346
        # do the right thing to ensure it respects the max size of cache.
UNCOV
1347
        sql = (
!
1348
            "SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
1349
            " WHERE %(stream)s > ? - %(limit)s"
1350
            " GROUP BY %(entity)s"
1351
        ) % {
1352
            "table": table,
1353
            "entity": entity_column,
1354
            "stream": stream_column,
1355
            "limit": limit,
1356
        }
1357

UNCOV
1358
        sql = self.database_engine.convert_param_style(sql)
!
1359

UNCOV
1360
        txn = db_conn.cursor()
!
UNCOV
1361
        txn.execute(sql, (int(max_value),))
!
1362

UNCOV
1363
        cache = {row[0]: int(row[1]) for row in txn}
Branches [[0, 1363], [0, 1365]] missed. !
1364

UNCOV
1365
        txn.close()
!
1366

UNCOV
1367
        if cache:
Branches [[0, 1368], [0, 1370]] missed. !
1368
            min_val = min(itervalues(cache))
!
1369
        else:
UNCOV
1370
            min_val = max_value
!
1371

UNCOV
1372
        return cache, min_val
!
1373

1374
    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
1×
1375
        """Invalidates the cache and adds it to the cache stream so slaves
1376
        will know to invalidate their caches.
1377

1378
        This should only be used to invalidate caches where slaves won't
1379
        otherwise know from other replication streams that the cache should
1380
        be invalidated.
1381
        """
UNCOV
1382
        txn.call_after(cache_func.invalidate, keys)
!
UNCOV
1383
        self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
!
1384

1385
    def _invalidate_state_caches_and_stream(self, txn, room_id, members_changed):
1×
1386
        """Special case invalidation of caches based on current state.
1387

1388
        We special case this so that we can batch the cache invalidations into a
1389
        single replication poke.
1390

1391
        Args:
1392
            txn
1393
            room_id (str): Room where state changed
1394
            members_changed (iterable[str]): The user_ids of members that have changed
1395
        """
UNCOV
1396
        txn.call_after(self._invalidate_state_caches, room_id, members_changed)
!
1397

UNCOV
1398
        if members_changed:
Branches [[0, 1404], [0, 1411]] missed. !
1399
            # We need to be careful that the size of the `members_changed` list
1400
            # isn't so large that it causes problems sending over replication, so we
1401
            # send them in chunks.
1402
            # Max line length is 16K, and max user ID length is 255, so 50 should
1403
            # be safe.
UNCOV
1404
            for chunk in batch_iter(members_changed, 50):
Branches [[0, 1385], [0, 1405]] missed. !
UNCOV
1405
                keys = itertools.chain([room_id], chunk)
!
UNCOV
1406
                self._send_invalidation_to_replication(
!
1407
                    txn, _CURRENT_STATE_CACHE_NAME, keys
1408
                )
1409
        else:
1410
            # if no members changed, we still need to invalidate the other caches.
UNCOV
1411
            self._send_invalidation_to_replication(
!
1412
                txn, _CURRENT_STATE_CACHE_NAME, [room_id]
1413
            )
1414

1415
    def _invalidate_state_caches(self, room_id, members_changed):
1×
1416
        """Invalidates caches that are based on the current state, but does
1417
        not stream invalidations down replication.
1418

1419
        Args:
1420
            room_id (str): Room where state changed
1421
            members_changed (iterable[str]): The user_ids of members that have
1422
                changed
1423
        """
UNCOV
1424
        for host in set(get_domain_from_id(u) for u in members_changed):
Branches [[0, 1424], [0, 1425], [0, 1428]] missed. !
UNCOV
1425
            self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
!
UNCOV
1426
            self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
!
1427

UNCOV
1428
        self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
!
UNCOV
1429
        self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
!
UNCOV
1430
        self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
!
1431

1432
    def _attempt_to_invalidate_cache(self, cache_name, key):
1×
1433
        """Attempts to invalidate the cache of the given name, ignoring if the
1434
        cache doesn't exist. Mainly used for invalidating caches on workers,
1435
        where they may not have the cache.
1436

1437
        Args:
1438
            cache_name (str)
1439
            key (tuple)
1440
        """
UNCOV
1441
        try:
!
UNCOV
1442
            getattr(self, cache_name).invalidate(key)
!
1443
        except AttributeError:
!
1444
            # We probably haven't pulled in the cache in this worker,
1445
            # which is fine.
1446
            pass
!
1447

1448
    def _send_invalidation_to_replication(self, txn, cache_name, keys):
1×
1449
        """Notifies replication that given cache has been invalidated.
1450

1451
        Note that this does *not* invalidate the cache locally.
1452

1453
        Args:
1454
            txn
1455
            cache_name (str)
1456
            keys (iterable[str])
1457
        """
1458

UNCOV
1459
        if isinstance(self.database_engine, PostgresEngine):
Branches [[0, 1448], [0, 1464]] missed. !
1460
            # get_next() returns a context manager which is designed to wrap
1461
            # the transaction. However, we want to only get an ID when we want
1462
            # to use it, here, so we need to call __enter__ manually, and have
1463
            # __exit__ called after the transaction finishes.
UNCOV
1464
            ctx = self._cache_id_gen.get_next()
!
UNCOV
1465
            stream_id = ctx.__enter__()
!
UNCOV
1466
            txn.call_on_exception(ctx.__exit__, None, None, None)
!
UNCOV
1467
            txn.call_after(ctx.__exit__, None, None, None)
!
UNCOV
1468
            txn.call_after(self.hs.get_notifier().on_new_replication_data)
!
1469

UNCOV
1470
            self._simple_insert_txn(
!
1471
                txn,
1472
                table="cache_invalidation_stream",
1473
                values={
1474
                    "stream_id": stream_id,
1475
                    "cache_func": cache_name,
1476
                    "keys": list(keys),
1477
                    "invalidation_ts": self.clock.time_msec(),
1478
                },
1479
            )
1480

1481
    def get_all_updated_caches(self, last_id, current_id, limit):
1×
1482
        if last_id == current_id:
Branches [[0, 1483], [0, 1485]] missed. !
1483
            return defer.succeed([])
!
1484

1485
        def get_all_updated_caches_txn(txn):
!
1486
            # We purposefully don't bound by the current token, as we want to
1487
            # send across cache invalidations as quickly as possible. Cache
1488
            # invalidations are idempotent, so duplicates are fine.
1489
            sql = (
!
1490
                "SELECT stream_id, cache_func, keys, invalidation_ts"
1491
                " FROM cache_invalidation_stream"
1492
                " WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
1493
            )
1494
            txn.execute(sql, (last_id, limit))
!
1495
            return txn.fetchall()
!
1496

1497
        return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
!
1498

1499
    def get_cache_stream_token(self):
1×
UNCOV
1500
        if self._cache_id_gen:
Branches [[0, 1501], [0, 1503]] missed. !
UNCOV
1501
            return self._cache_id_gen.get_current_token()
!
1502
        else:
1503
            return 0
!
1504

1505
    def _simple_select_list_paginate(
1×
1506
        self,
1507
        table,
1508
        keyvalues,
1509
        orderby,
1510
        start,
1511
        limit,
1512
        retcols,
1513
        order_direction="ASC",
1514
        desc="_simple_select_list_paginate",
1515
    ):
1516
        """
1517
        Executes a SELECT query on the named table with start and limit,
1518
        of row numbers, which may return zero or number of rows from start to limit,
1519
        returning the result as a list of dicts.
1520

1521
        Args:
1522
            table (str): the table name
1523
            keyvalues (dict[str, T] | None):
1524
                column names and values to select the rows with, or None to not
1525
                apply a WHERE clause.
1526
            orderby (str): Column to order the results by.
1527
            start (int): Index to begin the query at.
1528
            limit (int): Number of results to return.
1529
            retcols (iterable[str]): the names of the columns to return
1530
            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
1531
        Returns:
1532
            defer.Deferred: resolves to list[dict[str, Any]]
1533
        """
1534
        return self.runInteraction(
!
1535
            desc,
1536
            self._simple_select_list_paginate_txn,
1537
            table,
1538
            keyvalues,
1539
            orderby,
1540
            start,
1541
            limit,
1542
            retcols,
1543
            order_direction=order_direction,
1544
        )
1545

1546
    @classmethod
1×
1547
    def _simple_select_list_paginate_txn(
1×
1548
        cls,
1549
        txn,
1550
        table,
1551
        keyvalues,
1552
        orderby,
1553
        start,
1554
        limit,
1555
        retcols,
1556
        order_direction="ASC",
1557
    ):
1558
        """
1559
        Executes a SELECT query on the named table with start and limit,
1560
        of row numbers, which may return zero or number of rows from start to limit,
1561
        returning the result as a list of dicts.
1562

1563
        Args:
1564
            txn : Transaction object
1565
            table (str): the table name
1566
            keyvalues (dict[str, T] | None):
1567
                column names and values to select the rows with, or None to not
1568
                apply a WHERE clause.
1569
            orderby (str): Column to order the results by.
1570
            start (int): Index to begin the query at.
1571
            limit (int): Number of results to return.
1572
            retcols (iterable[str]): the names of the columns to return
1573
            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
1574
        Returns:
1575
            defer.Deferred: resolves to list[dict[str, Any]]
1576
        """
1577
        if order_direction not in ["ASC", "DESC"]:
Branches [[0, 1578], [0, 1580]] missed. !
1578
            raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
!
1579

1580
        if keyvalues:
Branches [[0, 1581], [0, 1583]] missed. !
1581
            where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
Branches [[0, 1581], [0, 1585]] missed. !
1582
        else:
1583
            where_clause = ""
!
1584

1585
        sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
!
1586
            ", ".join(retcols),
1587
            table,
1588
            where_clause,
1589
            orderby,
1590
            order_direction,
1591
        )
1592
        txn.execute(sql, list(keyvalues.values()) + [limit, start])
!
1593

1594
        return cls.cursor_to_dict(txn)
!
1595

1596
    def get_user_count_txn(self, txn):
1×
1597
        """Get a total number of registered users in the users list.
1598

1599
        Args:
1600
            txn : Transaction object
1601
        Returns:
1602
            int : number of users
1603
        """
1604
        sql_count = "SELECT COUNT(*) FROM users WHERE is_guest = 0;"
!
1605
        txn.execute(sql_count)
!
1606
        return txn.fetchone()[0]
!
1607

1608
    def _simple_search_list(
1×
1609
        self, table, term, col, retcols, desc="_simple_search_list"
1610
    ):
1611
        """Executes a SELECT query on the named table, which may return zero or
1612
        more rows, returning the result as a list of dicts.
1613

1614
        Args:
1615
            table (str): the table name
1616
            term (str | None):
1617
                term for searching the table matched to a column.
1618
            col (str): column to query term should be matched to
1619
            retcols (iterable[str]): the names of the columns to return
1620
        Returns:
1621
            defer.Deferred: resolves to list[dict[str, Any]] or None
1622
        """
1623

1624
        return self.runInteraction(
!
1625
            desc, self._simple_search_list_txn, table, term, col, retcols
1626
        )
1627

1628
    @classmethod
1×
1629
    def _simple_search_list_txn(cls, txn, table, term, col, retcols):
1630
        """Executes a SELECT query on the named table, which may return zero or
1631
        more rows, returning the result as a list of dicts.
1632

1633
        Args:
1634
            txn : Transaction object
1635
            table (str): the table name
1636
            term (str | None):
1637
                term for searching the table matched to a column.
1638
            col (str): column to query term should be matched to
1639
            retcols (iterable[str]): the names of the columns to return
1640
        Returns:
1641
            defer.Deferred: resolves to list[dict[str, Any]] or None
1642
        """
1643
        if term:
Branches [[0, 1644], [0, 1648]] missed. !
1644
            sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
!
1645
            termvalues = ["%%" + term + "%%"]
!
1646
            txn.execute(sql, termvalues)
!
1647
        else:
1648
            return 0
!
1649

1650
        return cls.cursor_to_dict(txn)
!
1651

1652
    @property
1×
1653
    def database_engine_name(self):
1654
        return self.database_engine.module.__name__
!
1655

1656
    def get_server_version(self):
1×
1657
        """Returns a string describing the server version number"""
1658
        return self.database_engine.server_version
!
1659

1660

1661
class _RollbackButIsFineException(Exception):
1×
1662
    """ This exception is used to rollback a transaction without implying
1663
    something went wrong.
1664
    """
1665

1666
    pass
1×
1667

1668

1669
def db_to_json(db_content):
1×
1670
    """
1671
    Take some data from a database row and return a JSON-decoded object.
1672

1673
    Args:
1674
        db_content (memoryview|buffer|bytes|bytearray|unicode)
1675
    """
1676
    # psycopg2 on Python 3 returns memoryview objects, which we need to
1677
    # cast to bytes to decode
UNCOV
1678
    if isinstance(db_content, memoryview):
Branches [[0, 1679], [0, 1683]] missed. !
UNCOV
1679
        db_content = db_content.tobytes()
!
1680

1681
    # psycopg2 on Python 2 returns buffer objects, which we need to cast to
1682
    # bytes to decode
UNCOV
1683
    if PY2 and isinstance(db_content, builtins.buffer):
Branches [[0, 1684], [0, 1688]] missed. !
1684
        db_content = bytes(db_content)
!
1685

1686
    # Decode it to a Unicode string before feeding it to json.loads, so we
1687
    # consistenty get a Unicode-containing object out.
UNCOV
1688
    if isinstance(db_content, (bytes, bytearray)):
Branches [[0, 1689], [0, 1691]] missed. !
UNCOV
1689
        db_content = db_content.decode("utf8")
!
1690

UNCOV
1691
    try:
!
UNCOV
1692
        return json.loads(db_content)
!
1693
    except Exception:
!
1694
        logging.warning("Tried to decode '%r' as JSON and failed", db_content)
!
1695
        raise
!
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
BLOG · TWITTER · Legal & Privacy · Supported CI Services · What's a CI service? · Automated Testing

© 2019 Coveralls, LLC