Coveralls logob
Coveralls logo
  • Home
  • Features
  • Pricing
  • Docs
  • Sign In

matrix-org / synapse / 4532

23 Sep 2019 - 19:39 coverage decreased (-49.7%) to 17.596%
4532

Pull #6079

buildkite

Richard van der Hoff
update changelog
Pull Request #6079: Add submit_url response parameter to msisdn /requestToken

359 of 12986 branches covered (2.76%)

Branch coverage included in aggregate %.

0 of 7 new or added lines in 1 file covered. (0.0%)

18869 existing lines in 281 files now uncovered.

8809 of 39116 relevant lines covered (22.52%)

0.23 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

13.78
/synapse/storage/state.py
1
# -*- coding: utf-8 -*-
2
# Copyright 2014-2016 OpenMarket Ltd
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import logging
1×
17
from collections import namedtuple
1×
18

19
from six import iteritems, itervalues
1×
20
from six.moves import range
1×
21

22
import attr
1×
23

24
from twisted.internet import defer
1×
25

26
from synapse.api.constants import EventTypes
1×
27
from synapse.api.errors import NotFoundError
1×
28
from synapse.storage._base import SQLBaseStore
1×
29
from synapse.storage.background_updates import BackgroundUpdateStore
1×
30
from synapse.storage.engines import PostgresEngine
1×
31
from synapse.storage.events_worker import EventsWorkerStore
1×
32
from synapse.util.caches import get_cache_factor_for, intern_string
1×
33
from synapse.util.caches.descriptors import cached, cachedList
1×
34
from synapse.util.caches.dictionary_cache import DictionaryCache
1×
35
from synapse.util.stringutils import to_ascii
1×
36

37
logger = logging.getLogger(__name__)
1×
38

39

40
MAX_STATE_DELTA_HOPS = 100
1×
41

42

43
class _GetStateGroupDelta(
1×
44
    namedtuple("_GetStateGroupDelta", ("prev_group", "delta_ids"))
45
):
46
    """Return type of get_state_group_delta that implements __len__, which lets
47
    us use the itrable flag when caching
48
    """
49

50
    __slots__ = []
1×
51

52
    def __len__(self):
1×
UNCOV
53
        return len(self.delta_ids) if self.delta_ids else 0
!
54

55

56
@attr.s(slots=True)
1×
57
class StateFilter(object):
1×
58
    """A filter used when querying for state.
59

60
    Attributes:
61
        types (dict[str, set[str]|None]): Map from type to set of state keys (or
62
            None). This specifies which state_keys for the given type to fetch
63
            from the DB. If None then all events with that type are fetched. If
64
            the set is empty then no events with that type are fetched.
65
        include_others (bool): Whether to fetch events with types that do not
66
            appear in `types`.
67
    """
68

69
    types = attr.ib()
1×
70
    include_others = attr.ib(default=False)
1×
71

72
    def __attrs_post_init__(self):
1×
73
        # If `include_others` is set we canonicalise the filter by removing
74
        # wildcards from the types dictionary
75
        if self.include_others:
Branches [[0, 72]] missed. 1×
76
            self.types = {k: v for k, v in iteritems(self.types) if v is not None}
1×
77

78
    @staticmethod
1×
79
    def all():
80
        """Creates a filter that fetches everything.
81

82
        Returns:
83
            StateFilter
84
        """
85
        return StateFilter(types={}, include_others=True)
1×
86

87
    @staticmethod
1×
88
    def none():
89
        """Creates a filter that fetches nothing.
90

91
        Returns:
92
            StateFilter
93
        """
UNCOV
94
        return StateFilter(types={}, include_others=False)
!
95

96
    @staticmethod
1×
97
    def from_types(types):
98
        """Creates a filter that only fetches the given types
99

100
        Args:
101
            types (Iterable[tuple[str, str|None]]): A list of type and state
102
                keys to fetch. A state_key of None fetches everything for
103
                that type
104

105
        Returns:
106
            StateFilter
107
        """
UNCOV
108
        type_dict = {}
!
UNCOV
109
        for typ, s in types:
Branches [[0, 110], [0, 120]] missed. !
UNCOV
110
            if typ in type_dict:
Branches [[0, 111], [0, 114]] missed. !
UNCOV
111
                if type_dict[typ] is None:
Branches [[0, 112], [0, 114]] missed. !
112
                    continue
!
113

UNCOV
114
            if s is None:
Branches [[0, 115], [0, 118]] missed. !
UNCOV
115
                type_dict[typ] = None
!
UNCOV
116
                continue
!
117

UNCOV
118
            type_dict.setdefault(typ, set()).add(s)
!
119

UNCOV
120
        return StateFilter(types=type_dict)
!
121

122
    @staticmethod
1×
123
    def from_lazy_load_member_list(members):
124
        """Creates a filter that returns all non-member events, plus the member
125
        events for the given users
126

127
        Args:
128
            members (iterable[str]): Set of user IDs
129

130
        Returns:
131
            StateFilter
132
        """
UNCOV
133
        return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
!
134

135
    def return_expanded(self):
1×
136
        """Creates a new StateFilter where type wild cards have been removed
137
        (except for memberships). The returned filter is a superset of the
138
        current one, i.e. anything that passes the current filter will pass
139
        the returned filter.
140

141
        This helps the caching as the DictionaryCache knows if it has *all* the
142
        state, but does not know if it has all of the keys of a particular type,
143
        which makes wildcard lookups expensive unless we have a complete cache.
144
        Hence, if we are doing a wildcard lookup, populate the cache fully so
145
        that we can do an efficient lookup next time.
146

147
        Note that since we have two caches, one for membership events and one for
148
        other events, we can be a bit more clever than simply returning
149
        `StateFilter.all()` if `has_wildcards()` is True.
150

151
        We return a StateFilter where:
152
            1. the list of membership events to return is the same
153
            2. if there is a wildcard that matches non-member events we
154
               return all non-member events
155

156
        Returns:
157
            StateFilter
158
        """
159

UNCOV
160
        if self.is_full():
Branches [[0, 162], [0, 164]] missed. !
161
            # If we're going to return everything then there's nothing to do
UNCOV
162
            return self
!
163

164
        if not self.has_wildcards():
Branches [[0, 166], [0, 168]] missed. !
165
            # If there are no wild cards, there's nothing to do
166
            return self
!
167

168
        if EventTypes.Member in self.types:
Branches [[0, 169], [0, 171]] missed. !
169
            get_all_members = self.types[EventTypes.Member] is None
!
170
        else:
171
            get_all_members = self.include_others
!
172

173
        has_non_member_wildcard = self.include_others or any(
Branches [[0, 174], [0, 179]] missed. !
174
            state_keys is None
175
            for t, state_keys in iteritems(self.types)
176
            if t != EventTypes.Member
177
        )
178

179
        if not has_non_member_wildcard:
Branches [[0, 181], [0, 183]] missed. !
180
            # If there are no non-member wild cards we can just return ourselves
181
            return self
!
182

183
        if get_all_members:
Branches [[0, 185], [0, 189]] missed. !
184
            # We want to return everything.
185
            return StateFilter.all()
!
186
        else:
187
            # We want to return all non-members, but only particular
188
            # memberships
189
            return StateFilter(
!
190
                types={EventTypes.Member: self.types[EventTypes.Member]},
191
                include_others=True,
192
            )
193

194
    def make_sql_filter_clause(self):
1×
195
        """Converts the filter to an SQL clause.
196

197
        For example:
198

199
            f = StateFilter.from_types([("m.room.create", "")])
200
            clause, args = f.make_sql_filter_clause()
201
            clause == "(type = ? AND state_key = ?)"
202
            args == ['m.room.create', '']
203

204

205
        Returns:
206
            tuple[str, list]: The SQL string (may be empty) and arguments. An
207
            empty SQL string is returned when the filter matches everything
208
            (i.e. is "full").
209
        """
210

UNCOV
211
        where_clause = ""
!
UNCOV
212
        where_args = []
!
213

UNCOV
214
        if self.is_full():
Branches [[0, 215], [0, 217]] missed. !
UNCOV
215
            return where_clause, where_args
!
216

UNCOV
217
        if not self.include_others and not self.types:
Branches [[0, 220], [0, 223]] missed. !
218
            # i.e. this is an empty filter, so we need to return a clause that
219
            # will match nothing
220
            return "1 = 2", []
!
221

222
        # First we build up a lost of clauses for each type/state_key combo
UNCOV
223
        clauses = []
!
UNCOV
224
        for etype, state_keys in iteritems(self.types):
Branches [[0, 225], [0, 235]] missed. !
UNCOV
225
            if state_keys is None:
Branches [[0, 226], [0, 230]] missed. !
UNCOV
226
                clauses.append("(type = ?)")
!
UNCOV
227
                where_args.append(etype)
!
UNCOV
228
                continue
!
229

UNCOV
230
            for state_key in state_keys:
Branches [[0, 224], [0, 231]] missed. !
UNCOV
231
                clauses.append("(type = ? AND state_key = ?)")
!
UNCOV
232
                where_args.extend((etype, state_key))
!
233

234
        # This will match anything that appears in `self.types`
UNCOV
235
        where_clause = " OR ".join(clauses)
!
236

237
        # If we want to include stuff that's not in the types dict then we add
238
        # a `OR type NOT IN (...)` clause to the end.
UNCOV
239
        if self.include_others:
Branches [[0, 240], [0, 246]] missed. !
240
            if where_clause:
Branches [[0, 241], [0, 243]] missed. !
241
                where_clause += " OR "
!
242

243
            where_clause += "type NOT IN (%s)" % (",".join(["?"] * len(self.types)),)
!
244
            where_args.extend(self.types)
!
245

UNCOV
246
        return where_clause, where_args
!
247

248
    def max_entries_returned(self):
1×
249
        """Returns the maximum number of entries this filter will return if
250
        known, otherwise returns None.
251

252
        For example a simple state filter asking for `("m.room.create", "")`
253
        will return 1, whereas the default state filter will return None.
254

255
        This is used to bail out early if the right number of entries have been
256
        fetched.
257
        """
258
        if self.has_wildcards():
Branches [[0, 259], [0, 261]] missed. !
259
            return None
!
260

261
        return len(self.concrete_types())
!
262

263
    def filter_state(self, state_dict):
1×
264
        """Returns the state filtered with by this StateFilter
265

266
        Args:
267
            state (dict[tuple[str, str], Any]): The state map to filter
268

269
        Returns:
270
            dict[tuple[str, str], Any]: The filtered state map
271
        """
UNCOV
272
        if self.is_full():
Branches [[0, 273], [0, 275]] missed. !
UNCOV
273
            return dict(state_dict)
!
274

UNCOV
275
        filtered_state = {}
!
UNCOV
276
        for k, v in iteritems(state_dict):
Branches [[0, 277], [0, 285]] missed. !
UNCOV
277
            typ, state_key = k
!
UNCOV
278
            if typ in self.types:
Branches [[0, 279], [0, 282]] missed. !
UNCOV
279
                state_keys = self.types[typ]
!
UNCOV
280
                if state_keys is None or state_key in state_keys:
Branches [[0, 276], [0, 281]] missed. !
UNCOV
281
                    filtered_state[k] = v
!
UNCOV
282
            elif self.include_others:
Branches [[0, 276], [0, 283]] missed. !
283
                filtered_state[k] = v
!
284

UNCOV
285
        return filtered_state
!
286

287
    def is_full(self):
1×
288
        """Whether this filter fetches everything or not
289

290
        Returns:
291
            bool
292
        """
UNCOV
293
        return self.include_others and not self.types
!
294

295
    def has_wildcards(self):
1×
296
        """Whether the filter includes wildcards or is attempting to fetch
297
        specific state.
298

299
        Returns:
300
            bool
301
        """
302

303
        return self.include_others or any(
Branches [[0, 304], [0, 295]] missed. !
304
            state_keys is None for state_keys in itervalues(self.types)
305
        )
306

307
    def concrete_types(self):
1×
308
        """Returns a list of concrete type/state_keys (i.e. not None) that
309
        will be fetched. This will be a complete list if `has_wildcards`
310
        returns False, but otherwise will be a subset (or even empty).
311

312
        Returns:
313
            list[tuple[str,str]]
314
        """
315
        return [
Branches [[0, 316], [0, 307]] missed. !
316
            (t, s)
317
            for t, state_keys in iteritems(self.types)
318
            if state_keys is not None
319
            for s in state_keys
320
        ]
321

322
    def get_member_split(self):
1×
323
        """Return the filter split into two: one which assumes it's exclusively
324
        matching against member state, and one which assumes it's matching
325
        against non member state.
326

327
        This is useful due to the returned filters giving correct results for
328
        `is_full()`, `has_wildcards()`, etc, when operating against maps that
329
        either exclusively contain member events or only contain non-member
330
        events. (Which is the case when dealing with the member vs non-member
331
        state caches).
332

333
        Returns:
334
            tuple[StateFilter, StateFilter]: The member and non member filters
335
        """
336

UNCOV
337
        if EventTypes.Member in self.types:
Branches [[0, 338], [0, 343]] missed. !
UNCOV
338
            state_keys = self.types[EventTypes.Member]
!
UNCOV
339
            if state_keys is None:
Branches [[0, 340], [0, 342]] missed. !
UNCOV
340
                member_filter = StateFilter.all()
!
341
            else:
UNCOV
342
                member_filter = StateFilter({EventTypes.Member: state_keys})
!
UNCOV
343
        elif self.include_others:
Branches [[0, 344], [0, 346]] missed. !
UNCOV
344
            member_filter = StateFilter.all()
!
345
        else:
UNCOV
346
            member_filter = StateFilter.none()
!
347

UNCOV
348
        non_member_filter = StateFilter(
Branches [[0, 349], [0, 353]] missed. !
349
            types={k: v for k, v in iteritems(self.types) if k != EventTypes.Member},
350
            include_others=self.include_others,
351
        )
352

UNCOV
353
        return member_filter, non_member_filter
!
354

355

356
# this inherits from EventsWorkerStore because it calls self.get_events
357
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
1×
358
    """The parts of StateGroupStore that can be called from workers.
359
    """
360

361
    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
1×
362
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
1×
363
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
1×
364

365
    def __init__(self, db_conn, hs):
1×
UNCOV
366
        super(StateGroupWorkerStore, self).__init__(db_conn, hs)
!
367

368
        # Originally the state store used a single DictionaryCache to cache the
369
        # event IDs for the state types in a given state group to avoid hammering
370
        # on the state_group* tables.
371
        #
372
        # The point of using a DictionaryCache is that it can cache a subset
373
        # of the state events for a given state group (i.e. a subset of the keys for a
374
        # given dict which is an entry in the cache for a given state group ID).
375
        #
376
        # However, this poses problems when performing complicated queries
377
        # on the store - for instance: "give me all the state for this group, but
378
        # limit members to this subset of users", as DictionaryCache's API isn't
379
        # rich enough to say "please cache any of these fields, apart from this subset".
380
        # This is problematic when lazy loading members, which requires this behaviour,
381
        # as without it the cache has no choice but to speculatively load all
382
        # state events for the group, which negates the efficiency being sought.
383
        #
384
        # Rather than overcomplicating DictionaryCache's API, we instead split the
385
        # state_group_cache into two halves - one for tracking non-member events,
386
        # and the other for tracking member_events.  This means that lazy loading
387
        # queries can be made in a cache-friendly manner by querying both caches
388
        # separately and then merging the result.  So for the example above, you
389
        # would query the members cache for a specific subset of state keys
390
        # (which DictionaryCache will handle efficiently and fine) and the non-members
391
        # cache for all state (which DictionaryCache will similarly handle fine)
392
        # and then just merge the results together.
393
        #
394
        # We size the non-members cache to be smaller than the members cache as the
395
        # vast majority of state in Matrix (today) is member events.
396

UNCOV
397
        self._state_group_cache = DictionaryCache(
!
398
            "*stateGroupCache*",
399
            # TODO: this hasn't been tuned yet
400
            50000 * get_cache_factor_for("stateGroupCache"),
401
        )
UNCOV
402
        self._state_group_members_cache = DictionaryCache(
!
403
            "*stateGroupMembersCache*",
404
            500000 * get_cache_factor_for("stateGroupMembersCache"),
405
        )
406

407
    @defer.inlineCallbacks
1×
408
    def get_room_version(self, room_id):
409
        """Get the room_version of a given room
410

411
        Args:
412
            room_id (str)
413

414
        Returns:
415
            Deferred[str]
416

417
        Raises:
418
            NotFoundError if the room is unknown
419
        """
420
        # for now we do this by looking at the create event. We may want to cache this
421
        # more intelligently in future.
422

423
        # Retrieve the room's create event
UNCOV
424
        create_event = yield self.get_create_event_for_room(room_id)
!
UNCOV
425
        return create_event.content.get("room_version", "1")
!
426

427
    @defer.inlineCallbacks
1×
428
    def get_room_predecessor(self, room_id):
429
        """Get the predecessor room of an upgraded room if one exists.
430
        Otherwise return None.
431

432
        Args:
433
            room_id (str)
434

435
        Returns:
436
            Deferred[unicode|None]: predecessor room id
437

438
        Raises:
439
            NotFoundError if the room is unknown
440
        """
441
        # Retrieve the room's create event
UNCOV
442
        create_event = yield self.get_create_event_for_room(room_id)
!
443

444
        # Return predecessor if present
UNCOV
445
        return create_event.content.get("predecessor", None)
!
446

447
    @defer.inlineCallbacks
1×
448
    def get_create_event_for_room(self, room_id):
449
        """Get the create state event for a room.
450

451
        Args:
452
            room_id (str)
453

454
        Returns:
455
            Deferred[EventBase]: The room creation event.
456

457
        Raises:
458
            NotFoundError if the room is unknown
459
        """
UNCOV
460
        state_ids = yield self.get_current_state_ids(room_id)
!
UNCOV
461
        create_id = state_ids.get((EventTypes.Create, ""))
!
462

463
        # If we can't find the create event, assume we've hit a dead end
UNCOV
464
        if not create_id:
Branches [[0, 465], [0, 468]] missed. !
465
            raise NotFoundError("Unknown room %s" % (room_id))
!
466

467
        # Retrieve the room's create event and return
UNCOV
468
        create_event = yield self.get_event(create_id)
!
UNCOV
469
        return create_event
!
470

471
    @cached(max_entries=100000, iterable=True)
1×
472
    def get_current_state_ids(self, room_id):
473
        """Get the current state event ids for a room based on the
474
        current_state_events table.
475

476
        Args:
477
            room_id (str)
478

479
        Returns:
480
            deferred: dict of (type, state_key) -> event_id
481
        """
482

UNCOV
483
        def _get_current_state_ids_txn(txn):
!
UNCOV
484
            txn.execute(
!
485
                """SELECT type, state_key, event_id FROM current_state_events
486
                WHERE room_id = ?
487
                """,
488
                (room_id,),
489
            )
490

UNCOV
491
            return {
Branches [[0, 491], [0, 483]] missed. !
492
                (intern_string(r[0]), intern_string(r[1])): to_ascii(r[2]) for r in txn
493
            }
494

UNCOV
495
        return self.runInteraction("get_current_state_ids", _get_current_state_ids_txn)
!
496

497
    # FIXME: how should this be cached?
498
    def get_filtered_current_state_ids(self, room_id, state_filter=StateFilter.all()):
1×
499
        """Get the current state event of a given type for a room based on the
500
        current_state_events table.  This may not be as up-to-date as the result
501
        of doing a fresh state resolution as per state_handler.get_current_state
502

503
        Args:
504
            room_id (str)
505
            state_filter (StateFilter): The state filter used to fetch state
506
                from the database.
507

508
        Returns:
509
            Deferred[dict[tuple[str, str], str]]: Map from type/state_key to
510
            event ID.
511
        """
512

UNCOV
513
        where_clause, where_args = state_filter.make_sql_filter_clause()
!
514

UNCOV
515
        if not where_clause:
Branches [[0, 517], [0, 519]] missed. !
516
            # We delegate to the cached version
UNCOV
517
            return self.get_current_state_ids(room_id)
!
518

UNCOV
519
        def _get_filtered_current_state_ids_txn(txn):
!
UNCOV
520
            results = {}
!
UNCOV
521
            sql = """
!
522
                SELECT type, state_key, event_id FROM current_state_events
523
                WHERE room_id = ?
524
            """
525

UNCOV
526
            if where_clause:
Branches [[0, 527], [0, 529]] missed. !
UNCOV
527
                sql += " AND (%s)" % (where_clause,)
!
528

UNCOV
529
            args = [room_id]
!
UNCOV
530
            args.extend(where_args)
!
UNCOV
531
            txn.execute(sql, args)
!
UNCOV
532
            for row in txn:
Branches [[0, 533], [0, 537]] missed. !
UNCOV
533
                typ, state_key, event_id = row
!
UNCOV
534
                key = (intern_string(typ), intern_string(state_key))
!
UNCOV
535
                results[key] = event_id
!
536

UNCOV
537
            return results
!
538

UNCOV
539
        return self.runInteraction(
!
540
            "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
541
        )
542

543
    @defer.inlineCallbacks
1×
544
    def get_canonical_alias_for_room(self, room_id):
545
        """Get canonical alias for room, if any
546

547
        Args:
548
            room_id (str)
549

550
        Returns:
551
            Deferred[str|None]: The canonical alias, if any
552
        """
553

UNCOV
554
        state = yield self.get_filtered_current_state_ids(
!
555
            room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
556
        )
557

UNCOV
558
        event_id = state.get((EventTypes.CanonicalAlias, ""))
!
UNCOV
559
        if not event_id:
Branches [[0, 560], [0, 562]] missed. !
UNCOV
560
            return
!
561

562
        event = yield self.get_event(event_id, allow_none=True)
!
563
        if not event:
Branches [[0, 564], [0, 566]] missed. !
564
            return
!
565

566
        return event.content.get("canonical_alias")
!
567

568
    @cached(max_entries=10000, iterable=True)
1×
569
    def get_state_group_delta(self, state_group):
570
        """Given a state group try to return a previous group and a delta between
571
        the old and the new.
572

573
        Returns:
574
            (prev_group, delta_ids), where both may be None.
575
        """
576

UNCOV
577
        def _get_state_group_delta_txn(txn):
!
UNCOV
578
            prev_group = self._simple_select_one_onecol_txn(
!
579
                txn,
580
                table="state_group_edges",
581
                keyvalues={"state_group": state_group},
582
                retcol="prev_state_group",
583
                allow_none=True,
584
            )
585

UNCOV
586
            if not prev_group:
Branches [[0, 587], [0, 589]] missed. !
UNCOV
587
                return _GetStateGroupDelta(None, None)
!
588

UNCOV
589
            delta_ids = self._simple_select_list_txn(
!
590
                txn,
591
                table="state_groups_state",
592
                keyvalues={"state_group": state_group},
593
                retcols=("type", "state_key", "event_id"),
594
            )
595

UNCOV
596
            return _GetStateGroupDelta(
Branches [[0, 598], [0, 577]] missed. !
597
                prev_group,
598
                {(row["type"], row["state_key"]): row["event_id"] for row in delta_ids},
599
            )
600

UNCOV
601
        return self.runInteraction("get_state_group_delta", _get_state_group_delta_txn)
!
602

603
    @defer.inlineCallbacks
1×
604
    def get_state_groups_ids(self, _room_id, event_ids):
605
        """Get the event IDs of all the state for the state groups for the given events
606

607
        Args:
608
            _room_id (str): id of the room for these events
609
            event_ids (iterable[str]): ids of the events
610

611
        Returns:
612
            Deferred[dict[int, dict[tuple[str, str], str]]]:
613
                dict of state_group_id -> (dict of (type, state_key) -> event id)
614
        """
UNCOV
615
        if not event_ids:
Branches [[0, 616], [0, 618]] missed. !
UNCOV
616
            return {}
!
617

UNCOV
618
        event_to_groups = yield self._get_state_group_for_events(event_ids)
!
619

UNCOV
620
        groups = set(itervalues(event_to_groups))
!
UNCOV
621
        group_to_state = yield self._get_state_for_groups(groups)
!
622

UNCOV
623
        return group_to_state
!
624

625
    @defer.inlineCallbacks
1×
626
    def get_state_ids_for_group(self, state_group):
627
        """Get the event IDs of all the state in the given state group
628

629
        Args:
630
            state_group (int)
631

632
        Returns:
633
            Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
634
        """
635
        group_to_state = yield self._get_state_for_groups((state_group,))
!
636

637
        return group_to_state[state_group]
!
638

639
    @defer.inlineCallbacks
1×
640
    def get_state_groups(self, room_id, event_ids):
641
        """ Get the state groups for the given list of event_ids
642

643
        Returns:
644
            Deferred[dict[int, list[EventBase]]]:
645
                dict of state_group_id -> list of state events.
646
        """
UNCOV
647
        if not event_ids:
Branches [[0, 648], [0, 650]] missed. !
648
            return {}
!
649

UNCOV
650
        group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
!
651

UNCOV
652
        state_event_map = yield self.get_events(
Branches [[0, 654], [0, 661]] missed. !
653
            [
654
                ev_id
655
                for group_ids in itervalues(group_to_ids)
656
                for ev_id in itervalues(group_ids)
657
            ],
658
            get_prev_content=False,
659
        )
660

UNCOV
661
        return {
Branches [[0, 663], [0, 661], [0, 639]] missed. !
662
            group: [
663
                state_event_map[v]
664
                for v in itervalues(event_id_map)
665
                if v in state_event_map
666
            ]
667
            for group, event_id_map in iteritems(group_to_ids)
668
        }
669

670
    @defer.inlineCallbacks
1×
671
    def _get_state_groups_from_groups(self, groups, state_filter):
672
        """Returns the state groups for a given set of groups, filtering on
673
        types of state events.
674

675
        Args:
676
            groups(list[int]): list of state group IDs to query
677
            state_filter (StateFilter): The state filter used to fetch state
678
                from the database.
679
        Returns:
680
            Deferred[dict[int, dict[tuple[str, str], str]]]:
681
                dict of state_group_id -> (dict of (type, state_key) -> event id)
682
        """
UNCOV
683
        results = {}
!
684

UNCOV
685
        chunks = [groups[i : i + 100] for i in range(0, len(groups), 100)]
Branches [[0, 685], [0, 686]] missed. !
UNCOV
686
        for chunk in chunks:
Branches [[0, 687], [0, 695]] missed. !
UNCOV
687
            res = yield self.runInteraction(
!
688
                "_get_state_groups_from_groups",
689
                self._get_state_groups_from_groups_txn,
690
                chunk,
691
                state_filter,
692
            )
UNCOV
693
            results.update(res)
!
694

UNCOV
695
        return results
!
696

697
    def _get_state_groups_from_groups_txn(
1×
698
        self, txn, groups, state_filter=StateFilter.all()
699
    ):
UNCOV
700
        results = {group: {} for group in groups}
Branches [[0, 700], [0, 702]] missed. !
701

UNCOV
702
        where_clause, where_args = state_filter.make_sql_filter_clause()
!
703

704
        # Unless the filter clause is empty, we're going to append it after an
705
        # existing where clause
UNCOV
706
        if where_clause:
Branches [[0, 707], [0, 709]] missed. !
707
            where_clause = " AND (%s)" % (where_clause,)
!
708

UNCOV
709
        if isinstance(self.database_engine, PostgresEngine):
Branches [[0, 712], [0, 749]] missed. !
710
            # Temporarily disable sequential scans in this transaction. This is
711
            # a temporary hack until we can add the right indices in
UNCOV
712
            txn.execute("SET LOCAL enable_seqscan=off")
!
713

714
            # The below query walks the state_group tree so that the "state"
715
            # table includes all state_groups in the tree. It then joins
716
            # against `state_groups_state` to fetch the latest state.
717
            # It assumes that previous state groups are always numerically
718
            # lesser.
719
            # The PARTITION is used to get the event_id in the greatest state
720
            # group for the given type, state_key.
721
            # This may return multiple rows per (type, state_key), but last_value
722
            # should be the same.
UNCOV
723
            sql = """
!
724
                WITH RECURSIVE state(state_group) AS (
725
                    VALUES(?::bigint)
726
                    UNION ALL
727
                    SELECT prev_state_group FROM state_group_edges e, state s
728
                    WHERE s.state_group = e.state_group
729
                )
730
                SELECT DISTINCT type, state_key, last_value(event_id) OVER (
731
                    PARTITION BY type, state_key ORDER BY state_group ASC
732
                    ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
733
                ) AS event_id FROM state_groups_state
734
                WHERE state_group IN (
735
                    SELECT state_group FROM state
736
                )
737
            """
738

UNCOV
739
            for group in groups:
Branches [[0, 740], [0, 796]] missed. !
UNCOV
740
                args = [group]
!
UNCOV
741
                args.extend(where_args)
!
742

UNCOV
743
                txn.execute(sql + where_clause, args)
!
UNCOV
744
                for row in txn:
Branches [[0, 739], [0, 745]] missed. !
UNCOV
745
                    typ, state_key, event_id = row
!
UNCOV
746
                    key = (typ, state_key)
!
UNCOV
747
                    results[group][key] = event_id
!
748
        else:
749
            max_entries_returned = state_filter.max_entries_returned()
!
750

751
            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
752
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
753
            for group in groups:
Branches [[0, 754], [0, 796]] missed. !
754
                next_group = group
!
755

756
                while next_group:
Branches [[0, 753], [0, 762]] missed. !
757
                    # We did this before by getting the list of group ids, and
758
                    # then passing that list to sqlite to get latest event for
759
                    # each (type, state_key). However, that was terribly slow
760
                    # without the right indices (which we can't add until
761
                    # after we finish deduping state, which requires this func)
762
                    args = [next_group]
!
763
                    args.extend(where_args)
!
764

765
                    txn.execute(
!
766
                        "SELECT type, state_key, event_id FROM state_groups_state"
767
                        " WHERE state_group = ? " + where_clause,
768
                        args,
769
                    )
770
                    results[group].update(
Branches [[0, 771], [0, 782]] missed. !
771
                        ((typ, state_key), event_id)
772
                        for typ, state_key, event_id in txn
773
                        if (typ, state_key) not in results[group]
774
                    )
775

776
                    # If the number of entries in the (type,state_key)->event_id dict
777
                    # matches the number of (type,state_keys) types we were searching
778
                    # for, then we must have found them all, so no need to go walk
779
                    # further down the tree... UNLESS our types filter contained
780
                    # wildcards (i.e. Nones) in which case we have to do an exhaustive
781
                    # search
782
                    if (
Branches [[0, 786], [0, 788]] missed. !
783
                        max_entries_returned is not None
784
                        and len(results[group]) == max_entries_returned
785
                    ):
786
                        break
!
787

788
                    next_group = self._simple_select_one_onecol_txn(
!
789
                        txn,
790
                        table="state_group_edges",
791
                        keyvalues={"state_group": next_group},
792
                        retcol="prev_state_group",
793
                        allow_none=True,
794
                    )
795

UNCOV
796
        return results
!
797

798
    @defer.inlineCallbacks
1×
799
    def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
1×
800
        """Given a list of event_ids and type tuples, return a list of state
801
        dicts for each event.
802

803
        Args:
804
            event_ids (list[string])
805
            state_filter (StateFilter): The state filter used to fetch state
806
                from the database.
807

808
        Returns:
809
            deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
810
        """
UNCOV
811
        event_to_groups = yield self._get_state_group_for_events(event_ids)
!
812

UNCOV
813
        groups = set(itervalues(event_to_groups))
!
UNCOV
814
        group_to_state = yield self._get_state_for_groups(groups, state_filter)
!
815

UNCOV
816
        state_event_map = yield self.get_events(
Branches [[0, 817], [0, 821]] missed. !
817
            [ev_id for sd in itervalues(group_to_state) for ev_id in itervalues(sd)],
818
            get_prev_content=False,
819
        )
820

UNCOV
821
        event_to_state = {
Branches [[0, 822], [0, 821], [0, 830]] missed. !
822
            event_id: {
823
                k: state_event_map[v]
824
                for k, v in iteritems(group_to_state[group])
825
                if v in state_event_map
826
            }
827
            for event_id, group in iteritems(event_to_groups)
828
        }
829

UNCOV
830
        return {event: event_to_state[event] for event in event_ids}
Branches [[0, 830], [0, 798]] missed. !
831

832
    @defer.inlineCallbacks
1×
833
    def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
1×
834
        """
835
        Get the state dicts corresponding to a list of events, containing the event_ids
836
        of the state events (as opposed to the events themselves)
837

838
        Args:
839
            event_ids(list(str)): events whose state should be returned
840
            state_filter (StateFilter): The state filter used to fetch state
841
                from the database.
842

843
        Returns:
844
            A deferred dict from event_id -> (type, state_key) -> event_id
845
        """
UNCOV
846
        event_to_groups = yield self._get_state_group_for_events(event_ids)
!
847

UNCOV
848
        groups = set(itervalues(event_to_groups))
!
UNCOV
849
        group_to_state = yield self._get_state_for_groups(groups, state_filter)
!
850

UNCOV
851
        event_to_state = {
Branches [[0, 851], [0, 856]] missed. !
852
            event_id: group_to_state[group]
853
            for event_id, group in iteritems(event_to_groups)
854
        }
855

UNCOV
856
        return {event: event_to_state[event] for event in event_ids}
Branches [[0, 856], [0, 832]] missed. !
857

858
    @defer.inlineCallbacks
1×
859
    def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
1×
860
        """
861
        Get the state dict corresponding to a particular event
862

863
        Args:
864
            event_id(str): event whose state should be returned
865
            state_filter (StateFilter): The state filter used to fetch state
866
                from the database.
867

868
        Returns:
869
            A deferred dict from (type, state_key) -> state_event
870
        """
871
        state_map = yield self.get_state_for_events([event_id], state_filter)
!
872
        return state_map[event_id]
!
873

874
    @defer.inlineCallbacks
1×
875
    def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
1×
876
        """
877
        Get the state dict corresponding to a particular event
878

879
        Args:
880
            event_id(str): event whose state should be returned
881
            state_filter (StateFilter): The state filter used to fetch state
882
                from the database.
883

884
        Returns:
885
            A deferred dict from (type, state_key) -> state_event
886
        """
UNCOV
887
        state_map = yield self.get_state_ids_for_events([event_id], state_filter)
!
UNCOV
888
        return state_map[event_id]
!
889

890
    @cached(max_entries=50000)
1×
891
    def _get_state_group_for_event(self, event_id):
892
        return self._simple_select_one_onecol(
!
893
            table="event_to_state_groups",
894
            keyvalues={"event_id": event_id},
895
            retcol="state_group",
896
            allow_none=True,
897
            desc="_get_state_group_for_event",
898
        )
899

900
    @cachedList(
1×
901
        cached_method_name="_get_state_group_for_event",
902
        list_name="event_ids",
903
        num_args=1,
904
        inlineCallbacks=True,
905
    )
906
    def _get_state_group_for_events(self, event_ids):
907
        """Returns mapping event_id -> state_group
908
        """
UNCOV
909
        rows = yield self._simple_select_many_batch(
!
910
            table="event_to_state_groups",
911
            column="event_id",
912
            iterable=event_ids,
913
            keyvalues={},
914
            retcols=("event_id", "state_group"),
915
            desc="_get_state_group_for_events",
916
        )
917

UNCOV
918
        return {row["event_id"]: row["state_group"] for row in rows}
Branches [[0, 918], [0, 900]] missed. !
919

920
    def _get_state_for_group_using_cache(self, cache, group, state_filter):
1×
921
        """Checks if group is in cache. See `_get_state_for_groups`
922

923
        Args:
924
            cache(DictionaryCache): the state group cache to use
925
            group(int): The state group to lookup
926
            state_filter (StateFilter): The state filter used to fetch state
927
                from the database.
928

929
        Returns 2-tuple (`state_dict`, `got_all`).
930
        `got_all` is a bool indicating if we successfully retrieved all
931
        requests state from the cache, if False we need to query the DB for the
932
        missing state.
933
        """
UNCOV
934
        is_all, known_absent, state_dict_ids = cache.get(group)
!
935

UNCOV
936
        if is_all or state_filter.is_full():
Branches [[0, 939], [0, 942]] missed. !
937
            # Either we have everything or want everything, either way
938
            # `is_all` tells us whether we've gotten everything.
UNCOV
939
            return state_filter.filter_state(state_dict_ids), is_all
!
940

941
        # tracks whether any of our requested types are missing from the cache
942
        missing_types = False
!
943

944
        if state_filter.has_wildcards():
Branches [[0, 948], [0, 952]] missed. !
945
            # We don't know if we fetched all the state keys for the types in
946
            # the filter that are wildcards, so we have to assume that we may
947
            # have missed some.
948
            missing_types = True
!
949
        else:
950
            # There aren't any wild cards, so `concrete_types()` returns the
951
            # complete list of event types we're wanting.
952
            for key in state_filter.concrete_types():
Branches [[0, 953], [0, 957]] missed. !
953
                if key not in state_dict_ids and key not in known_absent:
Branches [[0, 952], [0, 954]] missed. !
954
                    missing_types = True
!
955
                    break
!
956

957
        return state_filter.filter_state(state_dict_ids), not missing_types
!
958

959
    @defer.inlineCallbacks
1×
960
    def _get_state_for_groups(self, groups, state_filter=StateFilter.all()):
1×
961
        """Gets the state at each of a list of state groups, optionally
962
        filtering by type/state_key
963

964
        Args:
965
            groups (iterable[int]): list of state groups for which we want
966
                to get the state.
967
            state_filter (StateFilter): The state filter used to fetch state
968
                from the database.
969
        Returns:
970
            Deferred[dict[int, dict[tuple[str, str], str]]]:
971
                dict of state_group_id -> (dict of (type, state_key) -> event id)
972
        """
973

UNCOV
974
        member_filter, non_member_filter = state_filter.get_member_split()
!
975

976
        # Now we look them up in the member and non-member caches
UNCOV
977
        non_member_state, incomplete_groups_nm, = (
!
978
            yield self._get_state_for_groups_using_cache(
979
                groups, self._state_group_cache, state_filter=non_member_filter
980
            )
981
        )
982

UNCOV
983
        member_state, incomplete_groups_m, = (
!
984
            yield self._get_state_for_groups_using_cache(
985
                groups, self._state_group_members_cache, state_filter=member_filter
986
            )
987
        )
988

UNCOV
989
        state = dict(non_member_state)
!
UNCOV
990
        for group in groups:
Branches [[0, 991], [0, 995]] missed. !
UNCOV
991
            state[group].update(member_state[group])
!
992

993
        # Now fetch any missing groups from the database
994

UNCOV
995
        incomplete_groups = incomplete_groups_m | incomplete_groups_nm
!
996

UNCOV
997
        if not incomplete_groups:
Branches [[0, 998], [0, 1000]] missed. !
UNCOV
998
            return state
!
999

UNCOV
1000
        cache_sequence_nm = self._state_group_cache.sequence
!
UNCOV
1001
        cache_sequence_m = self._state_group_members_cache.sequence
!
1002

1003
        # Help the cache hit ratio by expanding the filter a bit
UNCOV
1004
        db_state_filter = state_filter.return_expanded()
!
1005

UNCOV
1006
        group_to_state_dict = yield self._get_state_groups_from_groups(
!
1007
            list(incomplete_groups), state_filter=db_state_filter
1008
        )
1009

1010
        # Now lets update the caches
UNCOV
1011
        self._insert_into_cache(
!
1012
            group_to_state_dict,
1013
            db_state_filter,
1014
            cache_seq_num_members=cache_sequence_m,
1015
            cache_seq_num_non_members=cache_sequence_nm,
1016
        )
1017

1018
        # And finally update the result dict, by filtering out any extra
1019
        # stuff we pulled out of the database.
UNCOV
1020
        for group, group_state_dict in iteritems(group_to_state_dict):
Branches [[0, 1023], [0, 1025]] missed. !
1021
            # We just replace any existing entries, as we will have loaded
1022
            # everything we need from the database anyway.
UNCOV
1023
            state[group] = state_filter.filter_state(group_state_dict)
!
1024

UNCOV
1025
        return state
!
1026

1027
    def _get_state_for_groups_using_cache(self, groups, cache, state_filter):
1×
1028
        """Gets the state at each of a list of state groups, optionally
1029
        filtering by type/state_key, querying from a specific cache.
1030

1031
        Args:
1032
            groups (iterable[int]): list of state groups for which we want
1033
                to get the state.
1034
            cache (DictionaryCache): the cache of group ids to state dicts which
1035
                we will pass through - either the normal state cache or the specific
1036
                members state cache.
1037
            state_filter (StateFilter): The state filter used to fetch state
1038
                from the database.
1039

1040
        Returns:
1041
            tuple[dict[int, dict[tuple[str, str], str]], set[int]]: Tuple of
1042
            dict of state_group_id -> (dict of (type, state_key) -> event id)
1043
            of entries in the cache, and the state group ids either missing
1044
            from the cache or incomplete.
1045
        """
UNCOV
1046
        results = {}
!
UNCOV
1047
        incomplete_groups = set()
!
UNCOV
1048
        for group in set(groups):
Branches [[0, 1049], [0, 1057]] missed. !
UNCOV
1049
            state_dict_ids, got_all = self._get_state_for_group_using_cache(
!
1050
                cache, group, state_filter
1051
            )
UNCOV
1052
            results[group] = state_dict_ids
!
1053

UNCOV
1054
            if not got_all:
Branches [[0, 1048], [0, 1055]] missed. !
UNCOV
1055
                incomplete_groups.add(group)
!
1056

UNCOV
1057
        return results, incomplete_groups
!
1058

1059
    def _insert_into_cache(
1×
1060
        self,
1061
        group_to_state_dict,
1062
        state_filter,
1063
        cache_seq_num_members,
1064
        cache_seq_num_non_members,
1065
    ):
1066
        """Inserts results from querying the database into the relevant cache.
1067

1068
        Args:
1069
            group_to_state_dict (dict): The new entries pulled from database.
1070
                Map from state group to state dict
1071
            state_filter (StateFilter): The state filter used to fetch state
1072
                from the database.
1073
            cache_seq_num_members (int): Sequence number of member cache since
1074
                last lookup in cache
1075
            cache_seq_num_non_members (int): Sequence number of member cache since
1076
                last lookup in cache
1077
        """
1078

1079
        # We need to work out which types we've fetched from the DB for the
1080
        # member vs non-member caches. This should be as accurate as possible,
1081
        # but can be an underestimate (e.g. when we have wild cards)
1082

UNCOV
1083
        member_filter, non_member_filter = state_filter.get_member_split()
!
UNCOV
1084
        if member_filter.is_full():
Branches [[0, 1086], [0, 1090]] missed. !
1085
            # We fetched all member events
UNCOV
1086
            member_types = None
!
1087
        else:
1088
            # `concrete_types()` will only return a subset when there are wild
1089
            # cards in the filter, but that's fine.
1090
            member_types = member_filter.concrete_types()
!
1091

UNCOV
1092
        if non_member_filter.is_full():
Branches [[0, 1094], [0, 1096]] missed. !
1093
            # We fetched all non member events
UNCOV
1094
            non_member_types = None
!
1095
        else:
1096
            non_member_types = non_member_filter.concrete_types()
!
1097

UNCOV
1098
        for group, group_state_dict in iteritems(group_to_state_dict):
Branches [[0, 1059], [0, 1099]] missed. !
UNCOV
1099
            state_dict_members = {}
!
UNCOV
1100
            state_dict_non_members = {}
!
1101

UNCOV
1102
            for k, v in iteritems(group_state_dict):
Branches [[0, 1103], [0, 1108]] missed. !
1103
                if k[0] == EventTypes.Member:
Branches [[0, 1104], [0, 1106]] missed. !
1104
                    state_dict_members[k] = v
!
1105
                else:
1106
                    state_dict_non_members[k] = v
!
1107

UNCOV
1108
            self._state_group_members_cache.update(
!
1109
                cache_seq_num_members,
1110
                key=group,
1111
                value=state_dict_members,
1112
                fetched_keys=member_types,
1113
            )
1114

UNCOV
1115
            self._state_group_cache.update(
!
1116
                cache_seq_num_non_members,
1117
                key=group,
1118
                value=state_dict_non_members,
1119
                fetched_keys=non_member_types,
1120
            )
1121

1122
    def store_state_group(
1×
1123
        self, event_id, room_id, prev_group, delta_ids, current_state_ids
1124
    ):
1125
        """Store a new set of state, returning a newly assigned state group.
1126

1127
        Args:
1128
            event_id (str): The event ID for which the state was calculated
1129
            room_id (str)
1130
            prev_group (int|None): A previous state group for the room, optional.
1131
            delta_ids (dict|None): The delta between state at `prev_group` and
1132
                `current_state_ids`, if `prev_group` was given. Same format as
1133
                `current_state_ids`.
1134
            current_state_ids (dict): The state to store. Map of (type, state_key)
1135
                to event_id.
1136

1137
        Returns:
1138
            Deferred[int]: The state group ID
1139
        """
1140

UNCOV
1141
        def _store_state_group_txn(txn):
!
UNCOV
1142
            if current_state_ids is None:
Branches [[0, 1144], [0, 1146]] missed. !
1143
                # AFAIK, this can never happen
1144
                raise Exception("current_state_ids cannot be None")
!
1145

UNCOV
1146
            state_group = self.database_engine.get_next_state_group_id(txn)
!
1147

UNCOV
1148
            self._simple_insert_txn(
!
1149
                txn,
1150
                table="state_groups",
1151
                values={"id": state_group, "room_id": room_id, "event_id": event_id},
1152
            )
1153

1154
            # We persist as a delta if we can, while also ensuring the chain
1155
            # of deltas isn't tooo long, as otherwise read performance degrades.
UNCOV
1156
            if prev_group:
Branches [[0, 1157], [0, 1171]] missed. !
UNCOV
1157
                is_in_db = self._simple_select_one_onecol_txn(
!
1158
                    txn,
1159
                    table="state_groups",
1160
                    keyvalues={"id": prev_group},
1161
                    retcol="id",
1162
                    allow_none=True,
1163
                )
UNCOV
1164
                if not is_in_db:
Branches [[0, 1165], [0, 1170]] missed. !
1165
                    raise Exception(
!
1166
                        "Trying to persist state with unpersisted prev_group: %r"
1167
                        % (prev_group,)
1168
                    )
1169

UNCOV
1170
                potential_hops = self._count_state_group_hops_txn(txn, prev_group)
!
UNCOV
1171
            if prev_group and potential_hops < MAX_STATE_DELTA_HOPS:
Branches [[0, 1172], [0, 1193]] missed. !
UNCOV
1172
                self._simple_insert_txn(
!
1173
                    txn,
1174
                    table="state_group_edges",
1175
                    values={"state_group": state_group, "prev_state_group": prev_group},
1176
                )
1177

UNCOV
1178
                self._simple_insert_many_txn(
Branches [[0, 1182], [0, 1213]] missed. !
1179
                    txn,
1180
                    table="state_groups_state",
1181
                    values=[
1182
                        {
1183
                            "state_group": state_group,
1184
                            "room_id": room_id,
1185
                            "type": key[0],
1186
                            "state_key": key[1],
1187
                            "event_id": state_id,
1188
                        }
1189
                        for key, state_id in iteritems(delta_ids)
1190
                    ],
1191
                )
1192
            else:
UNCOV
1193
                self._simple_insert_many_txn(
Branches [[0, 1197], [0, 1213]] missed. !
1194
                    txn,
1195
                    table="state_groups_state",
1196
                    values=[
1197
                        {
1198
                            "state_group": state_group,
1199
                            "room_id": room_id,
1200
                            "type": key[0],
1201
                            "state_key": key[1],
1202
                            "event_id": state_id,
1203
                        }
1204
                        for key, state_id in iteritems(current_state_ids)
1205
                    ],
1206
                )
1207

1208
            # Prefill the state group caches with this group.
1209
            # It's fine to use the sequence like this as the state group map
1210
            # is immutable. (If the map wasn't immutable then this prefill could
1211
            # race with another update)
1212

UNCOV
1213
            current_member_state_ids = {
Branches [[0, 1213], [0, 1218]] missed. !
1214
                s: ev
1215
                for (s, ev) in iteritems(current_state_ids)
1216
                if s[0] == EventTypes.Member
1217
            }
UNCOV
1218
            txn.call_after(
!
1219
                self._state_group_members_cache.update,
1220
                self._state_group_members_cache.sequence,
1221
                key=state_group,
1222
                value=dict(current_member_state_ids),
1223
            )
1224

UNCOV
1225
            current_non_member_state_ids = {
Branches [[0, 1225], [0, 1230]] missed. !
1226
                s: ev
1227
                for (s, ev) in iteritems(current_state_ids)
1228
                if s[0] != EventTypes.Member
1229
            }
UNCOV
1230
            txn.call_after(
!
1231
                self._state_group_cache.update,
1232
                self._state_group_cache.sequence,
1233
                key=state_group,
1234
                value=dict(current_non_member_state_ids),
1235
            )
1236

UNCOV
1237
            return state_group
!
1238

UNCOV
1239
        return self.runInteraction("store_state_group", _store_state_group_txn)
!
1240

1241
    def _count_state_group_hops_txn(self, txn, state_group):
1×
1242
        """Given a state group, count how many hops there are in the tree.
1243

1244
        This is used to ensure the delta chains don't get too long.
1245
        """
UNCOV
1246
        if isinstance(self.database_engine, PostgresEngine):
Branches [[0, 1247], [0, 1266]] missed. !
UNCOV
1247
            sql = """
!
1248
                WITH RECURSIVE state(state_group) AS (
1249
                    VALUES(?::bigint)
1250
                    UNION ALL
1251
                    SELECT prev_state_group FROM state_group_edges e, state s
1252
                    WHERE s.state_group = e.state_group
1253
                )
1254
                SELECT count(*) FROM state;
1255
            """
1256

UNCOV
1257
            txn.execute(sql, (state_group,))
!
UNCOV
1258
            row = txn.fetchone()
!
UNCOV
1259
            if row and row[0]:
Branches [[0, 1260], [0, 1262]] missed. !
UNCOV
1260
                return row[0]
!
1261
            else:
1262
                return 0
!
1263
        else:
1264
            # We don't use WITH RECURSIVE on sqlite3 as there are distributions
1265
            # that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
1266
            next_group = state_group
!
1267
            count = 0
!
1268

1269
            while next_group:
Branches [[0, 1270], [0, 1280]] missed. !
1270
                next_group = self._simple_select_one_onecol_txn(
!
1271
                    txn,
1272
                    table="state_group_edges",
1273
                    keyvalues={"state_group": next_group},
1274
                    retcol="prev_state_group",
1275
                    allow_none=True,
1276
                )
1277
                if next_group:
Branches [[0, 1269], [0, 1278]] missed. !
1278
                    count += 1
!
1279

1280
            return count
!
1281

1282

1283
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
1×
1284
    """ Keeps track of the state at a given event.
1285

1286
    This is done by the concept of `state groups`. Every event is a assigned
1287
    a state group (identified by an arbitrary string), which references a
1288
    collection of state events. The current state of an event is then the
1289
    collection of state events referenced by the event's state group.
1290

1291
    Hence, every change in the current state causes a new state group to be
1292
    generated. However, if no change happens (e.g., if we get a message event
1293
    with only one parent it inherits the state group from its parent.)
1294

1295
    There are three tables:
1296
      * `state_groups`: Stores group name, first event with in the group and
1297
        room id.
1298
      * `event_to_state_groups`: Maps events to state groups.
1299
      * `state_groups_state`: Maps state group to state events.
1300
    """
1301

1302
    STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
1×
1303
    STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
1×
1304
    CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
1×
1305
    EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
1×
1306

1307
    def __init__(self, db_conn, hs):
1×
UNCOV
1308
        super(StateStore, self).__init__(db_conn, hs)
!
UNCOV
1309
        self.register_background_update_handler(
!
1310
            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
1311
            self._background_deduplicate_state,
1312
        )
UNCOV
1313
        self.register_background_update_handler(
!
1314
            self.STATE_GROUP_INDEX_UPDATE_NAME, self._background_index_state
1315
        )
UNCOV
1316
        self.register_background_index_update(
!
1317
            self.CURRENT_STATE_INDEX_UPDATE_NAME,
1318
            index_name="current_state_events_member_index",
1319
            table="current_state_events",
1320
            columns=["state_key"],
1321
            where_clause="type='m.room.member'",
1322
        )
UNCOV
1323
        self.register_background_index_update(
!
1324
            self.EVENT_STATE_GROUP_INDEX_UPDATE_NAME,
1325
            index_name="event_to_state_groups_sg_index",
1326
            table="event_to_state_groups",
1327
            columns=["state_group"],
1328
        )
1329

1330
    def _store_event_state_mappings_txn(self, txn, events_and_contexts):
1×
UNCOV
1331
        state_groups = {}
!
UNCOV
1332
        for event, context in events_and_contexts:
Branches [[0, 1333], [0, 1344]] missed. !
UNCOV
1333
            if event.internal_metadata.is_outlier():
Branches [[0, 1334], [0, 1338]] missed. !
UNCOV
1334
                continue
!
1335

1336
            # if the event was rejected, just give it the same state as its
1337
            # predecessor.
UNCOV
1338
            if context.rejected:
Branches [[0, 1339], [0, 1342]] missed. !
UNCOV
1339
                state_groups[event.event_id] = context.prev_group
!
UNCOV
1340
                continue
!
1341

UNCOV
1342
            state_groups[event.event_id] = context.state_group
!
1343

UNCOV
1344
        self._simple_insert_many_txn(
Branches [[0, 1348], [0, 1353]] missed. !
1345
            txn,
1346
            table="event_to_state_groups",
1347
            values=[
1348
                {"state_group": state_group_id, "event_id": event_id}
1349
                for event_id, state_group_id in iteritems(state_groups)
1350
            ],
1351
        )
1352

UNCOV
1353
        for event_id, state_group_id in iteritems(state_groups):
Branches [[0, 1330], [0, 1354]] missed. !
UNCOV
1354
            txn.call_after(
!
1355
                self._get_state_group_for_event.prefill, (event_id,), state_group_id
1356
            )
1357

1358
    @defer.inlineCallbacks
1×
1359
    def _background_deduplicate_state(self, progress, batch_size):
1360
        """This background update will slowly deduplicate state by reencoding
1361
        them as deltas.
1362
        """
1363
        last_state_group = progress.get("last_state_group", 0)
!
1364
        rows_inserted = progress.get("rows_inserted", 0)
!
1365
        max_group = progress.get("max_group", None)
!
1366

1367
        BATCH_SIZE_SCALE_FACTOR = 100
!
1368

1369
        batch_size = max(1, int(batch_size / BATCH_SIZE_SCALE_FACTOR))
!
1370

1371
        if max_group is None:
Branches [[0, 1372], [0, 1379]] missed. !
1372
            rows = yield self._execute(
!
1373
                "_background_deduplicate_state",
1374
                None,
1375
                "SELECT coalesce(max(id), 0) FROM state_groups",
1376
            )
1377
            max_group = rows[0][0]
!
1378

1379
        def reindex_txn(txn):
!
1380
            new_last_state_group = last_state_group
!
1381
            for count in range(batch_size):
Branches [[0, 1382], [0, 1478]] missed. !
1382
                txn.execute(
!
1383
                    "SELECT id, room_id FROM state_groups"
1384
                    " WHERE ? < id AND id <= ?"
1385
                    " ORDER BY id ASC"
1386
                    " LIMIT 1",
1387
                    (new_last_state_group, max_group),
1388
                )
1389
                row = txn.fetchone()
!
1390
                if row:
Branches [[0, 1391], [0, 1393]] missed. !
1391
                    state_group, room_id = row
!
1392

1393
                if not row or not state_group:
Branches [[0, 1394], [0, 1396]] missed. !
1394
                    return True, count
!
1395

1396
                txn.execute(
!
1397
                    "SELECT state_group FROM state_group_edges"
1398
                    " WHERE state_group = ?",
1399
                    (state_group,),
1400
                )
1401

1402
                # If we reach a point where we've already started inserting
1403
                # edges we should stop.
1404
                if txn.fetchall():
Branches [[0, 1405], [0, 1407]] missed. !
1405
                    return True, count
!
1406

1407
                txn.execute(
!
1408
                    "SELECT coalesce(max(id), 0) FROM state_groups"
1409
                    " WHERE id < ? AND room_id = ?",
1410
                    (state_group, room_id),
1411
                )
1412
                prev_group, = txn.fetchone()
!
1413
                new_last_state_group = state_group
!
1414

1415
                if prev_group:
Branches [[0, 1381], [0, 1416]] missed. !
1416
                    potential_hops = self._count_state_group_hops_txn(txn, prev_group)
!
1417
                    if potential_hops >= MAX_STATE_DELTA_HOPS:
Branches [[0, 1420], [0, 1422]] missed. !
1418
                        # We want to ensure chains are at most this long,#
1419
                        # otherwise read performance degrades.
1420
                        continue
!
1421

1422
                    prev_state = self._get_state_groups_from_groups_txn(
!
1423
                        txn, [prev_group]
1424
                    )
1425
                    prev_state = prev_state[prev_group]
!
1426

1427
                    curr_state = self._get_state_groups_from_groups_txn(
!
1428
                        txn, [state_group]
1429
                    )
1430
                    curr_state = curr_state[state_group]
!
1431

1432
                    if not set(prev_state.keys()) - set(curr_state.keys()):
Branches [[0, 1381], [0, 1436]] missed. !
1433
                        # We can only do a delta if the current has a strict super set
1434
                        # of keys
1435

1436
                        delta_state = {
Branches [[0, 1436], [0, 1442]] missed. !
1437
                            key: value
1438
                            for key, value in iteritems(curr_state)
1439
                            if prev_state.get(key, None) != value
1440
                        }
1441

1442
                        self._simple_delete_txn(
!
1443
                            txn,
1444
                            table="state_group_edges",
1445
                            keyvalues={"state_group": state_group},
1446
                        )
1447

1448
                        self._simple_insert_txn(
!
1449
                            txn,
1450
                            table="state_group_edges",
1451
                            values={
1452
                                "state_group": state_group,
1453
                                "prev_state_group": prev_group,
1454
                            },
1455
                        )
1456

1457
                        self._simple_delete_txn(
!
1458
                            txn,
1459
                            table="state_groups_state",
1460
                            keyvalues={"state_group": state_group},
1461
                        )
1462

1463
                        self._simple_insert_many_txn(
Branches [[0, 1467], [0, 1381]] missed. !
1464
                            txn,
1465
                            table="state_groups_state",
1466
                            values=[
1467
                                {
1468
                                    "state_group": state_group,
1469
                                    "room_id": room_id,
1470
                                    "type": key[0],
1471
                                    "state_key": key[1],
1472
                                    "event_id": state_id,
1473
                                }
1474
                                for key, state_id in iteritems(delta_state)
1475
                            ],
1476
                        )
1477

1478
            progress = {
!
1479
                "last_state_group": state_group,
1480
                "rows_inserted": rows_inserted + batch_size,
1481
                "max_group": max_group,
1482
            }
1483

1484
            self._background_update_progress_txn(
!
1485
                txn, self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, progress
1486
            )
1487

1488
            return False, batch_size
!
1489

1490
        finished, result = yield self.runInteraction(
!
1491
            self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME, reindex_txn
1492
        )
1493

1494
        if finished:
Branches [[0, 1495], [0, 1499]] missed. !
1495
            yield self._end_background_update(
!
1496
                self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME
1497
            )
1498

1499
        return result * BATCH_SIZE_SCALE_FACTOR
!
1500

1501
    @defer.inlineCallbacks
1×
1502
    def _background_index_state(self, progress, batch_size):
1503
        def reindex_txn(conn):
!
1504
            conn.rollback()
!
1505
            if isinstance(self.database_engine, PostgresEngine):
Branches [[0, 1507], [0, 1518]] missed. !
1506
                # postgres insists on autocommit for the index
1507
                conn.set_session(autocommit=True)
!
1508
                try:
!
1509
                    txn = conn.cursor()
!
1510
                    txn.execute(
!
1511
                        "CREATE INDEX CONCURRENTLY state_groups_state_type_idx"
1512
                        " ON state_groups_state(state_group, type, state_key)"
1513
                    )
1514
                    txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
!
1515
                finally:
1516
                    conn.set_session(autocommit=False)
!
1517
            else:
1518
                txn = conn.cursor()
!
1519
                txn.execute(
!
1520
                    "CREATE INDEX state_groups_state_type_idx"
1521
                    " ON state_groups_state(state_group, type, state_key)"
1522
                )
1523
                txn.execute("DROP INDEX IF EXISTS state_groups_state_id")
!
1524

1525
        yield self.runWithConnection(reindex_txn)
!
1526

1527
        yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
!
1528

1529
        return 1
!
Troubleshooting · Open an Issue · Sales · Support · ENTERPRISE · CAREERS · STATUS
BLOG · TWITTER · Legal & Privacy · Supported CI Services · What's a CI service? · Automated Testing

© 2019 Coveralls, LLC