• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

mozilla-releng / balrog / #5442

04 May 2026 02:04AM UTC coverage: 89.866% (-0.06%) from 89.923%
#5442

Pull #3770

circleci

web-flow
chore(deps): lock file maintenance (pep621) (#3769)
Pull Request #3770: chore(deps): lock file maintenance (pep621)

2174 of 2558 branches covered (84.99%)

Branch coverage included in aggregate %.

5745 of 6254 relevant lines covered (91.86%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.08
/src/auslib/db.py
1
import itertools
1✔
2
import json
1✔
3
import logging
1✔
4
import re
1✔
5
import time
1✔
6
from collections import defaultdict
1✔
7
from copy import copy
1✔
8
from os import path
1✔
9

10
import migrate.versioning.api
1✔
11
import migrate.versioning.schema
1✔
12
import sqlalchemy.event
1✔
13
import sqlalchemy.types
1✔
14
from aiohttp import ClientSession
1✔
15
from sqlalchemy import JSON, BigInteger, Boolean, Column, Integer, MetaData, String, Table, Text, create_engine, func, join, select
1✔
16
from sqlalchemy.exc import SQLAlchemyError
1✔
17
from sqlalchemy.sql.expression import null
1✔
18
from sqlalchemy.sql.functions import max as sql_max
1✔
19

20
from auslib.blobs.base import createBlob, merge_dicts
1✔
21
from auslib.errors import PermissionDeniedError, ReadOnlyError, SignoffRequiredError
1✔
22
from auslib.global_state import cache
1✔
23
from auslib.util.rulematching import (
1✔
24
    matchBoolean,
25
    matchBuildID,
26
    matchChannel,
27
    matchCsv,
28
    matchLocale,
29
    matchMemory,
30
    matchRegex,
31
    matchSimpleExpression,
32
    matchVersion,
33
)
34
from auslib.util.signoffs import get_required_signoffs_for_product_channel
1✔
35
from auslib.util.statsd import statsd
1✔
36
from auslib.util.timestamp import getMillisecondTimestamp
1✔
37
from auslib.util.versions import get_version_class
1✔
38

39

40
def rows_to_dicts(rows):
1✔
41
    """Converts SQL Alchemy result rows to dicts.
42

43
    You might want this if you want to mutate objects (SQLAlchemy rows
44
    are immutable), or if you want to serialize them to JSON
45
    (SQLAlchemy rows get confused if you try to serialize them).
46
    """
47
    # In Python 3, map returns an iterable instead a list.
48
    return [dict(row) for row in rows]
1✔
49

50

51
class AlreadySetupError(Exception):
1✔
52
    def __str__(self):
1✔
53
        return "Can't connect to new database, still connected to previous one"
×
54

55

56
class TransactionError(SQLAlchemyError):
1✔
57
    """Raised when a transaction fails for any reason."""
58

59

60
class OutdatedDataError(SQLAlchemyError):
1✔
61
    """Raised when an update or delete fails because of outdated data."""
62

63

64
class MismatchedDataVersionError(SQLAlchemyError):
1✔
65
    """Raised when the data version of a scheduled change and its associated conditions
66
    row do not match after an insert or update."""
67

68

69
class WrongNumberOfRowsError(SQLAlchemyError):
1✔
70
    """Raised when an update or delete fails because the clause matches more than one row."""
71

72

73
class UpdateMergeError(SQLAlchemyError):
1✔
74
    pass
1✔
75

76

77
class ChangeScheduledError(SQLAlchemyError):
1✔
78
    """Raised when a Scheduled Change cannot be created, modified, or deleted
79
    for data consistency reasons."""
80

81

82
class JSONColumn(sqlalchemy.types.TypeDecorator):
1✔
83
    """JSONColumns are used for types that are deserialized JSON (usually
84
    dicts) in memory, but need to be serialized to text before storage.
85
    JSONColumn handles the conversion both ways, serialized just before
86
    storage, and deserialized just after retrieval."""
87

88
    impl = Text
1✔
89
    cache_ok = True
1✔
90

91
    def process_bind_param(self, value, dialect):
1✔
92
        if value:
1✔
93
            value = json.dumps(value)
1✔
94
        return value
1✔
95

96
    def process_result_value(self, value, dialect):
1✔
97
        if value:
1✔
98
            value = json.loads(value)
1✔
99
        return value
1✔
100

101

102
class CompatibleBooleanColumn(sqlalchemy.types.TypeDecorator):
1✔
103
    """A Boolean column that is compatible with all of our supported
104
    database engines (mysql, sqlite). SQLAlchemy's built-in Boolean
105
    does not work because it creates a CHECK constraint that makes
106
    it impossible to downgrade a database with sqlalchemy-migrate."""
107

108
    impl = Integer
1✔
109
    cache_ok = True
1✔
110

111
    def process_bind_param(self, value, dialect):
1✔
112
        if value is not None:
1✔
113
            if not isinstance(value, bool):
1!
114
                raise TypeError("{} is invalid type ({}), must be bool".format(value, type(value)))
×
115

116
            if value is True:
1✔
117
                value = 1
1✔
118
            else:
119
                value = 0
1✔
120
        return value
1✔
121

122
    def process_result_value(self, value, dialect):
1✔
123
        # Boolean columns may be nullable, we need to be sure to preserve nulls
124
        # in case consumers treat them differently than False.
125
        if value is not None:
1✔
126
            value = bool(value)
1✔
127
        return value
1✔
128

129

130
def BlobColumn(impl=Text):
1✔
131
    """BlobColumns are used to store Release Blobs, which are ultimately dicts.
132
    Release Blobs must be serialized before storage, and deserialized upon
133
    retrieval. This type handles both conversions. Some database engines
134
    (eg: mysql) may require a different underlying type than Text. The
135
    desired type may be passed in as an argument."""
136

137
    class cls(sqlalchemy.types.TypeDecorator):
1✔
138
        cache_ok = True
1✔
139

140
        def process_bind_param(self, value, dialect):
1✔
141
            if value:
1✔
142
                value = value.getJSON()
1✔
143
            return value
1✔
144

145
        def process_result_value(self, value, dialect):
1✔
146
            if value:
1✔
147
                value = createBlob(value)
1✔
148
            return value
1✔
149

150
    cls.impl = impl
1✔
151
    return cls
1✔
152

153

154
def verify_signoffs(potential_required_signoffs, signoffs):
1✔
155
    """Determines whether or not something is signed off given:
156
    * A list of potential required signoffs
157
    * A list of signoffs that have been made
158

159
    The real number of signoffs required is found by looking through the
160
    potential required signoffs and finding the highest number required for each
161
    role. If there are not enough signoffs provided for any of the groups,
162
    a SignoffRequiredError is raised."""
163

164
    signoffs_given = defaultdict(int)
1✔
165
    required_signoffs = {}
1✔
166
    if not potential_required_signoffs:
1✔
167
        return
1✔
168
    if not signoffs:
1✔
169
        raise SignoffRequiredError("No Signoffs given")
1✔
170
    for signoff in signoffs:
1✔
171
        signoffs_given[signoff["role"]] += 1
1✔
172
    for rs in potential_required_signoffs:
1✔
173
        required_signoffs[rs["role"]] = max(required_signoffs.get(rs["role"], 0), rs["signoffs_required"])
1✔
174
    for role, signoffs_required in required_signoffs.items():
1✔
175
        if signoffs_given[role] < signoffs_required:
1✔
176
            raise SignoffRequiredError("Not enough signoffs for role '{}'".format(role))
1✔
177

178

179
class AUSTransaction(object):
1✔
180
    """Manages a single transaction. Requires a connection object.
181

182
    :param conn: connection object to perform the transaction on
183
    :type conn: sqlalchemy.engine.base.Connection
184

185
    The connection and transaction are opened lazily on the first call to
186
    execute().
187
    """
188

189
    def __init__(self, engine):
1✔
190
        self.engine = engine
1✔
191
        self.conn = None
1✔
192
        self.trans = None
1✔
193
        self.log = logging.getLogger(self.__class__.__name__)
1✔
194

195
    def _ensure_connection(self):
1✔
196
        if self.conn is None:
1✔
197
            self.conn = self.engine.connect()
1✔
198
            self.trans = self.conn.begin()
1✔
199

200
    def __enter__(self):
1✔
201
        return self
1✔
202

203
    def __exit__(self, exc_type, exc_value, exc_traceback):
1✔
204
        if self.conn is None:
1✔
205
            return
1✔
206
        try:
1✔
207
            # If something that executed in the context raised an Exception,
208
            # rollback and re-raise it.
209
            if exc_type:
1✔
210
                self.log.debug("exc is:", exc_info=True)
1✔
211
                self.rollback()
1✔
212
                return False
1✔
213
            # self.commit will issue a rollback if it raises
214
            self.commit()
1✔
215
        finally:
216
            # Always make sure the connection is closed, bug 740360
217
            self.close()
1✔
218

219
    def close(self):
1✔
220
        # For some reason, sometimes the connection appears to close itself...
221
        if self.conn is not None and not self.conn.closed:
1✔
222
            self.conn.close()
1✔
223

224
    def execute(self, statement):
1✔
225
        self._ensure_connection()
1✔
226
        try:
1✔
227
            self.log.debug("Attempting to execute %s" % statement)
1✔
228
            return self.conn.execute(statement)
1✔
229
        except Exception as exc:
1✔
230
            self.log.debug("Caught exception")
1✔
231
            # We want to raise our own Exception, so that errors are easily
232
            # caught by consumers. The dance below lets us do that without
233
            # losing the original Traceback, which will be much more
234
            # informative than one starting from this point.
235
            self.rollback()
1✔
236
            raise TransactionError() from exc
1✔
237

238
    def commit(self):
1✔
239
        if self.trans is None:
1!
240
            return
×
241
        try:
1✔
242
            self.trans.commit()
1✔
243
        except Exception as exc:
×
244
            self.rollback()
×
245
            raise TransactionError() from exc
×
246

247
    def rollback(self):
1✔
248
        if self.trans is None:
1✔
249
            return
1✔
250
        self.trans.rollback()
1✔
251

252

253
class AUSTable(object):
1✔
254
    """Base class for all AUS Tables. By default, all tables have a history
255
    table created for them, too, which mirrors their own structure and adds
256
    a record of who made a change, and when the change happened.
257

258
    :param history: Whether or not to create a history table for this table.
259
                    When True, a History object will be created for this
260
                    table, and all changes will be logged to it. Defaults
261
                    to True.
262
    :type history: bool
263
    :param versioned: Whether or not this table is versioned. When True,
264
                      an additional 'data_version' column will be added
265
                      to the Table, and its version increased with every
266
                      update. This is useful for detecting colliding
267
                      updates.
268

269
    :type versioned: bool
270
    :param scheduled_changes: Whether or not this table should allow changes
271
                              to be scheduled. When True, two additional tables
272
                              will be created: a $name_scheduled_changes, which
273
                              will contain data needed to schedule changes to
274
                              $name, and $name_scheduled_changes_history, which
275
                              tracks the history of a scheduled change.
276

277
    :type scheduled_changes: bool
278
    """
279

280
    def __init__(
1✔
281
        self,
282
        db,
283
        dialect,
284
        historyClass=None,
285
        historyKwargs={},
286
        versioned=True,
287
        scheduled_changes=False,
288
        scheduled_changes_kwargs={},
289
    ):
290
        self.db = db
1✔
291
        self.t = self.table
1✔
292
        # Enable versioning, if required
293
        if versioned:
1✔
294
            self.t.append_column(Column("data_version", Integer, nullable=False))
1✔
295
        self.versioned = versioned
1✔
296
        # Mirror the columns as attributes for easy access
297
        self.primary_key = []
1✔
298
        for col in self.table.columns:
1✔
299
            setattr(self, col.name, col)
1✔
300
            if col.primary_key:
1✔
301
                self.primary_key.append(col)
1✔
302
        # Set-up a history table to do logging in, if required
303
        if historyClass:
1✔
304
            self.history = historyClass(db, dialect, self.t.metadata, self, **historyKwargs)
1✔
305
        else:
306
            self.history = None
1✔
307
        # Set-up a scheduled changes table if required
308
        if scheduled_changes:
1✔
309
            self.scheduled_changes = ScheduledChangeTable(db, dialect, self.t.metadata, self, **scheduled_changes_kwargs)
1✔
310
        else:
311
            self.scheduled_changes = None
1✔
312
        self.log = logging.getLogger(self.__class__.__name__)
1✔
313

314
    # Can't do this in the constructor, because the engine is always
315
    # unset when we're instantiated
316
    def getEngine(self):
1✔
317
        return self.t.metadata.bind
1✔
318

319
    def _returnRowOrRaise(self, where, columns=None, transaction=None):
1✔
320
        """Return the row matching the where clause supplied. If no rows match or multiple rows match,
321
        a WrongNumberOfRowsError will be raised."""
322
        rows = self.select(where=where, columns=columns, transaction=transaction)
1✔
323
        if len(rows) == 0:
1!
324
            raise WrongNumberOfRowsError("where clause matched no rows")
×
325
        if len(rows) > 1:
1!
326
            raise WrongNumberOfRowsError("where clause matches multiple rows (primary keys: %s)" % rows)
×
327
        return rows[0]
1✔
328

329
    def _selectStatement(self, columns=None, where=None, order_by=None, limit=None, offset=None, distinct=False):
1✔
330
        """Create a SELECT statement on this table.
331

332
        :param columns: Column objects to select. Defaults to None, meaning select all columns
333
        :type columns: A sequence of sqlalchemy.schema.Column objects or column names as strings
334
        :param order_by: Columns to sort the rows by. Defaults to None, meaning no ORDER BY clause
335
        :type order_by: A sequence of sqlalchemy.schema.Column objects
336
        :param limit: Limit results to this many. Defaults to None, meaning no limit
337
        :type limit: int
338
        :param distinct: Whether or not to return only distinct rows. Default: False.
339
        :type distinct: bool
340

341
        :rtype: sqlalchemy.sql.expression.Select
342
        """
343
        if columns:
1✔
344
            table_columns = [(self.t.c[col] if isinstance(col, str) else col) for col in columns]
1✔
345
            query = select(table_columns, order_by=order_by, limit=limit, offset=offset, distinct=distinct)
1✔
346
        else:
347
            query = self.t.select(order_by=order_by, limit=limit, offset=offset, distinct=distinct)
1✔
348
        if where:
1✔
349
            for cond in where:
1✔
350
                query = query.where(cond)
1✔
351
        return query
1✔
352

353
    def select(self, where=None, transaction=None, **kwargs):
1✔
354
        """Perform a SELECT statement on this table.
355
        See AUSTable._selectStatement for possible arguments.
356

357
        :param where: A list of SQLAlchemy clauses, or a key/value pair of columns and values.
358
        :type where: list of clauses or key/value pairs.
359

360
        :param transaction: A transaction object to add the update statement (and history changes) to.
361
                            If provided, you must commit the transaction yourself. If None, they will
362
                            be added to a locally-scoped transaction and committed.
363

364
        :rtype: sqlalchemy.engine.base.ResultProxy
365
        """
366

367
        # If "where" is key/value pairs, we need to convert it to SQLAlchemy
368
        # clauses before proceeding.
369
        if hasattr(where, "keys"):
1✔
370
            where = [getattr(self, k) == v for k, v in where.items()]
1✔
371

372
        query = self._selectStatement(where=where, **kwargs)
1✔
373

374
        if transaction:
1✔
375
            result = transaction.execute(query).fetchall()
1✔
376
        else:
377
            with AUSTransaction(self.getEngine()) as trans:
1✔
378
                result = trans.execute(query).fetchall()
1✔
379

380
        return rows_to_dicts(result)
1✔
381

382
    def _insertStatement(self, **columns):
1✔
383
        """Create an INSERT statement for this table
384

385
        :param columns: Data to insert
386
        :type colmuns: dict
387

388
        :rtype: sqlalchemy.sql.express.Insert
389
        """
390
        table_columns = {k: columns[k] for k in columns.keys() if k in self.table.c}
1✔
391
        unconsumed_columns = {k: columns[k] for k in columns.keys() if k not in table_columns}
1✔
392
        return self.t.insert(values=table_columns), unconsumed_columns
1✔
393

394
    def _sharedPrepareInsert(self, trans, changed_by, **columns):
1✔
395
        """Prepare an INSERT statement for commit. If this table has versioning enabled,
396
        data_version will be set to 1. If this table has history enabled, two rows
397
        will be created in that table: one representing the current state (NULL),
398
        and one representing the new state.
399

400
        :rtype: sqlalchemy.engine.base.ResultProxy
401
        """
402
        data = columns.copy()
1✔
403
        if self.versioned:
1✔
404
            data["data_version"] = 1
1✔
405
        query, unconsumed_columns = self._insertStatement(**data)
1✔
406
        ret = trans.execute(query)
1✔
407
        return data, ret
1✔
408

409
    def _prepareInsert(self, trans, changed_by, **columns):
1✔
410
        data, ret = self._sharedPrepareInsert(trans, changed_by, **columns)
1✔
411
        if self.history:
1✔
412
            self.history.forInsert(ret.inserted_primary_key, data, changed_by, trans)
1✔
413
        return ret
1✔
414

415
    async def _asyncPrepareInsert(self, trans, changed_by, **columns):
1✔
416
        data, ret = self._sharedPrepareInsert(trans, changed_by, **columns)
1✔
417
        if self.history:
1!
418
            await self.history.forInsert(ret.inserted_primary_key, data, changed_by, trans)
1✔
419
        return ret
1✔
420

421
    def insert(self, changed_by=None, transaction=None, dryrun=False, **columns):
1✔
422
        """Perform an INSERT statement on this table. See AUSTable._insertStatement for
423
        a description of columns.
424

425
        :param changed_by: The username of the person inserting the row. Required when
426
                           history is enabled. Unused otherwise. No authorization checks are done
427
                           at this level.
428
        :type changed_by: str
429
        :param transaction: A transaction object to add the insert statement (and history changes) to.
430
                            If provided, you must commit the transaction yourself. If None, they will
431
                            be added to a locally-scoped transaction and committed.
432
        :param dryrun: If true, this insert statement will not actually be run.
433
        :type dryrun: bool
434

435
        :rtype: sqlalchemy.engine.base.ResultProxy
436
        """
437
        if self.history and not changed_by:
1!
438
            raise ValueError("changed_by must be passed for Tables that have history")
×
439

440
        if dryrun:
1✔
441
            self.log.debug("In dryrun mode, not doing anything...")
1✔
442
            return
1✔
443

444
        if transaction:
1✔
445
            return self._prepareInsert(transaction, changed_by, **columns)
1✔
446
        else:
447
            with AUSTransaction(self.getEngine()) as trans:
1✔
448
                return self._prepareInsert(trans, changed_by, **columns)
1✔
449

450
    async def async_insert(self, changed_by=None, transaction=None, dryrun=False, **columns):
1✔
451
        """Perform an INSERT statement on this table. See AUSTable._insertStatement for
452
        a description of columns.
453

454
        :param changed_by: The username of the person inserting the row. Required when
455
                           history is enabled. Unused otherwise. No authorization checks are done
456
                           at this level.
457
        :type changed_by: str
458
        :param transaction: A transaction object to add the insert statement (and history changes) to.
459
                            If provided, you must commit the transaction yourself. If None, they will
460
                            be added to a locally-scoped transaction and committed.
461
        :param dryrun: If true, this insert statement will not actually be run.
462
        :type dryrun: bool
463

464
        :rtype: sqlalchemy.engine.base.ResultProxy
465
        """
466
        if self.history and not changed_by:
1!
467
            raise ValueError("changed_by must be passed for Tables that have history")
×
468

469
        if dryrun:
1!
470
            self.log.debug("In dryrun mode, not doing anything...")
×
471
            return
×
472

473
        if transaction:
1!
474
            return await self._asyncPrepareInsert(transaction, changed_by, **columns)
1✔
475
        else:
476
            with AUSTransaction(self.getEngine()) as trans:
×
477
                return await self._asyncPrepareInsert(trans, changed_by, **columns)
×
478

479
    def _deleteStatement(self, where):
1✔
480
        """Create a DELETE statement for this table.
481

482
        :param where: Conditions to apply on this select.
483
        :type where: A sequence of sqlalchemy.sql.expression.ClauseElement objects
484

485
        :rtype: sqlalchemy.sql.expression.Delete
486
        """
487
        query = self.t.delete()
1✔
488
        if where:
1!
489
            for cond in where:
1✔
490
                query = query.where(cond)
1✔
491
        return query
1✔
492

493
    def _sharedPrepareDelete(self, trans, where, changed_by, old_data_version):
1✔
494
        """Prepare a DELETE statement for commit. If this table has history enabled,
495
        a row will be created in that table representing the new state of the
496
        row being deleted (NULL). If versioning is enabled and old_data_version
497
        doesn't match the current version of the row to be deleted, an OutdatedDataError
498
        will be raised.
499

500
        :rtype: sqlalchemy.engine.base.ResultProxy
501
        """
502
        row = self._returnRowOrRaise(where=where, columns=self.primary_key, transaction=trans)
1✔
503

504
        if self.versioned:
1✔
505
            where = copy(where)
1✔
506
            where.append(self.data_version == old_data_version)
1✔
507

508
        query = self._deleteStatement(where)
1✔
509

510
        ret = trans.execute(query)
1✔
511
        if ret.rowcount != 1:
1✔
512
            raise OutdatedDataError("Failed to delete row, old_data_version doesn't match current data_version")
1✔
513
        if self.scheduled_changes:
1✔
514
            # If this table has active scheduled changes we cannot allow it to be deleted
515
            sc_where = [self.scheduled_changes.complete == False]  # noqa
1✔
516
            for pk in self.primary_key:
1✔
517
                sc_where.append(getattr(self.scheduled_changes, "base_%s" % pk.name) == row[pk.name])
1✔
518
            if self.scheduled_changes.select(where=sc_where, transaction=trans):
1✔
519
                raise ChangeScheduledError("Cannot delete rows that have changes scheduled.")
1✔
520

521
        return row, ret
1✔
522

523
    def _prepareDelete(self, trans, where, changed_by, old_data_version):
1✔
524
        row, ret = self._sharedPrepareDelete(trans, where, changed_by, old_data_version)
1✔
525
        if self.history:
1!
526
            self.history.forDelete(row, changed_by, trans)
1✔
527

528
        return ret
1✔
529

530
    async def _asyncPrepareDelete(self, trans, where, changed_by, old_data_version):
1✔
531
        row, ret = self._sharedPrepareDelete(trans, where, changed_by, old_data_version)
1✔
532
        if self.history:
1!
533
            await self.history.forDelete(row, changed_by, trans)
1✔
534

535
        return ret
1✔
536

537
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False):
1✔
538
        """Perform a DELETE statement on this table. See AUSTable._deleteStatement for
539
        a description of `where`. To simplify versioning, this method can only
540
        delete a single row per invocation. If the where clause given would delete
541
        zero or multiple rows, a WrongNumberOfRowsError is raised.
542

543
        :param where: A list of SQLAlchemy clauses, or a key/value pair of columns and values.
544
        :type where: list of clauses or key/value pairs.
545
        :param changed_by: The username of the person deleting the row(s). Required when
546
                           history is enabled. Unused otherwise. No authorization checks are done
547
                           at this level.
548
        :type changed_by: str
549
        :param old_data_version: Previous version of the row to be deleted. If this version doesn't
550
                                 match the current version of the row, an OutdatedDataError will be
551
                                 raised and the delete will fail. Required when versioning is enabled.
552
        :type old_data_version: int
553
        :param transaction: A transaction object to add the delete statement (and history changes) to.
554
                            If provided, you must commit the transaction yourself. If None, they will
555
                            be added to a locally-scoped transaction and committed.
556
        :param dryrun: If true, this insert statement will not actually be run.
557
        :type dryrun: bool
558

559
        :rtype: sqlalchemy.engine.base.ResultProxy
560
        """
561
        # If "where" is key/value pairs, we need to convert it to SQLAlchemy
562
        # clauses before proceeding.
563
        if hasattr(where, "keys"):
1✔
564
            where = [getattr(self, k) == v for k, v in where.items()]
1✔
565

566
        if self.history and not changed_by:
1!
567
            raise ValueError("changed_by must be passed for Tables that have history")
×
568
        if self.versioned and not old_data_version:
1!
569
            raise ValueError("old_data_version must be passed for Tables that are versioned")
×
570

571
        if dryrun:
1✔
572
            self.log.debug("In dryrun mode, not doing anything...")
1✔
573
            return
1✔
574

575
        if transaction:
1✔
576
            return self._prepareDelete(transaction, where, changed_by, old_data_version)
1✔
577
        else:
578
            with AUSTransaction(self.getEngine()) as trans:
1✔
579
                return self._prepareDelete(trans, where, changed_by, old_data_version)
1✔
580

581
    async def async_delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False):
1✔
582
        """Perform a DELETE statement on this table. See AUSTable._deleteStatement for
583
        a description of `where`. To simplify versioning, this method can only
584
        delete a single row per invocation. If the where clause given would delete
585
        zero or multiple rows, a WrongNumberOfRowsError is raised.
586

587
        :param where: A list of SQLAlchemy clauses, or a key/value pair of columns and values.
588
        :type where: list of clauses or key/value pairs.
589
        :param changed_by: The username of the person deleting the row(s). Required when
590
                           history is enabled. Unused otherwise. No authorization checks are done
591
                           at this level.
592
        :type changed_by: str
593
        :param old_data_version: Previous version of the row to be deleted. If this version doesn't
594
                                 match the current version of the row, an OutdatedDataError will be
595
                                 raised and the delete will fail. Required when versioning is enabled.
596
        :type old_data_version: int
597
        :param transaction: A transaction object to add the delete statement (and history changes) to.
598
                            If provided, you must commit the transaction yourself. If None, they will
599
                            be added to a locally-scoped transaction and committed.
600
        :param dryrun: If true, this insert statement will not actually be run.
601
        :type dryrun: bool
602

603
        :rtype: sqlalchemy.engine.base.ResultProxy
604
        """
605
        # If "where" is key/value pairs, we need to convert it to SQLAlchemy
606
        # clauses before proceeding.
607
        if hasattr(where, "keys"):
1✔
608
            where = [getattr(self, k) == v for k, v in where.items()]
1✔
609

610
        if self.history and not changed_by:
1!
611
            raise ValueError("changed_by must be passed for Tables that have history")
×
612
        if self.versioned and not old_data_version:
1!
613
            raise ValueError("old_data_version must be passed for Tables that are versioned")
×
614

615
        if dryrun:
1!
616
            self.log.debug("In dryrun mode, not doing anything...")
×
617
            return
×
618

619
        if transaction:
1!
620
            return await self._asyncPrepareDelete(transaction, where, changed_by, old_data_version)
1✔
621
        else:
622
            with AUSTransaction(self.getEngine()) as trans:
×
623
                return await self._asyncPrepareDelete(trans, where, changed_by, old_data_version)
×
624

625
    def _updateStatement(self, where, what):
1✔
626
        """Create an UPDATE statement for this table
627

628
        :param where: Conditions to apply to this UPDATE.
629
        :type where: A sequence of sqlalchemy.sql.expression.ClauseElement objects.
630
        :param what: Data to update
631
        :type what: dict
632

633
        :rtype: sqlalchemy.sql.expression.Update
634
        """
635
        table_what = {k: what[k] for k in what.keys() if k in self.table.c}
1✔
636
        unconsumed_columns = {k: what[k] for k in what.keys() if k not in table_what}
1✔
637
        query = self.t.update(values=table_what)
1✔
638
        if where:
1!
639
            for cond in where:
1✔
640
                query = query.where(cond)
1✔
641
        return query, unconsumed_columns
1✔
642

643
    def _sharedPrepareUpdate(self, trans, where, what, changed_by, old_data_version):
1✔
644
        """Prepare an UPDATE statement for commit. If this table has versioning enabled,
645
        data_version will be increased by 1. If this table has history enabled, a
646
        row will be added to that table represent the new state of the data.
647

648
        :rtype: sqlalchemy.engine.base.ResultProxy
649
        """
650
        # To do merge detection for tables with scheduled changes we need a
651
        # copy of the original row, and what will be changed. To record
652
        # history, we need a copy of the entire new row.
653
        orig_row = self._returnRowOrRaise(where=where, transaction=trans)
1✔
654
        new_row = orig_row.copy()
1✔
655
        if self.versioned:
1✔
656
            where = copy(where)
1✔
657
            where.append(self.data_version == old_data_version)
1✔
658
            new_row["data_version"] += 1
1✔
659
            what["data_version"] = new_row["data_version"]
1✔
660

661
        # Copy the new data into the row
662
        for col in what:
1✔
663
            new_row[col] = what[col]
1✔
664

665
        query, unconsumed_columns = self._updateStatement(where, new_row)
1✔
666

667
        ret = trans.execute(query)
1✔
668
        # It's important that OutdatedDataError is raised as early as possible
669
        # because callers may be able to handle it gracefully (and continue
670
        # with their update). If we raise this _after_ adding history or merging
671
        # with Scheduled Changes, we may end up altering the history or
672
        # scheduled changes more than once if the caller ends up re-calling
673
        # AUSTable.update() after handling the OutdatedDataError.
674
        if ret.rowcount != 1:
1✔
675
            raise OutdatedDataError("Failed to update row, old_data_version doesn't match current data_version")
1✔
676
        if self.scheduled_changes:
1✔
677
            self.scheduled_changes.mergeUpdate(orig_row, what, changed_by, trans)
1✔
678
        return new_row, ret
1✔
679

680
    def _prepareUpdate(self, trans, where, what, changed_by, old_data_version):
1✔
681
        new_row, ret = self._sharedPrepareUpdate(trans, where, what, changed_by, old_data_version)
1✔
682
        if self.history:
1✔
683
            self.history.forUpdate(new_row, changed_by, trans)
1✔
684
        return ret
1✔
685

686
    async def _asyncPrepareUpdate(self, trans, where, what, changed_by, old_data_version):
1✔
687
        new_row, ret = self._sharedPrepareUpdate(trans, where, what, changed_by, old_data_version)
1✔
688
        if self.history:
1!
689
            await self.history.forUpdate(new_row, changed_by, trans)
1✔
690

691
        return ret
1✔
692

693
    def update(self, where, what, changed_by=None, old_data_version=None, transaction=None, dryrun=False):
1✔
694
        """Perform an UPDATE statement on this table. See AUSTable._updateStatement for
695
        a description of `where` and `what`. This method can only update a single row
696
        per invocation. If the where clause given would update zero or multiple rows, a
697
        WrongNumberOfRowsError is raised.
698

699
        :param where: A list of SQLAlchemy clauses, or a key/value pair of columns and values.
700
        :type where: list of clauses or key/value pairs.
701
        :param what: Key/value pairs containing new values for the given columns.
702
        :type what: key/value pairs
703
        :param changed_by: The username of the person inserting the row. Required when
704
                           history is enabled. Unused otherwise. No authorization checks are done
705
                           at this level.
706
        :type changed_by: str
707
        :param old_data_version: Previous version of the row to be deleted. If this version doesn't
708
                                 match the current version of the row, an OutdatedDataError will be
709
                                 raised and the delete will fail. Required when versioning is enabled.
710
        :type old_data_version: int
711
        :param transaction: A transaction object to add the update statement (and history changes) to.
712
                            If provided, you must commit the transaction yourself. If None, they will
713
                            be added to a locally-scoped transaction and committed.
714
        :param dryrun: If true, this insert statement will not actually be run.
715
        :type dryrun: bool
716

717
        :rtype: sqlalchemy.engine.base.ResultProxy
718
        """
719
        # If "where" is key/value pairs, we need to convert it to SQLAlchemy
720
        # clauses before proceeding.
721
        if hasattr(where, "keys"):
1✔
722
            where = [getattr(self, k) == v for k, v in where.items()]
1✔
723

724
        if self.history and not changed_by:
1!
725
            raise ValueError("changed_by must be passed for Tables that have history")
×
726
        if self.versioned and not old_data_version:
1!
727
            raise ValueError("update: old_data_version must be passed for Tables that are versioned")
×
728

729
        if dryrun:
1✔
730
            self.log.debug("In dryrun mode, not doing anything...")
1✔
731
            return
1✔
732

733
        if transaction:
1✔
734
            return self._prepareUpdate(transaction, where, what, changed_by, old_data_version)
1✔
735
        else:
736
            with AUSTransaction(self.getEngine()) as trans:
1✔
737
                return self._prepareUpdate(trans, where, what, changed_by, old_data_version)
1✔
738

739
    async def async_update(self, where, what, changed_by=None, old_data_version=None, transaction=None, dryrun=False):
1✔
740
        """Perform an UPDATE statement on this table. See AUSTable._updateStatement for
741
        a description of `where` and `what`. This method can only update a single row
742
        per invocation. If the where clause given would update zero or multiple rows, a
743
        WrongNumberOfRowsError is raised.
744

745
        :param where: A list of SQLAlchemy clauses, or a key/value pair of columns and values.
746
        :type where: list of clauses or key/value pairs.
747
        :param what: Key/value pairs containing new values for the given columns.
748
        :type what: key/value pairs
749
        :param changed_by: The username of the person inserting the row. Required when
750
                           history is enabled. Unused otherwise. No authorization checks are done
751
                           at this level.
752
        :type changed_by: str
753
        :param old_data_version: Previous version of the row to be deleted. If this version doesn't
754
                                 match the current version of the row, an OutdatedDataError will be
755
                                 raised and the delete will fail. Required when versioning is enabled.
756
        :type old_data_version: int
757
        :param transaction: A transaction object to add the update statement (and history changes) to.
758
                            If provided, you must commit the transaction yourself. If None, they will
759
                            be added to a locally-scoped transaction and committed.
760
        :param dryrun: If true, this insert statement will not actually be run.
761
        :type dryrun: bool
762

763
        :rtype: sqlalchemy.engine.base.ResultProxy
764
        """
765
        # If "where" is key/value pairs, we need to convert it to SQLAlchemy
766
        # clauses before proceeding.
767
        if hasattr(where, "keys"):
1✔
768
            where = [getattr(self, k) == v for k, v in where.items()]
1✔
769

770
        if self.history and not changed_by:
1!
771
            raise ValueError("changed_by must be passed for Tables that have history")
×
772
        if self.versioned and not old_data_version:
1!
773
            raise ValueError("update: old_data_version must be passed for Tables that are versioned")
×
774

775
        if dryrun:
1!
776
            self.log.debug("In dryrun mode, not doing anything...")
×
777
            return
×
778

779
        if transaction:
1!
780
            return await self._asyncPrepareUpdate(transaction, where, what, changed_by, old_data_version)
1✔
781
        else:
782
            with AUSTransaction(self.getEngine()) as trans:
×
783
                return await self._asyncPrepareUpdate(trans, where, what, changed_by, old_data_version)
×
784

785
    def count(self, column="*", where=None, transaction=None):
1✔
786
        count_statement = select(columns=[func.count(column)], from_obj=self.t)
1✔
787
        if where:
1✔
788
            for cond in where:
1✔
789
                count_statement = count_statement.where(cond)
1✔
790
        if transaction:
1✔
791
            row_count = transaction.execute(count_statement).scalar()
1✔
792
        else:
793
            with AUSTransaction(self.getEngine()) as trans:
1✔
794
                row_count = trans.execute(count_statement).scalar()
1✔
795
        return row_count
1✔
796

797
    def getRecentChanges(self, limit=10, transaction=None):
1✔
798
        return self.history.select(transaction=transaction, limit=limit, order_by=self.history.timestamp.desc())
×
799

800

801
class GCSHistory:
1✔
802
    def __init__(self, db, dialect, metadata, baseTable, buckets, identifier_columns, data_column):
1✔
803
        self.buckets = buckets
1✔
804
        self.identifier_columns = identifier_columns
1✔
805
        self.data_column = data_column
1✔
806

807
    def _getBucket(self, identifier):
1✔
808
        for substring, bucket in self.buckets.items():
×
809
            if substring in identifier:
×
810
                return bucket
×
811
        else:
812
            raise KeyError("Couldn't find bucket to place {} history in.".format(identifier))
×
813

814
    def forInsert(self, insertedKeys, columns, changed_by, trans):
1✔
815
        timestamp = getMillisecondTimestamp()
1✔
816
        identifier = "-".join([columns.get(i) for i in self.identifier_columns])
1✔
817
        for data_version, ts, data in ((None, timestamp - 1, ""), (columns.get("data_version"), timestamp, json.dumps(columns[self.data_column]))):
1✔
818
            bname = "{}/{}-{}-{}.json".format(identifier, data_version, ts, changed_by)
1✔
819
            with statsd.timer("gcs_upload"):
1✔
820
                bucket = self._getBucket(identifier)(use_gcloud_aio=False)
1✔
821
                blob = bucket.blob(bname)
1✔
822
                blob.upload_from_string(data, content_type="application/json")
1✔
823

824
    def forDelete(self, rowData, changed_by, trans):
1✔
825
        identifier = "-".join([rowData.get(i) for i in self.identifier_columns])
1✔
826
        bname = "{}/{}-{}-{}.json".format(identifier, rowData.get("data_version"), getMillisecondTimestamp(), changed_by)
1✔
827
        with statsd.timer("gcs_upload"):
1✔
828
            bucket = self._getBucket(identifier)(use_gcloud_aio=False)
1✔
829
            blob = bucket.blob(bname)
1✔
830
            blob.upload_from_string("", content_type="application/json")
1✔
831

832
    def forUpdate(self, rowData, changed_by, trans):
1✔
833
        identifier = "-".join([rowData.get(i) for i in self.identifier_columns])
1✔
834
        bname = "{}/{}-{}-{}.json".format(identifier, rowData.get("data_version"), getMillisecondTimestamp(), changed_by)
1✔
835
        with statsd.timer("gcs_upload"):
1✔
836
            bucket = self._getBucket(identifier)(use_gcloud_aio=False)
1✔
837
            blob = bucket.blob(bname)
1✔
838
            blob.upload_from_string(json.dumps(rowData[self.data_column]), content_type="application/json")
1✔
839

840
    def getChange(self, change_id=None, column_values=None, data_version=None, transaction=None):
1✔
841
        if not set(self.identifier_columns).issubset(column_values.keys()) or not data_version:
×
842
            raise ValueError("Cannot find GCS changes without {} and data_version".format(self.identifier_columns))
×
843
        identifier = "-".join([column_values[i] for i in self.identifier_columns])
×
844
        bucket = self._getBucket(identifier)(use_gcloud_aio=False)
×
845
        blobs = [b for b in bucket.list_blobs(prefix="{}/{}".format(identifier, data_version))]
×
846
        if len(blobs) != 1:
×
847
            raise ValueError("Found {} blobs instead of 1".format(len(blobs)))
×
848
        return {tuple(self.identifier_columns): identifier, "data_version": data_version, self.data_column: json.loads(blobs[0].download_as_string())}
×
849

850

851
class GCSHistoryAsync:
1✔
852
    def __init__(self, db, dialect, metadata, baseTable, buckets, identifier_columns, data_column):
1✔
853
        self.db = db
1✔
854
        self.buckets = buckets
1✔
855
        self.identifier_columns = identifier_columns
1✔
856
        self.data_column = data_column
1✔
857

858
    def _getBucket(self, identifier):
1✔
859
        for substring, bucket in self.buckets.items():
×
860
            if substring in identifier:
×
861
                return bucket
×
862
        else:
863
            raise KeyError("Couldn't find bucket to place {} history in.".format(identifier))
×
864

865
    async def forInsert(self, insertedKeys, columns, changed_by, trans):
1✔
866
        timestamp = getMillisecondTimestamp()
1✔
867
        identifier = "-".join([columns.get(i) for i in self.identifier_columns])
1✔
868
        for data_version, ts, data in ((None, timestamp - 1, ""), (columns.get("data_version"), timestamp, json.dumps(columns[self.data_column]))):
1✔
869
            bname = "{}/{}-{}-{}.json".format(identifier, data_version, ts, changed_by)
1✔
870
            with statsd.timer("async_gcs_upload"):
1✔
871
                # Using a separate session for each request is not ideal, but it's
872
                # the only thing that seems to work. Ideally, we'd share one session
873
                # for the entire application, but we can't for two reasons:
874
                # 1) gcloud-aio won't close the sessions, which results in a lot of
875
                # errors (https://github.com/talkiq/gcloud-aio/issues/33)
876
                # 2) When bhearsum tried this it resulted in hangs that he suspected
877
                # were caused by connection re-use.
878
                async with ClientSession() as session:
1✔
879
                    bucket = self._getBucket(identifier)(session=session)
1✔
880
                    blob = bucket.new_blob(bname)
1✔
881
                    await blob.upload(data, session=session)
1✔
882

883
    async def forDelete(self, rowData, changed_by, trans):
1✔
884
        identifier = "-".join([rowData.get(i) for i in self.identifier_columns])
1✔
885
        bname = "{}/{}-{}-{}.json".format(identifier, rowData.get("data_version"), getMillisecondTimestamp(), changed_by)
1✔
886
        with statsd.timer("async_gcs_upload"):
1✔
887
            async with ClientSession() as session:
1✔
888
                bucket = self._getBucket(identifier)(session=session)
1✔
889
                blob = bucket.new_blob(bname)
1✔
890
                await blob.upload("", session=session)
1✔
891

892
    async def forUpdate(self, rowData, changed_by, trans):
1✔
893
        identifier = "-".join([rowData.get(i) for i in self.identifier_columns])
1✔
894
        bname = "{}/{}-{}-{}.json".format(identifier, rowData.get("data_version"), getMillisecondTimestamp(), changed_by)
1✔
895
        with statsd.timer("async_gcs_upload"):
1✔
896
            async with ClientSession() as session:
1✔
897
                bucket = self._getBucket(identifier)(session=session)
1✔
898
                blob = bucket.new_blob(bname)
1✔
899
                await blob.upload(json.dumps(rowData[self.data_column]), session=session)
1✔
900

901

902
class HistoryTable(AUSTable):
1✔
903
    """Represents a history table that may be attached to another AUSTable.
904
    History tables mirror the structure of their `baseTable`, with the exception
905
    that nullable and primary_key attributes are always overwritten to be
906
    True and False respectively. Additionally, History tables have a unique
907
    change_id for each row, and record the username making a change, and the
908
    timestamp of each change. The methods forInsert, forDelete, and forUpdate
909
    will generate appropriate INSERTs to the History table given appropriate
910
    inputs, and are documented below. History tables are never versioned,
911
    and cannot have history of their own."""
912

913
    def __init__(self, db, dialect, metadata, baseTable):
1✔
914
        self.baseTable = baseTable
1✔
915
        self.table = Table(
1✔
916
            "%s_history" % baseTable.t.name,
917
            metadata,
918
            Column("change_id", Integer, primary_key=True, autoincrement=True),
919
            Column("changed_by", String(100), nullable=False),
920
        )
921
        # Timestamps are stored as an integer, but actually contain
922
        # precision down to the millisecond, achieved through
923
        # multiplication.
924
        # SQLAlchemy's SQLite dialect doesn't support fully support BigInteger.
925
        # The Column will work, but it ends up being a NullType Column which
926
        # breaks our upgrade unit tests. Because of this, we make sure to use
927
        # a plain Integer column for SQLite. In MySQL, an Integer is
928
        # Integer(11), which is too small for our needs.
929
        if dialect == "sqlite":
1!
930
            self.table.append_column(Column("timestamp", Integer, nullable=False))
1✔
931
        else:
932
            self.table.append_column(Column("timestamp", BigInteger, nullable=False))
×
933
        self.base_primary_key = [pk.name for pk in baseTable.primary_key]
1✔
934
        for col in baseTable.t.columns:
1✔
935
            newcol = col._copy()
1✔
936
            if col.primary_key:
1✔
937
                newcol.primary_key = False
1✔
938
            else:
939
                newcol.nullable = True
1✔
940
                # Setting unique to None because SQLAlchemy marks column attribute as None
941
                # unless they have been explicitely set to True or False.
942
                newcol.unique = None
1✔
943
            self.table.append_column(newcol)
1✔
944
        AUSTable.__init__(self, db, dialect, historyClass=None, versioned=False)
1✔
945

946
    def getPointInTime(self, timestamp, transaction=None):
1✔
947
        # The inner query here gets one change id for every unique object in
948
        # the base table. Filtering by timestamp < provided timestamp means
949
        # we won't get any results most recent than the requested timestamp.
950
        # Grouping by the primary key and selecting the max change_id means
951
        # we'll get the most recent change_id (after applying the timestamp
952
        # filter) for every unique object.
953
        # The outer query simply retrieves the actual row data for each
954
        # change_id that the inner query found
955
        # Black wants to format this all on one line, which is more difficult
956
        # to read.
957
        # fmt: off
958
        q = (select(self.table.columns)
1✔
959
             .where(self.change_id.in_(
960
                 select([sql_max(self.change_id)])
961
                 .where(self.timestamp <= timestamp)
962
                 .group_by(*self.base_primary_key)
963
             )
964
        ))
965
        # fmt: on
966
        if transaction:
1!
967
            result = transaction.execute(q).fetchall()
×
968
        else:
969
            with AUSTransaction(self.getEngine()) as trans:
1✔
970
                result = trans.execute(q).fetchall()
1✔
971

972
        rows = []
1✔
973
        # Filter out any rows who have no non-primary key data, because this
974
        # means the row has been deleted.
975
        non_primary_key_columns = [col.name for col in self.baseTable.t.columns if not col.primary_key]
1✔
976
        for row in result:
1✔
977
            if any([row[col] for col in non_primary_key_columns]):
1✔
978
                rows.append(row)
1✔
979

980
        return rows_to_dicts(rows)
1✔
981

982
    def forInsert(self, insertedKeys, columns, changed_by, trans):
1✔
983
        """Inserts cause two rows in the History table to be created. The first
984
        one records the primary key data and NULLs for other row data. This
985
        represents that the row did not exist prior to the insert. The
986
        timestamp for this row is 1 millisecond behind the real timestamp to
987
        reflect this. The second row records the full data of the row at the
988
        time of insert."""
989
        primary_key_data = {}
1✔
990
        for i in range(0, len(self.base_primary_key)):
1✔
991
            name = self.base_primary_key[i]
1✔
992
            primary_key_data[name] = insertedKeys[i]
1✔
993
            # Make sure the primary keys are included in the second row as well
994
            columns[name] = insertedKeys[i]
1✔
995

996
        ts = getMillisecondTimestamp()
1✔
997
        query, _ = self._insertStatement(changed_by=changed_by, timestamp=ts - 1, **primary_key_data)
1✔
998
        trans.execute(query)
1✔
999
        query, _ = self._insertStatement(changed_by=changed_by, timestamp=ts, **columns)
1✔
1000
        trans.execute(query)
1✔
1001

1002
    def forDelete(self, rowData, changed_by, trans):
1✔
1003
        """Deletes cause a single row to be created, which only contains the
1004
        primary key data. This represents that the row no longer exists."""
1005
        row = {}
1✔
1006
        table_row_data = {k: rowData[k] for k in rowData.keys() if k in self.table.c}
1✔
1007
        for k in table_row_data:
1✔
1008
            row[str(k)] = table_row_data[k]
1✔
1009
        # Tack on history table information to the row
1010
        row["changed_by"] = changed_by
1✔
1011
        row["timestamp"] = getMillisecondTimestamp()
1✔
1012
        query, _ = self._insertStatement(**row)
1✔
1013
        trans.execute(query)
1✔
1014

1015
    def forUpdate(self, rowData, changed_by, trans):
1✔
1016
        """Updates cause a single row to be created, which contains the full,
1017
        new data of the row at the time of the update."""
1018
        row = {}
1✔
1019
        table_row_data = {k: rowData[k] for k in rowData.keys() if k in self.table.c}
1✔
1020
        for k in table_row_data:
1✔
1021
            row[str(k)] = table_row_data[k]
1✔
1022
        row["changed_by"] = changed_by
1✔
1023
        row["timestamp"] = getMillisecondTimestamp()
1✔
1024
        query, _ = self._insertStatement(**row)
1✔
1025
        trans.execute(query)
1✔
1026

1027
    def getChange(self, change_id=None, column_values=None, data_version=None, transaction=None):
1✔
1028
        """Returns the unique change that matches the give change_id or
1029
        combination of data_version and values for the specified columns.
1030
        column_values is a dict that contains the column names that are
1031
        versioned and their values.
1032
        Ignores non primary key attributes specified in column_values."""
1033
        # if change_id is not None, we use it to get the change, ignoring
1034
        # data_version and column_values
1035
        by_change_id = False if change_id is None else True
1✔
1036
        # column_names lists all primary keys as string keys with the column
1037
        # objects as values
1038
        column_names = {col.name: col for col in self.table.columns if col.name in self.base_primary_key}
1✔
1039

1040
        if not by_change_id:
1✔
1041
            # we check if the entire primary key is present in column_values,
1042
            # since there might be multiple rows that match an incomplete
1043
            # primary key
1044
            for col in column_names.keys():
1✔
1045
                if col not in column_values.keys():
1✔
1046
                    raise ValueError("Entire primary key not present")
1✔
1047
            # data_version can only be queried for versioned tables
1048
            if not self.baseTable.versioned:
1!
1049
                raise ValueError("data_version queried for non-versioned table")
×
1050

1051
            where = [self.data_version == data_version]
1✔
1052
            self.log.debug("Querying for change_id by:")
1✔
1053
            self.log.debug("data_version: %s", data_version)
1✔
1054
            for col in column_names.keys():
1✔
1055
                self.log.debug("%s: %s", column_names[col], column_values[col])
1✔
1056
                where.append(column_names[col] == column_values[col])
1✔
1057

1058
            # To improve query efficiency we first get the change_id,
1059
            # and _then_ get the entire row. This is because we may not be able
1060
            # to query by an index depending which column_values we were given.
1061
            # If we end up querying by column_values that don't have an index,
1062
            # mysql will read many more rows than will be returned. This is
1063
            # particularly bad on the releases_history table, where the "data"
1064
            # column is often hundreds of kilobytes per row.
1065
            # Additional details in https://github.com/mozilla-releng/balrog/pull/419#issuecomment-334851038
1066
            change_ids = self.select(columns=[self.change_id], where=where, transaction=transaction)
1✔
1067
            if len(change_ids) != 1:
1✔
1068
                self.log.debug("Found %s changes when not querying by change_id, should have been 1", len(change_ids))
1✔
1069
                return None
1✔
1070
            change_id = change_ids[0]["change_id"]
1✔
1071

1072
        self.log.debug("Querying for full change by change_id %s", change_id)
1✔
1073
        changes = self.select(where=[self.change_id == change_id], transaction=transaction)
1✔
1074
        if len(changes) != 1:
1✔
1075
            self.log.debug("Found %s changes when querying by change_id, should have been 1", len(changes))
1✔
1076
            return None
1✔
1077
        return changes[0]
1✔
1078

1079

1080
class ConditionsTable(AUSTable):
1✔
1081
    # Scheduled changes may only have a single type of condition, but some
1082
    # conditions require mulitple arguments. This data structure defines
1083
    # each type of condition, and groups their args together for easier
1084
    # processing.
1085
    condition_groups = {"time": ("when",), "uptake": ("telemetry_product", "telemetry_channel", "telemetry_uptake")}
1✔
1086

1087
    def __init__(self, db, dialect, metadata, baseName, conditions, historyClass=HistoryTable):
1✔
1088
        if not conditions:
1✔
1089
            raise ValueError("No conditions enabled, cannot initialize conditions for for {}".format(baseName))
1✔
1090
        if set(conditions) - set(self.condition_groups):
1✔
1091
            raise ValueError("Unknown conditions in: {}".format(conditions))
1✔
1092

1093
        self.enabled_condition_groups = {k: v for k, v in self.condition_groups.items() if k in conditions}
1✔
1094

1095
        self.table = Table("{}_conditions".format(baseName), metadata, Column("sc_id", Integer, primary_key=True))
1✔
1096

1097
        if "uptake" in conditions:
1✔
1098
            self.table.append_column(Column("telemetry_product", String(15)))
1✔
1099
            self.table.append_column(Column("telemetry_channel", String(75)))
1✔
1100
            self.table.append_column(Column("telemetry_uptake", Integer))
1✔
1101

1102
        if "time" in conditions:
1!
1103
            if dialect == "sqlite":
1!
1104
                self.table.append_column(Column("when", Integer))
1✔
1105
            else:
1106
                self.table.append_column(Column("when", BigInteger))
×
1107

1108
        super(ConditionsTable, self).__init__(db, dialect, historyClass=historyClass, versioned=True)
1✔
1109

1110
    def validate(self, conditions):
1✔
1111
        conditions = {k: v for k, v in conditions.items() if conditions[k]}
1✔
1112
        if not conditions:
1✔
1113
            raise ValueError("No conditions found")
1✔
1114

1115
        for c in conditions:
1✔
1116
            for condition, args in self.condition_groups.items():
1✔
1117
                if c in args:
1✔
1118
                    if c in itertools.chain(*self.enabled_condition_groups.values()):
1✔
1119
                        break
1✔
1120
                    else:
1121
                        raise ValueError("{} condition is disabled".format(condition))
1✔
1122
            else:
1123
                raise ValueError("Invalid condition: %s", c)
1✔
1124

1125
        for group in self.enabled_condition_groups.values():
1✔
1126
            if set(group) == set(conditions.keys()):
1✔
1127
                break
1✔
1128
        else:
1129
            raise ValueError("Invalid combination of conditions: {}".format(conditions.keys()))
1✔
1130

1131
        if "when" in conditions:
1✔
1132
            try:
1✔
1133
                time.gmtime(conditions["when"] / 1000)
1✔
1134
            except Exception:
1✔
1135
                raise ValueError("Cannot parse 'when' as a unix timestamp.")
1✔
1136

1137
            if conditions["when"] < getMillisecondTimestamp():
1✔
1138
                raise ValueError("Cannot schedule changes in the past")
1✔
1139

1140

1141
class ScheduledChangeTable(AUSTable):
1✔
1142
    """A Table that stores the necessary information to schedule changes
1143
    to the baseTable provided. A ScheduledChangeTable ends up mirroring the
1144
    columns of its base, and adding the necessary ones to provide the schedule.
1145
    By default, ScheduledChangeTables enable History on themselves."""
1146

1147
    def __init__(self, db, dialect, metadata, baseTable, conditions=("time", "uptake"), historyClass=HistoryTable):
1✔
1148
        table_name = "{}_scheduled_changes".format(baseTable.t.name)
1✔
1149
        self.baseTable = baseTable
1✔
1150
        self.table = Table(
1✔
1151
            table_name,
1152
            metadata,
1153
            Column("sc_id", Integer, primary_key=True, autoincrement=True),
1154
            Column("scheduled_by", String(100), nullable=False),
1155
            Column("complete", Boolean, default=False),
1156
            Column("change_type", String(50), nullable=False),
1157
        )
1158
        self.conditions = ConditionsTable(db, dialect, metadata, table_name, conditions, historyClass=historyClass)
1✔
1159
        # Signoffs are configurable at runtime, which means that we always need
1160
        # a Signoffs table, even if it may not be used immediately.
1161
        self.signoffs = SignoffsTable(db, metadata, dialect, table_name)
1✔
1162

1163
        # The primary key column(s) are used in construct "where" clauses for
1164
        # existing rows.
1165
        self.base_primary_key = []
1✔
1166
        # A ScheduledChangesTable requires all of the columns from its base
1167
        # table, with a few tweaks:
1168
        for col in baseTable.t.columns:
1✔
1169
            if col.primary_key:
1✔
1170
                self.base_primary_key.append(col.name)
1✔
1171
            newcol = col._copy()
1✔
1172
            # 1) Columns are prefixed with "base_", to make them easy to
1173
            # identify and avoid conflicts.
1174
            # Renaming a column requires to change both the key and the name
1175
            # See https://github.com/zzzeek/sqlalchemy/blob/rel_0_7/lib/sqlalchemy/schema.py#L781
1176
            # for background.
1177
            newcol.key = newcol.name = "base_%s" % col.name
1✔
1178
            # 2) Primary Key Integer Autoincrement columns from the baseTable become normal nullable
1179
            # columns in ScheduledChanges because we can schedule changes that insert into baseTable
1180
            # and the DB will handle inserting the correct value. However, nulls aren't allowed when
1181
            # we schedule updates or deletes -this is enforced in self.validate().
1182
            # For Primary Key columns that aren't Integer or Autoincrement but are nullable, we preserve
1183
            # this non-nullability because we need a value to insert into the baseTable when the
1184
            # scheduled change gets executed.
1185
            # Non-Primary Key columns from the baseTable become nullable and non-unique in ScheduledChanges
1186
            # because they aren't part of the ScheduledChanges business logic and become simple data storage.
1187
            if col.primary_key:
1✔
1188
                newcol.primary_key = False
1✔
1189

1190
                # Only integer columns can be AUTOINCREMENT. The isinstance statement guards
1191
                # against false positives from SQLAlchemy.
1192
                if col.autoincrement and isinstance(col.type, Integer):
1✔
1193
                    newcol.nullable = True
1✔
1194
            else:
1195
                newcol.unique = None
1✔
1196
                newcol.nullable = True
1✔
1197

1198
            self.table.append_column(newcol)
1✔
1199

1200
        super(ScheduledChangeTable, self).__init__(db, dialect, historyClass=historyClass, versioned=True)
1✔
1201

1202
    def _prefixColumns(self, columns):
1✔
1203
        """Helper function which takes key/value pairs of columns for this
1204
        scheduled changes table - which could contain some unprefixed base
1205
        table columns - and returns key/values pairs of the same columns
1206
        with the base table ones prefixed."""
1207
        ret = {}
1✔
1208
        base_columns = [c.name for c in self.baseTable.t.columns]
1✔
1209
        for k, v in columns.items():
1✔
1210
            if k in base_columns:
1✔
1211
                ret["base_%s" % k] = v
1✔
1212
            else:
1213
                ret[k] = v
1✔
1214
        return ret
1✔
1215

1216
    def _splitColumns(self, columns):
1✔
1217
        """Because Scheduled Changes are stored across two Tables, we need to
1218
        split out the parts that are in the main table from the parts that
1219
        are stored in the conditions table in a few different places."""
1220
        base_columns = {}
1✔
1221
        condition_columns = {}
1✔
1222
        for cond_type in columns:
1✔
1223
            if cond_type in itertools.chain(*self.conditions.condition_groups.values()):
1✔
1224
                condition_columns[cond_type] = columns[cond_type]
1✔
1225
            else:
1226
                base_columns[cond_type] = columns[cond_type]
1✔
1227

1228
        return base_columns, condition_columns
1✔
1229

1230
    def _checkBaseTablePermissions(self, base_table_where, new_row, changed_by, transaction):
1✔
1231
        if "change_type" not in new_row:
1!
1232
            raise ValueError("change_type needed to check Permission")
×
1233

1234
        if new_row.get("change_type") == "update":
1✔
1235
            self.baseTable.update(base_table_where, new_row, changed_by, new_row["data_version"], transaction=transaction, dryrun=True)
1✔
1236
        elif new_row.get("change_type") == "insert":
1✔
1237
            self.baseTable.insert(changed_by, transaction=transaction, dryrun=True, **new_row)
1✔
1238
        elif new_row.get("change_type") == "delete":
1!
1239
            self.baseTable.delete(base_table_where, changed_by, new_row["data_version"], transaction=transaction, dryrun=True)
1✔
1240
        else:
1241
            raise ValueError("Unknown Change Type")
×
1242

1243
    def _dataVersionsAreSynced(self, sc_id, transaction):
1✔
1244
        sc_row = super(ScheduledChangeTable, self).select(where=[self.sc_id == sc_id], transaction=transaction, columns=[self.data_version])
1✔
1245
        conditions_row = self.conditions.select(where=[self.conditions.sc_id == sc_id], transaction=transaction, columns=[self.conditions.data_version])
1✔
1246
        if not sc_row or len(sc_row) != 1 or not conditions_row or len(conditions_row) != 1:
1!
1247
            return False
×
1248
        self.log.debug("sc_row data version is %s", sc_row[0].get("data_version"))
1✔
1249
        self.log.debug("conditions_row data version is %s", conditions_row[0].get("data_version"))
1✔
1250
        if sc_row[0].get("data_version") != conditions_row[0].get("data_version"):
1✔
1251
            return False
1✔
1252

1253
        return True
1✔
1254

1255
    def validate(self, base_columns, condition_columns, changed_by, sc_id=None, transaction=None):
1✔
1256
        # Depending on the change type, we may do some additional checks
1257
        # against the base table PK columns. It's cleaner to build up these
1258
        # early than do it later.
1259
        base_table_where = []
1✔
1260
        sc_table_where = []
1✔
1261

1262
        for pk in self.base_primary_key:
1✔
1263
            base_column = getattr(self.baseTable, pk)
1✔
1264
            if pk in base_columns:
1✔
1265
                sc_table_where.append(getattr(self, "base_%s" % pk) == base_columns[pk])
1✔
1266
                base_table_where.append(getattr(self.baseTable, pk) == base_columns[pk])
1✔
1267
            # Non-Integer columns can have autoincrement set to True for some reason.
1268
            # Any non-integer columns in the primary key are always required (because
1269
            # autoincrement actually isn't a thing for them), and any Integer columns
1270
            # that _aren't_ autoincrement are required as well.
1271
            elif not isinstance(base_column.type, (sqlalchemy.types.Integer,)) or not base_column.autoincrement:
1✔
1272
                raise ValueError("Missing primary key column '%s' which is not autoincrement", pk)
1✔
1273

1274
        if base_columns["change_type"] == "delete":
1✔
1275
            for pk in self.base_primary_key:
1✔
1276
                if pk not in base_columns:
1!
1277
                    raise ValueError("Missing primary key column %s. PK values needed for deletion" % (pk))
×
1278
                if base_columns[pk] is None:
1!
1279
                    raise ValueError("%s value found to be None. PK value can not be None for deletion" % (pk))
×
1280
        elif base_columns["change_type"] == "update":
1✔
1281
            # For updates, we need to make sure that the baseTable row already
1282
            # exists, and that the data version provided matches the current
1283
            # version to ensure that someone isn't trying to schedule a change
1284
            # against out-of-date data.
1285
            current_data_version = self.baseTable.select(columns=(self.baseTable.data_version,), where=base_table_where, transaction=transaction)
1✔
1286
            if not current_data_version:
1✔
1287
                raise ValueError("Cannot create scheduled change with data_version for non-existent row")
1✔
1288

1289
            if current_data_version and current_data_version[0]["data_version"] != base_columns.get("data_version"):
1✔
1290
                raise OutdatedDataError("Wrong data_version given for base table, cannot create scheduled change.")
1✔
1291
        elif base_columns["change_type"] == "insert" and base_table_where:
1✔
1292
            # If the base table row shouldn't already exist, we need to make sure they don't
1293
            # to avoid getting an IntegrityError when the change is enacted.
1294
            if self.baseTable.select(columns=(self.baseTable.data_version,), where=base_table_where, transaction=transaction):
1✔
1295
                raise ValueError("Cannot schedule change for duplicate PK")
1✔
1296

1297
        # If we're validating a new scheduled change (sc_id is None), we need
1298
        # to make sure that no other scheduled change already exists if a
1299
        # primary key for the base table was provided (sc_table_where is not empty).
1300
        if not sc_id and sc_table_where:
1✔
1301
            sc_table_where.append(self.complete == False)  # noqa because we need to use == for sqlalchemy operator overloading to work
1✔
1302
            if len(self.select(columns=[self.sc_id], where=sc_table_where)) > 0:
1✔
1303
                raise ChangeScheduledError("Cannot schedule a change for a row with one already scheduled")
1✔
1304

1305
        self.conditions.validate(condition_columns)
1✔
1306
        self._checkBaseTablePermissions(base_table_where, base_columns, changed_by, transaction)
1✔
1307

1308
    def auto_signoff(self, changed_by, transaction, sc_id, dryrun, columns):
1✔
1309
        # - If the User scheduling a change only holds one of the required Roles, record a signoff with it.
1310
        # - If the User scheduling a change holds more than one of the required Roles, we cannot a Signoff, because
1311
        #   we don't know which Role we'd want to signoff with. The user will need to signoff
1312
        #   manually in these cases.
1313
        user_roles = self.db.getUserRoles(username=changed_by, transaction=transaction)
1✔
1314
        if len(user_roles):
1✔
1315
            required_roles = set()
1✔
1316
            required_signoffs = self.baseTable.getPotentialRequiredSignoffs([columns], transaction=transaction)
1✔
1317
            if required_signoffs:
1✔
1318
                required_roles.update([rs["role"] for rs in [obj for v in required_signoffs.values() for obj in v]])
1✔
1319
            possible_signoffs = list(filter(lambda role: role["role"] in required_roles, user_roles))
1✔
1320
            if len(possible_signoffs) == 1:
1✔
1321
                self.signoffs.insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, sc_id=sc_id, role=possible_signoffs[0].get("role"))
1✔
1322

1323
    def select(self, where=None, transaction=None, **kwargs):
1✔
1324
        ret = []
1✔
1325
        # We'll be retrieving condition information for each Scheduled Change,
1326
        # and we'll need sc_id to do so.
1327
        if kwargs.get("columns") is not None:
1✔
1328
            # Columns can be specified as names or Column instances, so we must check for both.
1329
            if "sc_id" not in kwargs["columns"] and self.sc_id not in kwargs["columns"]:
1✔
1330
                kwargs["columns"].append(self.sc_id)
1✔
1331
        for row in super(ScheduledChangeTable, self).select(where=where, transaction=transaction, **kwargs):
1✔
1332
            columns = [getattr(self.conditions, c) for c in itertools.chain(*self.conditions.enabled_condition_groups.values())]
1✔
1333
            conditions = self.conditions.select([self.conditions.sc_id == row["sc_id"]], transaction=transaction, columns=columns)
1✔
1334
            row.update(conditions[0])
1✔
1335
            ret.append(row)
1✔
1336
        return ret
1✔
1337

1338
    def insert(self, changed_by, transaction=None, dryrun=False, **columns):
1✔
1339
        base_columns, condition_columns = self._splitColumns(columns)
1✔
1340
        if "change_type" not in base_columns:
1!
1341
            raise ValueError("Change type is required")
×
1342

1343
        self.validate(base_columns=base_columns, condition_columns=condition_columns, changed_by=changed_by, transaction=transaction)
1✔
1344

1345
        base_columns = self._prefixColumns(base_columns)
1✔
1346
        base_columns["scheduled_by"] = changed_by
1✔
1347

1348
        ret = super(ScheduledChangeTable, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **base_columns)
1✔
1349
        if not dryrun:
1!
1350
            sc_id = ret.inserted_primary_key[0]
1✔
1351
            self.conditions.insert(changed_by, transaction, dryrun, sc_id=sc_id, **condition_columns)
1✔
1352
            if not self._dataVersionsAreSynced(sc_id, transaction):
1✔
1353
                raise MismatchedDataVersionError("Conditions data version is out of sync with main table for sc_id %s", sc_id)
1✔
1354
            self.auto_signoff(changed_by, transaction, sc_id, dryrun, columns)
1✔
1355

1356
            return sc_id
1✔
1357

1358
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False):
1✔
1359
        base_what, condition_what = self._splitColumns(what)
1✔
1360

1361
        affected_ids = []
1✔
1362
        # We need to check each Scheduled Change that would be affected by this
1363
        # to ensure the new row will be valid.
1364
        for row in self.select(where=where, transaction=transaction):
1✔
1365
            # verify whether the scheduled change has already been completed or not. If completed,
1366
            # then cannot modify the scheduled change anymore.
1367
            if row.get("complete"):
1✔
1368
                raise ValueError("Scheduled change already completed. Cannot update now.")
1✔
1369

1370
            affected_ids.append(row["sc_id"])
1✔
1371
            # Before validation, we need to create the new version of the
1372
            # Scheduled Change by combining the old one with the new data.
1373
            # To do this, we need to split the columns up a bit. First,
1374
            # separating the primary scheduled changes columns from the conditions...
1375
            sc_columns, condition_columns = self._splitColumns(row)
1✔
1376
            # ...and then combine taking the baseTable parts of sc_columns
1377
            # and combining them with any new values provided in base_what.
1378
            base_columns = {}
1✔
1379
            for col in sc_columns:
1✔
1380
                if not col.startswith("base_"):
1✔
1381
                    continue
1✔
1382
                base_col = col.replace("base_", "")
1✔
1383
                if base_col in base_what:
1✔
1384
                    base_columns[base_col] = base_what[base_col]
1✔
1385
                elif sc_columns.get(col):
1✔
1386
                    base_columns[base_col] = sc_columns[col]
1✔
1387

1388
            # As we need change_type in base_columns and it does not start with "base_". We assign it outside the loop
1389
            base_columns["change_type"] = sc_columns["change_type"]
1✔
1390

1391
            # Similarly, we need to integrate the new values for any conditions
1392
            # with the existing ones.
1393
            condition_columns.update(condition_what)
1✔
1394

1395
            # Now that we have all that sorted out, we can validate the new values for everything.
1396
            self.validate(base_columns, condition_columns, changed_by, sc_id=sc_columns["sc_id"], transaction=transaction)
1✔
1397

1398
            self.conditions.update([self.conditions.sc_id == sc_columns["sc_id"]], condition_columns, changed_by, old_data_version, transaction, dryrun=dryrun)
1✔
1399

1400
        base_what = self._prefixColumns(base_what)
1✔
1401
        base_what["scheduled_by"] = changed_by
1✔
1402
        ret = super(ScheduledChangeTable, self).update(where, base_what, changed_by, old_data_version, transaction, dryrun=dryrun)
1✔
1403
        sc_id = ret.last_updated_params()["sc_id"]
1✔
1404

1405
        for sc_id in affected_ids:
1✔
1406
            if not self._dataVersionsAreSynced(sc_id, transaction):
1✔
1407
                raise MismatchedDataVersionError("Conditions data version is out of sync with main table for sc_id %s" % sc_id)
1✔
1408

1409
        self.auto_signoff(changed_by, transaction, sc_id, dryrun, base_columns)
1✔
1410

1411
        for sc_id in affected_ids:
1✔
1412
            where_signOff = {"sc_id": sc_id}
1✔
1413
            signOffs = self.signoffs.select(where=where_signOff, transaction=transaction, columns=["sc_id", "username"])
1✔
1414
            for signOff in signOffs:
1✔
1415
                if signOff["username"] != changed_by:
1✔
1416
                    where_signOff.update({"username": signOff["username"]})
1✔
1417
                    self.signoffs.delete(where=where_signOff, changed_by=changed_by, transaction=transaction, reset_signoff=True)
1✔
1418

1419
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False):
1✔
1420
        conditions_where = []
1✔
1421
        for row in self.select(where=where, transaction=transaction):
1✔
1422
            # verify whether the scheduled change has already been completed or not. If completed,
1423
            # then cannot modify the scheduled change anymore.
1424
            if row.get("complete"):
1✔
1425
                raise ValueError("Scheduled change already completed. Cannot delete now.")
1✔
1426

1427
            conditions_where.append(self.conditions.sc_id == row["sc_id"])
1✔
1428
            base_row = {col[5:]: row[col] for col in row if col.startswith("base_")}
1✔
1429
            # we also need change_type in base_row to check permission
1430
            base_row["change_type"] = row["change_type"]
1✔
1431
            base_table_where = {pk: row["base_%s" % pk] for pk in self.base_primary_key}
1✔
1432
            # TODO: What permissions *should* be required to delete a scheduled change?
1433
            # It seems a bit odd to be checking base table update/insert here. Maybe
1434
            # something broader should be required?
1435
            self._checkBaseTablePermissions(base_table_where, base_row, changed_by, transaction)
1✔
1436

1437
        ret = super(ScheduledChangeTable, self).delete(where, changed_by, old_data_version, transaction, dryrun=dryrun)
1✔
1438
        self.conditions.delete(conditions_where, changed_by, old_data_version, transaction, dryrun=dryrun)
1✔
1439
        return ret
1✔
1440

1441
    async def asyncEnactChange(self, sc_id, enacted_by, transaction=None):
1✔
1442
        """Enacts a previously scheduled change by running update, insert, or delete on
1443
        the base table."""
1444
        if not self.db.hasPermission(enacted_by, "scheduled_change", "enact", transaction=transaction):
1✔
1445
            raise PermissionDeniedError("%s is not allowed to enact scheduled changes", enacted_by)
1✔
1446

1447
        sc = self.select(where=[self.sc_id == sc_id], transaction=transaction)[0]
1✔
1448
        what = {}
1✔
1449
        change_type = sc["change_type"]
1✔
1450
        for col in sc:
1✔
1451
            if col.startswith("base_"):
1✔
1452
                what[col[5:]] = sc[col]
1✔
1453

1454
        # The scheduled change is marked as complete first to avoid it being
1455
        # updated unnecessarily when the base table's update method calls
1456
        # mergeUpdate. If the base table update fails, this will get reverted
1457
        # when the transaction is rolled back.
1458
        # We explicitly avoid using ScheduledChangeTable's update() method here
1459
        # because we don't want to trigger its validation of conditions. Doing so
1460
        # would raise any exception for any timestamp based changes, because
1461
        # they are already in the past when we're ready to enact them.
1462
        # Updating in conditions table also so that history view can work
1463
        # See : https://bugzilla.mozilla.org/show_bug.cgi?id=1333876
1464
        self.conditions.update(
1✔
1465
            where=[self.conditions.sc_id == sc_id], what={}, changed_by=sc["scheduled_by"], old_data_version=sc["data_version"], transaction=transaction
1466
        )
1467
        super(ScheduledChangeTable, self).update(
1✔
1468
            where=[self.sc_id == sc_id], what={"complete": True}, changed_by=sc["scheduled_by"], old_data_version=sc["data_version"], transaction=transaction
1469
        )
1470

1471
        signoffs = self.signoffs.select(where=[self.signoffs.sc_id == sc_id], transaction=transaction)
1✔
1472

1473
        # If the scheduled change had a data version, it means the row already
1474
        # exists, and we need to use update() to enact it.
1475
        if change_type == "delete":
1✔
1476
            where = []
1✔
1477
            for col in self.base_primary_key:
1✔
1478
                where.append((getattr(self.baseTable, col) == sc["base_%s" % col]))
1✔
1479
            await self.baseTable.async_delete(where, sc["scheduled_by"], sc["base_data_version"], transaction=transaction, signoffs=signoffs)
1✔
1480
        elif change_type == "update":
1✔
1481
            where = []
1✔
1482
            for col in self.base_primary_key:
1✔
1483
                where.append((getattr(self.baseTable, col) == sc["base_%s" % col]))
1✔
1484
            await self.baseTable.async_update(where, what, sc["scheduled_by"], sc["base_data_version"], transaction=transaction, signoffs=signoffs)
1✔
1485
        elif change_type == "insert":
1!
1486
            await self.baseTable.async_insert(sc["scheduled_by"], transaction=transaction, signoffs=signoffs, **what)
1✔
1487
        else:
1488
            raise ValueError("Unknown Change Type")
×
1489

1490
    def enactChange(self, sc_id, enacted_by, transaction=None):
1✔
1491
        """Enacts a previously scheduled change by running update, insert, or delete on
1492
        the base table."""
1493
        if not self.db.hasPermission(enacted_by, "scheduled_change", "enact", transaction=transaction):
1✔
1494
            raise PermissionDeniedError("%s is not allowed to enact scheduled changes", enacted_by)
1✔
1495

1496
        sc = self.select(where=[self.sc_id == sc_id], transaction=transaction)[0]
1✔
1497
        what = {}
1✔
1498
        change_type = sc["change_type"]
1✔
1499
        for col in sc:
1✔
1500
            if col.startswith("base_"):
1✔
1501
                what[col[5:]] = sc[col]
1✔
1502

1503
        # The scheduled change is marked as complete first to avoid it being
1504
        # updated unnecessarily when the base table's update method calls
1505
        # mergeUpdate. If the base table update fails, this will get reverted
1506
        # when the transaction is rolled back.
1507
        # We explicitly avoid using ScheduledChangeTable's update() method here
1508
        # because we don't want to trigger its validation of conditions. Doing so
1509
        # would raise any exception for any timestamp based changes, because
1510
        # they are already in the past when we're ready to enact them.
1511
        # Updating in conditions table also so that history view can work
1512
        # See : https://bugzilla.mozilla.org/show_bug.cgi?id=1333876
1513
        self.conditions.update(
1✔
1514
            where=[self.conditions.sc_id == sc_id], what={}, changed_by=sc["scheduled_by"], old_data_version=sc["data_version"], transaction=transaction
1515
        )
1516
        super(ScheduledChangeTable, self).update(
1✔
1517
            where=[self.sc_id == sc_id], what={"complete": True}, changed_by=sc["scheduled_by"], old_data_version=sc["data_version"], transaction=transaction
1518
        )
1519

1520
        signoffs = self.signoffs.select(where=[self.signoffs.sc_id == sc_id], transaction=transaction)
1✔
1521

1522
        # If the scheduled change had a data version, it means the row already
1523
        # exists, and we need to use update() to enact it.
1524
        if change_type == "delete":
1✔
1525
            where = []
1✔
1526
            for col in self.base_primary_key:
1✔
1527
                where.append((getattr(self.baseTable, col) == sc["base_%s" % col]))
1✔
1528
            self.baseTable.delete(where, sc["scheduled_by"], sc["base_data_version"], transaction=transaction, signoffs=signoffs)
1✔
1529
        elif change_type == "update":
1✔
1530
            where = []
1✔
1531
            for col in self.base_primary_key:
1✔
1532
                where.append((getattr(self.baseTable, col) == sc["base_%s" % col]))
1✔
1533
            self.baseTable.update(where, what, sc["scheduled_by"], sc["base_data_version"], transaction=transaction, signoffs=signoffs)
1✔
1534
        elif change_type == "insert":
1!
1535
            for col in self.base_primary_key:
1✔
1536
                # as we want sqlalchemy to return the automatically inserted id, we need to not pass it (for 1.4)
1537
                if what[col] is None:
1✔
1538
                    del what[col]
1✔
1539
            self.baseTable.insert(sc["scheduled_by"], transaction=transaction, signoffs=signoffs, **what)
1✔
1540
        else:
1541
            raise ValueError("Unknown Change Type")
×
1542

1543
    def mergeUpdate(self, old_row, what, changed_by, transaction=None):
1✔
1544
        """Merges an update to the base table into any changes that may be
1545
        scheduled for the affected row. If the changes are unmergable
1546
        (meaning: the scheduled change and the new version of the row modify
1547
        the same columns), an UpdateMergeError is raised."""
1548

1549
        # Filter the update to only include fields that are different than
1550
        # what's in the base (old_row).
1551
        what = {k: v for k, v in what.items() if v != old_row.get(k)}
1✔
1552

1553
        # pyflakes thinks this should be "is False", but that's not how SQLAlchemy
1554
        # works, so we need to shut it up.
1555
        # http://stackoverflow.com/questions/18998010/flake8-complains-on-boolean-comparison-in-filter-clause
1556
        where = [self.complete == False]  # noqa
1✔
1557
        for col in self.base_primary_key:
1✔
1558
            where.append((getattr(self, "base_%s" % col) == old_row[col]))
1✔
1559

1560
        scheduled_changes = self.select(where=where, transaction=transaction)
1✔
1561

1562
        if not scheduled_changes:
1✔
1563
            self.log.debug("No scheduled changes found for update; nothing to do")
1✔
1564
            return
1✔
1565
        for sc in scheduled_changes:
1✔
1566
            self.log.debug("Trying to merge update with scheduled change '%s'", sc["sc_id"])
1✔
1567

1568
            for col in what:
1✔
1569
                # If the scheduled change is different than the old row it will
1570
                # be modifying the row when enacted. If the update to the row
1571
                # ("what") is also modifying the same column, this is a conflict
1572
                # that the server cannot resolve.
1573
                if sc["base_%s" % col] != old_row.get(col) and what.get(col) != old_row.get(col):
1✔
1574
                    raise UpdateMergeError("Cannot safely merge change to '%s' with scheduled change '%s'", col, sc["sc_id"])
1✔
1575

1576
            # If we get here, the change is safely mergeable
1577
            self.update(
1✔
1578
                where=[self.sc_id == sc["sc_id"]], what=what, changed_by=sc["scheduled_by"], old_data_version=sc["data_version"], transaction=transaction
1579
            )
1580
            self.log.debug("Merged %s into scheduled change '%s'", what, sc["sc_id"])
1✔
1581

1582

1583
class RequiredSignoffsTable(AUSTable):
1✔
1584
    """RequiredSignoffsTables store and validate information about what types
1585
    and how many signoffs are required for the data provided in
1586
    `decisionColumns`. Subclasses are required to create a Table with the
1587
    necessary columns, and add those columns names to `decisionColumns`.
1588
    When changes are made to a RequiredSignoffsTable, it will look at its own
1589
    rows to determine whether or not that change needs signoff."""
1590

1591
    decisionColumns = []
1✔
1592

1593
    def __init__(self, db, dialect):
1✔
1594
        self.table.append_column(Column("role", String(50), primary_key=True))
1✔
1595
        self.table.append_column(Column("signoffs_required", Integer, nullable=False))
1✔
1596

1597
        super(RequiredSignoffsTable, self).__init__(
1✔
1598
            db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=HistoryTable
1599
        )
1600

1601
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
1602
        potential_required_signoffs = {"rs": []}
1✔
1603
        for row in affected_rows:
1✔
1604
            if not row:
1!
1605
                continue
×
1606
            where = {col: row[col] for col in self.decisionColumns}
1✔
1607
            potential_required_signoffs["rs"].extend(self.select(where=where, transaction=transaction))
1✔
1608
        return potential_required_signoffs
1✔
1609

1610
    def validate(self, columns, transaction=None):
1✔
1611
        for col in self.decisionColumns:
1✔
1612
            if columns[col] is None:
1!
1613
                raise ValueError("{} are required.".format(self.decisionColumns))
×
1614
            user_table = self.db.permissions.user_roles
1✔
1615
            users_with_role = user_table.count(where=[user_table.role == columns["role"]], transaction=transaction)
1✔
1616

1617
        if users_with_role < columns["signoffs_required"]:
1✔
1618
            msg = ", ".join([columns[col] for col in self.decisionColumns])
1✔
1619
            raise ValueError(
1✔
1620
                "Cannot require {} signoffs for {} - only {} users hold the {} role".format(columns["signoffs_required"], msg, users_with_role, columns["role"])
1621
            )
1622

1623
    def insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
1624
        self.validate(columns, transaction=transaction)
1✔
1625

1626
        if not self.db.hasPermission(changed_by, "required_signoff", "create", transaction=transaction):
1✔
1627
            raise PermissionDeniedError("{} is not allowed to create new Required Signoffs.".format(changed_by))
1✔
1628

1629
        if not dryrun:
1✔
1630
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
1631
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
1632

1633
        return super(RequiredSignoffsTable, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
1634

1635
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
1636
        for rs in self.select(where=where, transaction=transaction):
1✔
1637
            new_rs = rs.copy()
1✔
1638
            new_rs.update(what)
1✔
1639
            self.validate(new_rs, transaction=transaction)
1✔
1640

1641
            if not self.db.hasPermission(changed_by, "required_signoff", "modify", transaction=transaction):
1✔
1642
                raise PermissionDeniedError("{} is not allowed to modify Required Signoffs.".format(changed_by))
1✔
1643

1644
            if not dryrun:
1✔
1645
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([rs, new_rs], transaction=transaction).values() for obj in v]
1✔
1646
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
1647

1648
        return super(RequiredSignoffsTable, self).update(
1✔
1649
            where=where, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
1650
        )
1651

1652
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
1653
        if not self.db.hasPermission(changed_by, "required_signoff", "delete", transaction=transaction):
1✔
1654
            raise PermissionDeniedError("{} is not allowed to remove Required Signoffs.".format(changed_by))
1✔
1655

1656
        if not dryrun:
1✔
1657
            for rs in self.select(where=where, transaction=transaction):
1✔
1658
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([rs], transaction=transaction).values() for obj in v]
1✔
1659
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
1660

1661
        return super(RequiredSignoffsTable, self).delete(
1✔
1662
            where=where, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
1663
        )
1664

1665

1666
class ProductRequiredSignoffsTable(RequiredSignoffsTable):
1✔
1667
    decisionColumns = ["product", "channel"]
1✔
1668

1669
    def __init__(self, db, metadata, dialect):
1✔
1670
        self.table = Table("product_req_signoffs", metadata, Column("product", String(15), primary_key=True), Column("channel", String(75), primary_key=True))
1✔
1671
        super(ProductRequiredSignoffsTable, self).__init__(db, dialect)
1✔
1672

1673

1674
class PermissionsRequiredSignoffsTable(RequiredSignoffsTable):
1✔
1675
    decisionColumns = ["product"]
1✔
1676

1677
    def __init__(self, db, metadata, dialect):
1✔
1678
        self.table = Table("permissions_req_signoffs", metadata, Column("product", String(15), primary_key=True))
1✔
1679
        super(PermissionsRequiredSignoffsTable, self).__init__(db, dialect)
1✔
1680

1681

1682
class SignoffsTable(AUSTable):
1✔
1683
    def __init__(self, db, metadata, dialect, baseName):
1✔
1684
        self.table = Table(
1✔
1685
            "{}_signoffs".format(baseName),
1686
            metadata,
1687
            Column("sc_id", Integer, primary_key=True, autoincrement=False),
1688
            Column("username", String(100), primary_key=True),
1689
            Column("role", String(50), nullable=False),
1690
        )
1691
        # Because Signoffs cannot be modified, there's no possibility of an
1692
        # update race, so they do not need to be versioned.
1693
        super(SignoffsTable, self).__init__(db, dialect, versioned=False, historyClass=HistoryTable)
1✔
1694

1695
    def insert(self, changed_by=None, transaction=None, dryrun=False, **columns):
1✔
1696
        if "sc_id" not in columns or "role" not in columns:
1!
1697
            raise ValueError("sc_id and role must be provided when signing off")
×
1698
        if "username" in columns and columns["username"] != changed_by:
1!
1699
            raise PermissionDeniedError("Cannot signoff on behalf of another user")
×
1700
        if changed_by in self.db.systemAccounts:
1✔
1701
            raise PermissionDeniedError("System account cannot signoff")
1✔
1702
        if not self.db.hasRole(changed_by, columns["role"], transaction=transaction):
1✔
1703
            raise PermissionDeniedError("{} cannot signoff with role '{}'".format(changed_by, columns["role"]))
1✔
1704

1705
        existing_signoff = self.select({"sc_id": columns["sc_id"], "username": changed_by}, transaction)
1✔
1706
        if existing_signoff:
1✔
1707
            # It shouldn't be possible for there to be more than one signoff,
1708
            # so not iterating over this should be fine.
1709
            existing_signoff = existing_signoff[0]
1✔
1710
            if existing_signoff["role"] != columns["role"]:
1✔
1711
                raise PermissionDeniedError("Cannot signoff with a second role")
1✔
1712
            # Signoff already made under the same role, we don't need to do
1713
            # anything!
1714
            return
1✔
1715

1716
        columns["username"] = changed_by
1✔
1717
        super(SignoffsTable, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
1718

1719
    def update(self, where, what, changed_by=None, transaction=None, dryrun=False):
1✔
1720
        raise AttributeError("Signoffs cannot be modified (only granted and revoked)")
1✔
1721

1722
    def delete(self, where, changed_by=None, transaction=None, dryrun=False, reset_signoff=False):
1✔
1723
        if not reset_signoff:
1✔
1724
            for row in self.select(where, transaction):
1✔
1725
                if changed_by in self.db.systemAccounts:
1✔
1726
                    raise PermissionDeniedError("System accounts cannot revoke a signoff")
1✔
1727
                if not self.db.hasRole(changed_by, row["role"], transaction=transaction) and not self.db.isAdmin(changed_by, transaction=transaction):
1✔
1728
                    raise PermissionDeniedError("Cannot revoke a signoff made by someone in a group you do not belong to")
1✔
1729

1730
        super(SignoffsTable, self).delete(where, changed_by=changed_by, transaction=transaction, dryrun=dryrun)
1✔
1731

1732

1733
class Rules(AUSTable):
1✔
1734
    def __init__(self, db, metadata, dialect):
1✔
1735
        self.table = Table(
1✔
1736
            "rules",
1737
            metadata,
1738
            Column("rule_id", Integer, primary_key=True, autoincrement=True),
1739
            Column("alias", String(50), unique=True),
1740
            Column("priority", Integer),
1741
            Column("mapping", String(100)),
1742
            Column("fallbackMapping", String(100)),
1743
            Column("backgroundRate", Integer),
1744
            Column("update_type", String(15), nullable=False),
1745
            Column("product", String(15)),
1746
            Column("version", String(75)),
1747
            Column("channel", String(75)),
1748
            Column("buildTarget", String(75)),
1749
            Column("buildID", String(20)),
1750
            Column("locale", String(200)),
1751
            Column("osVersion", String(1000)),
1752
            Column("memory", String(100)),
1753
            Column("instructionSet", String(1000)),
1754
            Column("jaws", CompatibleBooleanColumn),
1755
            Column("mig64", CompatibleBooleanColumn),
1756
            Column("distribution", String(2000)),
1757
            Column("distVersion", String(100)),
1758
            Column("headerArchitecture", String(10)),
1759
            Column("comment", String(500)),
1760
        )
1761

1762
        AUSTable.__init__(self, db, dialect, scheduled_changes=True, historyClass=HistoryTable)
1✔
1763

1764
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
1765
        potential_required_signoffs = {}
1✔
1766
        rows = []
1✔
1767
        # The new row may change the product or channel, so we must look for
1768
        # Signoffs for both.
1769
        for row in affected_rows:
1✔
1770
            if not row:
1!
1771
                continue
×
1772
            rows.append(row)
1✔
1773

1774
        where = {}
1✔
1775
        cond = []
1✔
1776
        for row in rows:
1✔
1777
            if not row.get("product"):
1✔
1778
                # If product isn't present, or is None, it means the Rule affects
1779
                # all products, and we must leave it out of the where clause. If
1780
                # we included it, the query would only match rows where product is
1781
                # NULL. Since we are returning all rs, we can safely breakout of this loop
1782
                break
1✔
1783
            cond.append(row["product"])
1✔
1784
        else:  # nobreak
1785
            where = [self.db.productRequiredSignoffs.product.in_(tuple(cond))]
1✔
1786

1787
        q = self.db.productRequiredSignoffs.select(where=where, transaction=transaction)
1✔
1788

1789
        # map query result using product as the key
1790
        q_map = {}
1✔
1791
        for rs in q:
1✔
1792
            if rs["product"] in q_map:
1✔
1793
                q_map[rs["product"]].append(rs)
1✔
1794
            else:
1795
                q_map[rs["product"]] = [rs]
1✔
1796

1797
        for row in rows:
1✔
1798
            potential_required_signoffs[(row.get("product"), row.get("channel"))] = get_required_signoffs_for_product_channel(
1✔
1799
                row.get("product"), row.get("channel"), q_map, q
1800
            )
1801
        return potential_required_signoffs
1✔
1802

1803
    def _isAlias(self, id_or_alias):
1✔
1804
        if re.match("^[a-zA-Z][a-zA-Z0-9-]*$", str(id_or_alias)):
1✔
1805
            return True
1✔
1806
        return False
1✔
1807

1808
    def insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
1809
        if not self.db.hasPermission(changed_by, "rule", "create", columns.get("product"), transaction):
1✔
1810
            raise PermissionDeniedError("%s is not allowed to create new rules for product %s" % (changed_by, columns.get("product")))
1✔
1811

1812
        if not dryrun:
1✔
1813
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
1814
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
1815

1816
        ret = super(Rules, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
1817
        if not dryrun:
1✔
1818
            return ret.inserted_primary_key[0]
1✔
1819

1820
    def getOrderedRules(self, where=None, transaction=None):
1✔
1821
        """Returns all of the rules, sorted in ascending order"""
1822
        return self.select(where=where, order_by=(self.priority, self.version, self.mapping), transaction=transaction)
1✔
1823

1824
    def getRulesMatchingQuery(self, updateQuery, fallbackChannel, transaction=None):
1✔
1825
        """Returns all of the rules that match the given update query.
1826
        For cases where a particular updateQuery channel has no
1827
        fallback, fallbackChannel should match the channel from the query."""
1828

1829
        def getRawMatches():
1✔
1830
            where = [
1✔
1831
                ((self.product == updateQuery["product"]) | (self.product == null()))
1832
                & ((self.buildTarget == updateQuery["buildTarget"]) | (self.buildTarget == null()))
1833
            ]
1834

1835
            if "headerArchitecture" in updateQuery:
1✔
1836
                where.extend([(self.headerArchitecture == updateQuery.get("headerArchitecture")) | (self.headerArchitecture == null())])
1✔
1837
            else:
1838
                where.extend([self.headerArchitecture == null()])
1✔
1839

1840
            if "distVersion" in updateQuery:
1✔
1841
                where.extend([(self.distVersion == updateQuery["distVersion"]) | (self.distVersion == null())])
1✔
1842
            else:
1843
                where.extend([self.distVersion == null()])
1✔
1844

1845
            self.log.debug("where: %s", where)
1✔
1846
            return self.select(where=where, transaction=transaction)
1✔
1847

1848
        # This cache key is constructed from all parts of the updateQuery that
1849
        # are used in the select() to get the "raw" rule matches. For the most
1850
        # part, product and buildTarget will be the only applicable ones which
1851
        # means we should get very high cache hit rates, as there's not a ton
1852
        # of variability of possible combinations for those.
1853
        cache_key = "%s:%s:%s:%s:%s" % (
1✔
1854
            updateQuery["product"],
1855
            updateQuery["buildTarget"],
1856
            updateQuery.get("headerArchitecture"),
1857
            updateQuery.get("distVersion"),
1858
            updateQuery.get("force"),
1859
        )
1860
        rules = cache.get("rules", cache_key, getRawMatches)
1✔
1861

1862
        self.log.debug("Raw matches:")
1✔
1863

1864
        matchingRules = []
1✔
1865
        for rule in rules:
1✔
1866
            self.log.debug(rule)
1✔
1867

1868
            # Resolve special means for channel, version, and buildID - dropping
1869
            # rules that don't match after resolution.
1870
            if not matchChannel(rule["channel"], updateQuery["channel"], fallbackChannel):
1✔
1871
                self.log.debug("%s doesn't match %s", rule["channel"], updateQuery["channel"])
1✔
1872
                continue
1✔
1873
            if not matchVersion(rule["version"], updateQuery["version"], get_version_class(updateQuery["product"])):
1✔
1874
                self.log.debug("%s doesn't match %s", rule["version"], updateQuery["version"])
1✔
1875
                continue
1✔
1876
            if not matchBuildID(rule["buildID"], updateQuery.get("buildID", "")):
1✔
1877
                self.log.debug("%s doesn't match %s", rule["buildID"], updateQuery.get("buildID"))
1✔
1878
                continue
1✔
1879
            if not matchMemory(rule["memory"], updateQuery.get("memory")):
1✔
1880
                self.log.debug("%s doesn't match %s", rule["memory"], updateQuery.get("memory"))
1✔
1881
                continue
1✔
1882
            # To help keep the rules table compact, multiple OS versions may be
1883
            # specified in a single rule. They are comma delimited, so we need to
1884
            # break them out and create clauses for each one.
1885
            if not matchSimpleExpression(rule["osVersion"], updateQuery.get("osVersion", "")):
1✔
1886
                self.log.debug("%s doesn't match %s", rule["osVersion"], updateQuery.get("osVersion"))
1✔
1887
                continue
1✔
1888
            if not matchCsv(rule["instructionSet"], updateQuery.get("instructionSet", ""), substring=False):
1✔
1889
                self.log.debug("%s doesn't match %s", rule["instructionSet"], updateQuery.get("instructionSet"))
1✔
1890
                continue
1✔
1891
            if not matchCsv(rule["distribution"], updateQuery.get("distribution", ""), substring=False):
1✔
1892
                self.log.debug("%s doesn't match %s", rule["distribution"], updateQuery.get("distribution"))
1✔
1893
                continue
1✔
1894
            # Locales may be a comma delimited rule too, exact matches only
1895
            if not matchLocale(rule["locale"], updateQuery.get("locale", "")):
1✔
1896
                self.log.debug("%s doesn't match %s", rule["locale"], updateQuery.get("locale"))
1✔
1897
                continue
1✔
1898
            if not matchBoolean(rule["mig64"], updateQuery.get("mig64")):
1✔
1899
                self.log.debug("%s doesn't match %s", rule["mig64"], updateQuery.get("mig64"))
1✔
1900
                continue
1✔
1901
            if not matchBoolean(rule["jaws"], updateQuery.get("jaws")):
1✔
1902
                self.log.debug("%s doesn't match %s", rule["jaws"], updateQuery.get("jaws"))
1✔
1903
                continue
1✔
1904

1905
            matchingRules.append(rule)
1✔
1906

1907
        self.log.debug("Reduced matches:")
1✔
1908
        if self.log.isEnabledFor(logging.DEBUG):
1!
1909
            for r in matchingRules:
×
1910
                self.log.debug(r)
×
1911
        return matchingRules
1✔
1912

1913
    def getRule(self, id_or_alias, transaction=None):
1✔
1914
        """Returns the unique rule that matches the give rule_id or alias."""
1915
        where = []
1✔
1916
        # Figuring out which column to use ahead of times means there's only
1917
        # one potential index for the database to use, which should make
1918
        # queries faster (it will always use the most efficient one).
1919
        if self._isAlias(id_or_alias):
1✔
1920
            where.append(self.alias == id_or_alias)
1✔
1921
        else:
1922
            where.append(self.rule_id == id_or_alias)
1✔
1923

1924
        rules = self.select(where=where, transaction=transaction)
1✔
1925
        found = len(rules)
1✔
1926
        if found > 1 or found == 0:
1✔
1927
            self.log.debug("Found %s rules, should have been 1", found)
1✔
1928
            return None
1✔
1929
        return rules[0]
1✔
1930

1931
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
1932
        # Rather than forcing callers to figure out whether the identifier
1933
        # they have is an id or an alias, we handle it here.
1934
        if "rule_id" in where and self._isAlias(where["rule_id"]):
1✔
1935
            where["alias"] = where["rule_id"]
1✔
1936
            del where["rule_id"]
1✔
1937

1938
        # If the product is being changed, we also need to make sure the user
1939
        # permission to modify _that_ product.
1940
        if "product" in what:
1✔
1941
            if not self.db.hasPermission(changed_by, "rule", "modify", what["product"], transaction):
1✔
1942
                raise PermissionDeniedError("%s is not allowed to modify rules for product %s" % (changed_by, what["product"]))
1✔
1943

1944
        for current_rule in self.select(where=where, transaction=transaction):
1✔
1945
            if not self.db.hasPermission(changed_by, "rule", "modify", current_rule["product"], transaction):
1✔
1946
                raise PermissionDeniedError("%s is not allowed to modify rules for product %s" % (changed_by, current_rule["product"]))
1✔
1947

1948
            new_rule = current_rule.copy()
1✔
1949
            new_rule.update(what)
1✔
1950
            if not dryrun:
1✔
1951
                potential_required_signoffs = [
1✔
1952
                    obj for v in self.getPotentialRequiredSignoffs([current_rule, new_rule], transaction=transaction).values() for obj in v
1953
                ]
1954
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
1955

1956
        return super(Rules, self).update(
1✔
1957
            changed_by=changed_by, where=where, what=what, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
1958
        )
1959

1960
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
1961
        if "rule_id" in where and self._isAlias(where["rule_id"]):
1✔
1962
            where["alias"] = where["rule_id"]
1✔
1963
            del where["rule_id"]
1✔
1964

1965
        product = self.select(where=where, columns=[self.product], transaction=transaction)[0]["product"]
1✔
1966
        if not self.db.hasPermission(changed_by, "rule", "delete", product, transaction):
1✔
1967
            raise PermissionDeniedError("%s is not allowed to delete rules for product %s" % (changed_by, product))
1✔
1968

1969
        if not dryrun:
1✔
1970
            for current_rule in self.select(where=where, transaction=transaction):
1✔
1971
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([current_rule], transaction=transaction).values() for obj in v]
1✔
1972
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
1973

1974
        super(Rules, self).delete(changed_by=changed_by, where=where, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun)
1✔
1975

1976

1977
class Releases(AUSTable):
1✔
1978
    def __init__(self, db, metadata, dialect, history_buckets, historyClass):
1✔
1979
        self.domainAllowlist = []
1✔
1980

1981
        self.table = Table(
1✔
1982
            "releases",
1983
            metadata,
1984
            Column("name", String(100), primary_key=True),
1985
            Column("product", String(15), nullable=False),
1986
            Column("read_only", Boolean, default=False),
1987
        )
1988
        if dialect == "mysql":
1!
1989
            from sqlalchemy.dialects.mysql import LONGTEXT
×
1990

1991
            dataType = LONGTEXT
×
1992
        else:
1993
            dataType = Text
1✔
1994
        self.table.append_column(Column("data", BlobColumn(dataType), nullable=False))
1✔
1995
        historyKwargs = {}
1✔
1996
        if history_buckets:
1✔
1997
            historyKwargs["buckets"] = history_buckets
1✔
1998
            historyKwargs["identifier_columns"] = ["name"]
1✔
1999
            historyKwargs["data_column"] = "data"
1✔
2000
        else:
2001
            # Can't have history without a bucket
2002
            historyClass = None
1✔
2003
        AUSTable.__init__(
1✔
2004
            self, db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=historyClass, historyKwargs=historyKwargs
2005
        )
2006

2007
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2008
        potential_required_signoffs = {}
1✔
2009
        rows = []
1✔
2010
        for row in affected_rows:
1✔
2011
            if not row:
1!
2012
                continue
×
2013
            rows.append(row)
1✔
2014
        info = self.getReleaseInfo(names=[row["name"] for row in rows], transaction=transaction)
1✔
2015
        # Releases do not affect live updates on their own, only the
2016
        # product+channel combinations specified in Rules that point
2017
        # to them. We need to find these Rules, and then return _their_
2018
        # Required Signoffs.
2019
        if info:
1✔
2020
            relevant_rules = [rule_info for row in info for rule_info in row["rule_info"].values()]
1✔
2021

2022
            # get all rs as one query
2023
            all_rs = self.db.rules.getPotentialRequiredSignoffs(relevant_rules, transaction=transaction)
1✔
2024

2025
            for row in info:
1✔
2026
                rs = []
1✔
2027
                potential_required_signoffs[row["name"]] = []
1✔
2028
                for rule in row["rule_info"].values():
1✔
2029
                    _rs = all_rs[(rule["product"], rule["channel"])]
1✔
2030
                    rs.extend(_rs)
1✔
2031
                potential_required_signoffs[row["name"]] = rs
1✔
2032
        else:
2033
            potential_required_signoffs["rs"] = []
1✔
2034
        return potential_required_signoffs
1✔
2035

2036
    def getPotentialRequiredSignoffsForProduct(self, product, transaction=None):
1✔
2037
        potential_required_signoffs = {"rs": []}
1✔
2038
        where = [self.db.productRequiredSignoffs.product == product]
1✔
2039
        product_rs = self.db.productRequiredSignoffs.select(where=where, transaction=transaction)
1✔
2040
        if product_rs:
1✔
2041
            role_map = defaultdict(list)
1✔
2042
            for rs in product_rs:
1✔
2043
                role_map[rs["role"]].append(rs)
1✔
2044
            signoffs_required = [max(signoffs, default=None, key=lambda k: k["signoffs_required"]) for signoffs in role_map.values()]
1✔
2045
            potential_required_signoffs["rs"] = signoffs_required
1✔
2046
        return potential_required_signoffs
1✔
2047

2048
    def setDomainAllowlist(self, domainAllowlist):
1✔
2049
        self.domainAllowlist = domainAllowlist
1✔
2050

2051
    def getReleases(self, name=None, product=None, limit=None, transaction=None):
1✔
2052
        self.log.debug("Looking for releases with:")
1✔
2053
        self.log.debug("name: %s", name)
1✔
2054
        self.log.debug("product: %s", product)
1✔
2055
        where = []
1✔
2056
        if name:
1✔
2057
            where.append(self.name == name)
1✔
2058
        if product:
1!
2059
            where.append(self.product == product)
×
2060
        # We could get the "data" column here too, but getReleaseBlob knows how
2061
        # to grab cached versions of that, so it's better to let it take care
2062
        # of it.
2063
        rows = self.select(columns=[self.name, self.product, self.data_version], where=where, limit=limit, transaction=transaction)
1✔
2064
        for row in rows:
1✔
2065
            row["data"] = self.getReleaseBlob(row["name"], transaction)
1✔
2066
        return rows
1✔
2067

2068
    def getReleaseInfo(self, names=None, product=None, limit=None, transaction=None, nameOnly=False, name_prefix=None):
1✔
2069
        where = []
1✔
2070
        if names:
1✔
2071
            where.append(self.name.in_(tuple(names)))
1✔
2072
        if product:
1✔
2073
            where.append(self.product == product)
1✔
2074
        if name_prefix:
1✔
2075
            where.append(self.name.startswith(name_prefix))
1✔
2076
        if nameOnly:
1✔
2077
            column = [self.name]
1✔
2078
        else:
2079
            column = [self.name, self.product, self.data_version, self.read_only]
1✔
2080

2081
        rows = self.select(where=where, columns=column, limit=limit, transaction=transaction)
1✔
2082

2083
        if not nameOnly:
1✔
2084
            j = join(
1✔
2085
                self.db.releases.t,
2086
                self.db.rules.t,
2087
                ((self.db.releases.name == self.db.rules.mapping) | (self.db.releases.name == self.db.rules.fallbackMapping)),
2088
            )
2089
            if transaction:
1✔
2090
                ref_list = transaction.execute(
1✔
2091
                    select([self.db.releases.name, self.db.rules.rule_id, self.db.rules.product, self.db.rules.channel]).select_from(j)
2092
                ).fetchall()
2093
            else:
2094
                ref_list = (
1✔
2095
                    self.getEngine()
2096
                    .execute(select([self.db.releases.name, self.db.rules.rule_id, self.db.rules.product, self.db.rules.channel]).select_from(j))
2097
                    .fetchall()
2098
                )
2099

2100
            for row in rows:
1✔
2101
                refs = [ref for ref in ref_list if ref[0] == row["name"]]
1✔
2102
                ref_list = [ref for ref in ref_list if ref[0] != row["name"]]
1✔
2103
                row["rule_ids"] = [ref[1] for ref in refs]
1✔
2104
                row["rule_info"] = {str(ref[1]): {"product": ref[2], "channel": ref[3]} for ref in refs}
1✔
2105

2106
        return rows
1✔
2107

2108
    def getReleaseNames(self, **kwargs):
1✔
2109
        return self.getReleaseInfo(nameOnly=True, **kwargs)
1✔
2110

2111
    def getReleaseBlob(self, name, transaction=None):
1✔
2112
        # Putting the data_version and blob getters into these methods lets us
2113
        # delegate the decision about whether or not to use the cached values
2114
        # to the cache class. It will either return as a cached value, or use
2115
        # the getter to return a fresh value (and cache it).
2116
        def getDataVersion():
1✔
2117
            try:
1✔
2118
                return self.select(where=[self.name == name], columns=[self.data_version], limit=1, transaction=transaction)[0]
1✔
2119
            except IndexError:
1✔
2120
                raise KeyError("Couldn't find release with name '%s'" % name)
1✔
2121

2122
        data_version = cache.get("blob_version", name, getDataVersion)
1✔
2123

2124
        def getBlob():
1✔
2125
            try:
1✔
2126
                row = self.select(where=[self.name == name], columns=[self.data], limit=1, transaction=transaction)[0]
1✔
2127
                blob = row["data"]
1✔
2128
                return {"data_version": data_version, "blob": blob}
1✔
2129
            except IndexError:
×
2130
                raise KeyError("Couldn't find release with name '%s'" % name)
×
2131

2132
        def get_data_version(obj):
1✔
2133
            if isinstance(obj, int):
1✔
2134
                return obj
1✔
2135
            return obj["data_version"]
1✔
2136

2137
        cached_blob = cache.get("blob", name, getBlob)
1✔
2138

2139
        # Even though we may have retrieved a cached blob, we need to make sure
2140
        # that it's not older than the one in the database. If the data version
2141
        # of the cached blob and the latest data version don't match, we need
2142
        # to update the cache with the latest blob.
2143
        if get_data_version(data_version) > get_data_version(cached_blob["data_version"]):
1!
2144
            blob_info = getBlob()
×
2145
            cache.put("blob", name, blob_info)
×
2146
            blob = blob_info["blob"]
×
2147
        else:
2148
            # And while it's extremely unlikely, there is a remote possibility
2149
            # that the cached blob actually has a newer data version than the
2150
            # blob version cache. This can occur if the blob cache expired
2151
            # between retrieving the cached data version and cached blob.
2152
            # (Because the blob version cache ttl should be shorter than the
2153
            # blob cache ttl, if the blob cache expired prior to retrieving the
2154
            # data version, the blob version cache would've expired as well.
2155
            # If we hit one of these cases, we should bring the blob version
2156
            # cache up to date since we have it.
2157
            if get_data_version(cached_blob["data_version"]) > get_data_version(data_version):
1!
2158
                cache.put("blob_version", name, data_version)
×
2159
            blob = cached_blob["blob"]
1✔
2160

2161
        return blob
1✔
2162

2163
    def insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
2164
        if "name" not in columns or "product" not in columns or "data" not in columns:
1!
2165
            raise ValueError("name, product, and data are all required")
×
2166

2167
        blob = columns["data"]
1✔
2168

2169
        blob.validate(columns["product"], self.domainAllowlist)
1✔
2170
        if columns["name"] != blob["name"]:
1✔
2171
            raise ValueError("name in database (%s) does not match name in blob (%s)" % (columns["name"], blob["name"]))
1✔
2172

2173
        if not self.db.hasPermission(changed_by, "release", "create", columns["product"], transaction):
1✔
2174
            raise PermissionDeniedError("%s is not allowed to create releases for product %s" % (changed_by, columns["product"]))
1✔
2175

2176
        if not dryrun:
1✔
2177
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
2178
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2179

2180
        ret = super(Releases, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2181
        if not dryrun:
1✔
2182
            cache.put("blob", columns["name"], {"data_version": 1, "blob": blob})
1✔
2183
            cache.put("blob_version", columns["name"], 1)
1✔
2184
            return ret.inserted_primary_key[0]
1✔
2185

2186
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2187
        blob = what.get("data")
1✔
2188

2189
        current_releases = self.select(where=where, columns=[self.name, self.product, self.read_only], transaction=transaction)
1✔
2190
        for current_release in current_releases:
1✔
2191
            name = current_release["name"]
1✔
2192
            is_readonly_change = "read_only" in what and current_release["read_only"] != what["read_only"]
1✔
2193

2194
            if not is_readonly_change:
1✔
2195
                if "product" in what or "data" in what:
1!
2196
                    self._proceedIfNotReadOnly(current_release["name"], transaction=transaction)
1✔
2197

2198
                if blob:
1✔
2199
                    blob.validate(what.get("product", current_release["product"]), self.domainAllowlist)
1✔
2200
                    name = what.get("name", name)
1✔
2201
                    if name != blob["name"]:
1✔
2202
                        raise ValueError("name in database (%s) does not match name in blob (%s)" % (name, blob.get("name")))
1✔
2203

2204
                if not self.db.hasPermission(changed_by, "release", "modify", current_release["product"], transaction):
1✔
2205
                    raise PermissionDeniedError("%s is not allowed to modify releases for product %s" % (changed_by, current_release["product"]))
1✔
2206

2207
                if "product" in what:
1✔
2208
                    # If the product is being changed, we need to make sure the user
2209
                    # has permission to modify releases of that product, too.
2210
                    if not self.db.hasPermission(changed_by, "release", "modify", what["product"], transaction):
1✔
2211
                        raise PermissionDeniedError("%s is not allowed to modify releases for product %s" % (changed_by, what["product"]))
1✔
2212

2213
                new_release = current_release.copy()
1✔
2214
                new_release.update(what)
1✔
2215
                if not dryrun:
1✔
2216
                    potential_required_signoffs = [
1✔
2217
                        obj for v in self.getPotentialRequiredSignoffs([current_release, new_release], transaction=transaction).values() for obj in v
2218
                    ]
2219
                    verify_signoffs(potential_required_signoffs, signoffs)
1✔
2220
            else:
2221
                self.validate_readonly_change(
1✔
2222
                    where, what["read_only"], changed_by, release=current_release, transaction=transaction, dryrun=dryrun, signoffs=signoffs
2223
                )
2224

2225
        for release in current_releases:
1✔
2226
            name = current_release["name"]
1✔
2227
            new_data_version = old_data_version + 1
1✔
2228
            try:
1✔
2229
                super(Releases, self).update(
1✔
2230
                    where={"name": name}, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2231
                )
2232
            except OutdatedDataError as e:
1✔
2233
                self.log.warning("Trying to merge update to release %s at data_version %s with the latest version.", name, old_data_version)
1✔
2234
                if blob is not None:
1!
2235
                    ancestor_change = self.history.getChange(data_version=old_data_version, column_values={"name": name}, transaction=transaction)
1✔
2236
                    # if we have no historical information about the ancestor blob
2237
                    if ancestor_change is None:
1!
2238
                        self.log.exception("Couldn't find history for release %s at data_version %s", name, old_data_version)
×
2239
                        raise
×
2240
                    ancestor_blob = ancestor_change.get("data")
1✔
2241
                    tip_release = self.getReleases(name=name, transaction=transaction)[0]
1✔
2242
                    tip_blob = tip_release.get("data")
1✔
2243
                    try:
1✔
2244
                        what["data"] = createBlob(merge_dicts(ancestor_blob, tip_blob, blob))
1✔
2245
                        self.log.warning("Successfully merged release %s at data_version %s with the latest version.", name, old_data_version)
1✔
2246
                        # ancestor_change is checked for None a few lines up
2247
                        self.log.warning(
1✔
2248
                            "ancestor_change is change_id %s, data_version %s", ancestor_change.get("change_id"), ancestor_change.get("data_version")
2249
                        )
2250
                        self.log.warning("tip release is data_version %s", tip_release.get("data_version"))
1✔
2251
                    except ValueError:
1✔
2252
                        self.log.exception("Couldn't merge release %s at data_version %s with the latest version.", name, old_data_version)
1✔
2253
                        # ancestor_change is checked for None a few lines up
2254
                        self.log.warning(
1✔
2255
                            "ancestor_change is change_id %s, data_version %s", ancestor_change.get("change_id"), ancestor_change.get("data_version")
2256
                        )
2257
                        self.log.warning("tip release is data_version %s", tip_release.get("data_version"))
1✔
2258
                        raise e
1✔
2259
                    # we want the data_version for the dictdiffer.merged blob to be one
2260
                    # more than that of the latest blob
2261
                    tip_data_version = tip_release["data_version"]
1✔
2262
                    super(Releases, self).update(
1✔
2263
                        where={"name": name}, what=what, changed_by=changed_by, old_data_version=tip_data_version, transaction=transaction, dryrun=dryrun
2264
                    )
2265
                    # cache will have a data_version of one plus the tip
2266
                    # data_version
2267
                    new_data_version = tip_data_version + 1
1✔
2268

2269
            if not dryrun:
1✔
2270
                cache.put("blob", name, {"data_version": new_data_version, "blob": blob})
1✔
2271
                cache.put("blob_version", name, new_data_version)
1✔
2272

2273
    def addLocaleToRelease(self, name, product, platform, locale, data, old_data_version, changed_by, transaction=None, alias=None):
1✔
2274
        """Adds or update's the existing data for a specific platform + locale
2275
        combination, in the release identified by 'name'. The data is
2276
        validated before commiting it, and a ValueError is raised if it is
2277
        invalid.
2278
        """
2279
        self._proceedIfNotReadOnly(name, transaction=transaction)
1✔
2280

2281
        where = [self.name == name]
1✔
2282
        product = self.select(where=where, columns=[self.product], transaction=transaction)[0]["product"]
1✔
2283
        if not self.db.hasPermission(changed_by, "release_locale", "modify", product, transaction):
1✔
2284
            raise PermissionDeniedError("%s is not allowed to add builds for product %s" % (changed_by, product))
1✔
2285

2286
        releaseBlob = self.getReleaseBlob(name, transaction=transaction)
1✔
2287
        if "platforms" not in releaseBlob:
1✔
2288
            releaseBlob["platforms"] = {}
1✔
2289

2290
        if platform in releaseBlob["platforms"]:
1✔
2291
            # If the platform we're given is aliased to another one, we need
2292
            # to resolve that before doing any updating. If we don't, the data
2293
            # will go into an aliased platform and be ignored!
2294
            platform = releaseBlob.getResolvedPlatform(platform)
1✔
2295

2296
        if platform not in releaseBlob["platforms"]:
1✔
2297
            releaseBlob["platforms"][platform] = {}
1✔
2298

2299
        if "locales" not in releaseBlob["platforms"][platform]:
1✔
2300
            releaseBlob["platforms"][platform]["locales"] = {}
1✔
2301

2302
        releaseBlob["platforms"][platform]["locales"][locale] = data
1✔
2303

2304
        # we don't allow modification of existing platforms (aliased or not)
2305
        if alias:
1✔
2306
            for a in alias:
1✔
2307
                if a not in releaseBlob["platforms"]:
1!
2308
                    releaseBlob["platforms"][a] = {"alias": platform}
1✔
2309

2310
        releaseBlob.validate(product, self.domainAllowlist)
1✔
2311
        what = dict(data=releaseBlob)
1✔
2312

2313
        self.update(where=where, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction)
1✔
2314
        new_data_version = old_data_version + 1
1✔
2315
        cache.put("blob", name, {"data_version": new_data_version, "blob": releaseBlob})
1✔
2316
        cache.put("blob_version", name, new_data_version)
1✔
2317

2318
    def getLocale(self, name, platform, locale, transaction=None):
1✔
2319
        try:
1✔
2320
            blob = self.getReleaseBlob(name, transaction=transaction)
1✔
2321
            return blob["platforms"][platform]["locales"][locale]
1✔
2322
        except KeyError:
1✔
2323
            raise KeyError("Couldn't find locale identified by: %s, %s, %s" % (name, platform, locale))
1✔
2324

2325
    def localeExists(self, name, platform, locale, transaction=None):
1✔
2326
        try:
1✔
2327
            self.getLocale(name, platform, locale, transaction)
1✔
2328
            return True
1✔
2329
        except KeyError:
1✔
2330
            return False
1✔
2331

2332
    def isMappedTo(self, name, transaction=None):
1✔
2333
        mapping_count = self.db.rules.count(where=[self.db.rules.mapping == name], transaction=transaction)
1✔
2334
        fallbackMapping_count = self.db.rules.count(where=[self.db.rules.fallbackMapping == name], transaction=transaction)
1✔
2335
        return mapping_count > 0 or fallbackMapping_count > 0
1✔
2336

2337
    def delete(self, where, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2338
        release = self.select(where=where, columns=[self.name, self.product], transaction=transaction)
1✔
2339
        if len(release) != 1:
1✔
2340
            raise ValueError("Where clause must match exactly one release to delete.")
1✔
2341
        release = release[0]
1✔
2342

2343
        if self.isMappedTo(release["name"], transaction):
1✔
2344
            msg = "%s has rules pointing to it. Hence it cannot be deleted." % (release["name"])
1✔
2345
            raise ValueError(msg)
1✔
2346

2347
        self._proceedIfNotReadOnly(release["name"], transaction=transaction)
1✔
2348
        if not self.db.hasPermission(changed_by, "release", "delete", release["product"], transaction):
1✔
2349
            raise PermissionDeniedError("%s is not allowed to delete releases for product %s" % (changed_by, release["product"]))
1✔
2350

2351
        if not dryrun:
1✔
2352
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([release], transaction=transaction).values() for obj in v]
1✔
2353
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2354

2355
        super(Releases, self).delete(where=where, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun)
1✔
2356
        if not dryrun:
1✔
2357
            cache.invalidate("blob", release["name"])
1✔
2358
            cache.invalidate("blob_version", release["name"])
1✔
2359

2360
    def isReadOnly(self, name, limit=None, transaction=None):
1✔
2361
        where = [self.name == name]
1✔
2362
        column = [self.read_only]
1✔
2363
        row = self.select(where=where, columns=column, limit=limit, transaction=transaction)[0]
1✔
2364
        return row["read_only"]
1✔
2365

2366
    def _proceedIfNotReadOnly(self, name, limit=None, transaction=None):
1✔
2367
        if self.isReadOnly(name, limit, transaction):
1✔
2368
            raise ReadOnlyError("Release '%s' is read only" % name)
1✔
2369

2370
    def validate_readonly_change(self, where, new_readonly_state, changed_by, release=None, transaction=None, dryrun=False, signoffs=None):
1✔
2371
        if not release:
1✔
2372
            release = self.select(where=where, columns=[self.name, self.product, self.read_only], transaction=transaction)[0]
1✔
2373

2374
        product = release["product"]
1✔
2375

2376
        if new_readonly_state:
1✔
2377
            if not self.db.hasPermission(changed_by, "release_read_only", "set", product, transaction):
1✔
2378
                raise PermissionDeniedError(f"{changed_by} is not allowed to mark {product} products read only")
1✔
2379
        else:
2380
            if not self.db.hasPermission(changed_by, "release_read_only", "unset", product, transaction):
1✔
2381
                raise PermissionDeniedError(f"{changed_by} is not allowed to mark {product} products read write")
1✔
2382

2383
        # if release is moving from ro->rw
2384
        if not dryrun and release["read_only"] and not new_readonly_state:
1✔
2385

2386
            def _map_required_signoffs(required_signoffs):
1✔
2387
                return [obj for v in required_signoffs for obj in v]
1✔
2388

2389
            # If release is associated with a rule, get the required signoffs for the rule's product/channel.
2390
            potential_required_signoffs = _map_required_signoffs(self.getPotentialRequiredSignoffs([release], transaction=transaction).values())
1✔
2391

2392
            # If no required signoffs is found, get the required signoffs considering all products/channels for the given release product.
2393
            if not potential_required_signoffs:
1✔
2394
                potential_required_signoffs = _map_required_signoffs(
1✔
2395
                    self.getPotentialRequiredSignoffsForProduct(release["product"], transaction=transaction).values()
2396
                )
2397
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2398

2399
    def change_readonly(self, where, is_readonly, changed_by, old_data_version, transaction=None):
1✔
2400
        self.validate_readonly_change(where, is_readonly, changed_by, transaction=transaction)
1✔
2401
        super().update(where, {"read_only": is_readonly}, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction)
1✔
2402

2403

2404
class ReleasesJSON(AUSTable):
1✔
2405
    def __init__(self, db, metadata, dialect, history_buckets, historyClass):
1✔
2406
        self.domainAllowlist = []
1✔
2407

2408
        self.table = Table(
1✔
2409
            "releases_json",
2410
            metadata,
2411
            Column("name", String(100), primary_key=True),
2412
            Column("product", String(15), nullable=False),
2413
            Column("read_only", Boolean, default=False),
2414
            Column("data", JSON),
2415
        )
2416
        historyKwargs = {}
1✔
2417
        if history_buckets:
1✔
2418
            historyKwargs["buckets"] = history_buckets
1✔
2419
            historyKwargs["identifier_columns"] = ["name"]
1✔
2420
            historyKwargs["data_column"] = "data"
1✔
2421
        else:
2422
            # Can't have history without a bucket
2423
            historyClass = None
1✔
2424
        super(ReleasesJSON, self).__init__(
1✔
2425
            db,
2426
            dialect,
2427
            scheduled_changes=True,
2428
            scheduled_changes_kwargs={"conditions": ["time"]},
2429
            historyClass=historyClass,
2430
            historyKwargs=historyKwargs,
2431
        )
2432

2433
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2434
        potential_required_signoffs = defaultdict(list)
1✔
2435

2436
        for release in affected_rows:
1✔
2437
            stmt = select([self.db.rules.rule_id, self.db.rules.product, self.db.rules.channel]).where(
1✔
2438
                ((self.db.releases_json.name == self.db.rules.mapping) | (self.db.releases_json.name == self.db.rules.fallbackMapping))
2439
                & (self.db.releases_json.name == release["name"])
2440
            )
2441

2442
            if transaction:
1!
2443
                rule_info = transaction.execute(stmt).fetchall()
1✔
2444
            else:
2445
                rule_info = self.getEngine().execute(stmt).fetchall()
×
2446

2447
            rule_required_signoffs = self.db.rules.getPotentialRequiredSignoffs([dict(r) for r in rule_info], transaction)
1✔
2448

2449
            for rule in rule_info:
1✔
2450
                rs = rule_required_signoffs[(rule["product"], rule["channel"])]
1✔
2451
                potential_required_signoffs[release["name"]].extend(rs)
1✔
2452

2453
        return potential_required_signoffs
1✔
2454

2455
    def getPotentialRequiredSignoffsForProduct(self, product, transaction=None):
1✔
2456
        potential_required_signoffs = {"rs": []}
×
2457
        where = [self.db.productRequiredSignoffs.product == product]
×
2458
        product_rs = self.db.productRequiredSignoffs.select(where=where, transaction=transaction)
×
2459
        if product_rs:
×
2460
            role_map = defaultdict(list)
×
2461
            for rs in product_rs:
×
2462
                role_map[rs["role"]].append(rs)
×
2463
            signoffs_required = [max(signoffs, default=None, key=lambda k: k["signoffs_required"]) for signoffs in role_map.values()]
×
2464
            potential_required_signoffs["rs"] = signoffs_required
×
2465
        return potential_required_signoffs
×
2466

2467
    async def async_insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
2468
        if not dryrun:
1!
2469
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
2470
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2471

2472
        return await super(ReleasesJSON, self).async_insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2473

2474
    async def async_update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2475
        for row in self.select(where=where, transaction=transaction):
1✔
2476
            new_row = row.copy()
1✔
2477
            new_row.update(what)
1✔
2478
            is_readonly_change = row["data"] == new_row["data"] and "read_only" in what and row["read_only"] != what["read_only"]
1✔
2479

2480
            # Only do signoff checks when the data is being changed, or we're moving from read-only to read-write
2481
            if not is_readonly_change or what["read_only"] is False:
1✔
2482
                if not dryrun:
1!
2483
                    potential_required_signoffs = [
1✔
2484
                        obj for v in self.getPotentialRequiredSignoffs([row, new_row], transaction=transaction).values() for obj in v
2485
                    ]
2486
                    verify_signoffs(potential_required_signoffs, signoffs)
1✔
2487

2488
        return await super(ReleasesJSON, self).async_update(
1✔
2489
            where=where, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2490
        )
2491

2492
    async def async_delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
2493
        if not dryrun:
1!
2494
            for row in self.select(where=where, transaction=transaction):
1✔
2495
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([row], transaction=transaction).values() for obj in v]
1✔
2496
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2497

2498
        return await super(ReleasesJSON, self).async_delete(
1✔
2499
            where=where, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2500
        )
2501

2502

2503
class ReleaseAssets(AUSTable):
1✔
2504
    def __init__(self, db, metadata, dialect, history_buckets, historyClass):
1✔
2505
        self.table = Table(
1✔
2506
            "release_assets",
2507
            metadata,
2508
            Column("name", String(100), primary_key=True),
2509
            Column("path", String(200), primary_key=True),
2510
            Column("data", JSON),
2511
        )
2512
        historyKwargs = {}
1✔
2513
        if history_buckets:
1✔
2514
            historyKwargs["buckets"] = history_buckets
1✔
2515
            historyKwargs["identifier_columns"] = ["name", "path"]
1✔
2516
            historyKwargs["data_column"] = "data"
1✔
2517
        else:
2518
            # Can't have history without a bucket
2519
            historyClass = None
1✔
2520

2521
        super(ReleaseAssets, self).__init__(
1✔
2522
            db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=historyClass, historyKwargs=historyKwargs
2523
        )
2524

2525
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2526
        potential_required_signoffs = defaultdict(list)
1✔
2527

2528
        for release in affected_rows:
1✔
2529
            stmt = select([self.db.rules.rule_id, self.db.rules.product, self.db.rules.channel]).where(
1✔
2530
                ((self.db.release_assets.name == self.db.rules.mapping) | (self.db.release_assets.name == self.db.rules.fallbackMapping))
2531
                & (self.db.release_assets.name == release["name"])
2532
                & (self.db.release_assets.path == release["path"])
2533
            )
2534

2535
            if transaction:
1!
2536
                rule_info = transaction.execute(stmt).fetchall()
1✔
2537
            else:
2538
                rule_info = self.getEngine().execute(stmt).fetchall()
×
2539

2540
            rule_required_signoffs = self.db.rules.getPotentialRequiredSignoffs([dict(r) for r in rule_info], transaction)
1✔
2541

2542
            for rule in rule_info:
1✔
2543
                rs = rule_required_signoffs[(rule["product"], rule["channel"])]
1✔
2544
                potential_required_signoffs[(release["name"], release["path"])].extend(rs)
1✔
2545

2546
        return potential_required_signoffs
1✔
2547

2548
    async def async_insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
2549
        if not dryrun:
1!
2550
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
2551
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2552

2553
        return await super(ReleaseAssets, self).async_insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2554

2555
    async def async_update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2556
        for row in self.select(where=where, transaction=transaction):
1✔
2557
            new_row = row.copy()
1✔
2558
            new_row.update(what)
1✔
2559

2560
            if not dryrun:
1!
2561
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([row, new_row], transaction=transaction).values() for obj in v]
1✔
2562
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2563

2564
        return await super(ReleaseAssets, self).async_update(
1✔
2565
            where=where, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2566
        )
2567

2568
    async def async_delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
2569
        if not dryrun:
1!
2570
            for row in self.select(where=where, transaction=transaction):
1✔
2571
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([row], transaction=transaction).values() for obj in v]
1✔
2572
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2573

2574
        return await super(ReleaseAssets, self).async_delete(
1✔
2575
            where=where, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2576
        )
2577

2578

2579
class UserRoles(AUSTable):
1✔
2580
    def __init__(self, db, metadata, dialect):
1✔
2581
        self.table = Table("user_roles", metadata, Column("username", String(100), primary_key=True), Column("role", String(50), primary_key=True))
1✔
2582
        super(UserRoles, self).__init__(db, dialect, historyClass=HistoryTable)
1✔
2583

2584
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False):
1✔
2585
        raise AttributeError("User roles cannot be modified (only granted and revoked)")
1✔
2586

2587

2588
class Permissions(AUSTable):
1✔
2589
    """allPermissions defines the structure and possible options for all
2590
    available permissions. Permissions can be limited to specific types
2591
    of actions. Eg: granting the "rule" permission with "actions" set to
2592
    ["create"] allows rules to be created but not modified or deleted.
2593
    Permissions that relate to rules or releases can be further limited
2594
    by product. Eg: granting the "release" permission with "products" set
2595
    to ["GMP"] allows the user to modify GMP releases, but not Firefox."""
2596

2597
    allPermissions = {
1✔
2598
        "admin": ["products"],
2599
        "emergency_shutoff": ["actions", "products"],
2600
        "release": ["actions", "products"],
2601
        "release_locale": ["actions", "products"],
2602
        "release_read_only": ["actions", "products"],
2603
        "rule": ["actions", "products"],
2604
        "pinnable_release": ["products"],
2605
        "permission": ["actions"],
2606
        "required_signoff": ["products"],
2607
        "scheduled_change": ["actions"],
2608
    }
2609

2610
    def __init__(self, db, metadata, dialect):
1✔
2611
        self.table = Table(
1✔
2612
            "permissions",
2613
            metadata,
2614
            Column("permission", String(50), primary_key=True),
2615
            Column("username", String(100), primary_key=True),
2616
            Column("options", JSONColumn),
2617
        )
2618
        self.user_roles = UserRoles(db, metadata, dialect)
1✔
2619
        AUSTable.__init__(self, db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=HistoryTable)
1✔
2620

2621
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2622
        potential_required_signoffs = {"rs": []}
1✔
2623
        for row in affected_rows:
1✔
2624
            if not row:
1!
2625
                continue
×
2626
            # XXX: This kindof sucks because it means that we don't have great control
2627
            # over the signoffs required permissions that don't specify products, or
2628
            # don't support them.
2629
            if "products" in self.allPermissions[row["permission"]] and row.get("options") and row["options"].get("products"):
1✔
2630
                for product in row["options"]["products"]:
1✔
2631
                    potential_required_signoffs["rs"].extend(self.db.permissionsRequiredSignoffs.select(where={"product": product}, transaction=transaction))
1✔
2632
            else:
2633
                potential_required_signoffs["rs"].extend(self.db.permissionsRequiredSignoffs.select(transaction=transaction))
1✔
2634
        return potential_required_signoffs
1✔
2635

2636
    def assertPermissionExists(self, permission):
1✔
2637
        if permission not in self.allPermissions.keys():
1✔
2638
            raise ValueError('Unknown permission "%s"' % permission)
1✔
2639

2640
    def assertOptionsExist(self, permission, options):
1✔
2641
        for opt in options:
1✔
2642
            if opt not in self.allPermissions[permission]:
1✔
2643
                raise ValueError('Unknown option "%s" for permission "%s"' % (opt, permission))
1✔
2644

2645
    def getAllUsers(self, transaction=None):
1✔
2646
        res_users = self.select(columns=[self.username], distinct=True, transaction=transaction)
1✔
2647
        users_list = list([r["username"] for r in res_users])
1✔
2648
        users = {}
1✔
2649
        for user in users_list:
1✔
2650
            res_roles = self.user_roles.select(
1✔
2651
                where=[self.user_roles.username == user], columns=[self.user_roles.role, self.user_roles.data_version], transaction=transaction
2652
            )
2653
            users[user] = {"roles": res_roles}
1✔
2654
        return users
1✔
2655

2656
    def getAllPermissions(self, retrieving_as, transaction=None):
1✔
2657
        if not self.hasPermission(retrieving_as, "permission", "view", transaction=transaction):
1!
2658
            raise PermissionDeniedError("You are not authorized to view permissions of other users.")
×
2659

2660
        ret = defaultdict(dict)
1✔
2661
        for r in self.select(transaction=transaction):
1✔
2662
            ret[r["username"]][r["permission"]] = {
1✔
2663
                "options": r["options"],
2664
                "data_version": r["data_version"],
2665
            }
2666
        return ret
1✔
2667

2668
    def countAllUsers(self, transaction=None):
1✔
2669
        res = self.select(columns=[self.username], distinct=True, transaction=transaction)
1✔
2670
        return len(res)
1✔
2671

2672
    def insert(self, changed_by, transaction=None, dryrun=False, signoffs=None, **columns):
1✔
2673
        if "permission" not in columns or "username" not in columns:
1!
2674
            raise ValueError("permission and username are required")
×
2675

2676
        self.assertPermissionExists(columns["permission"])
1✔
2677
        if columns.get("options"):
1✔
2678
            self.assertOptionsExist(columns["permission"], columns["options"])
1✔
2679

2680
        if not self.db.hasPermission(changed_by, "permission", "create", transaction=transaction):
1!
2681
            raise PermissionDeniedError("%s is not allowed to grant permissions" % changed_by)
×
2682

2683
        if not dryrun:
1✔
2684
            potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([columns], transaction=transaction).values() for obj in v]
1✔
2685
            verify_signoffs(potential_required_signoffs, signoffs)
1✔
2686

2687
        self.log.debug("granting %s to %s with options %s", columns["permission"], columns["username"], columns.get("options"))
1✔
2688
        super(Permissions, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2689
        cache.invalidate("users", "usernames")
1✔
2690
        self.log.debug("successfully granted %s to %s with options %s", columns["permission"], columns["username"], columns.get("options"))
1✔
2691

2692
    def grantRole(self, username, role, changed_by, transaction=None):
1✔
2693
        if not self.hasPermission(changed_by, "permission", "create", transaction=transaction):
1✔
2694
            raise PermissionDeniedError("%s is not allowed to grant user roles" % changed_by)
1✔
2695

2696
        if len(self.getUserPermissions(username, changed_by, transaction)) < 1:
1✔
2697
            raise ValueError("Cannot grant a role to a user without any permissions")
1✔
2698

2699
        self.log.debug("granting {} role to {}".format(role, username))
1✔
2700
        return self.user_roles.insert(changed_by, transaction, username=username, role=role)
1✔
2701

2702
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2703
        if "permission" in what:
1✔
2704
            self.assertPermissionExists(what["permission"])
1✔
2705

2706
        for current_permission in self.select(where=where, transaction=transaction):
1✔
2707
            if what.get("options"):
1✔
2708
                self.assertOptionsExist(what.get("permission", current_permission["permission"]), what["options"])
1✔
2709

2710
            if not self.db.hasPermission(changed_by, "permission", "modify", transaction=transaction):
1!
2711
                raise PermissionDeniedError("%s is not allowed to modify permissions" % changed_by)
×
2712

2713
            new_permission = current_permission.copy()
1✔
2714
            new_permission.update(what)
1✔
2715
            if not dryrun:
1✔
2716
                potential_required_signoffs = [
1✔
2717
                    obj for v in self.getPotentialRequiredSignoffs([current_permission, new_permission], transaction=transaction).values() for obj in v
2718
                ]
2719
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2720

2721
        super(Permissions, self).update(
1✔
2722
            where=where, what=what, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun
2723
        )
2724

2725
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
2726
        if not self.db.hasPermission(changed_by, "permission", "delete", transaction=transaction):
1!
2727
            raise PermissionDeniedError("%s is not allowed to revoke permissions", changed_by)
×
2728

2729
        usernames = set()
1✔
2730
        for current_permission in self.select(where=where, transaction=transaction):
1✔
2731
            usernames.add(current_permission["username"])
1✔
2732
            if not dryrun:
1✔
2733
                potential_required_signoffs = [
1✔
2734
                    obj for v in self.getPotentialRequiredSignoffs([current_permission], transaction=transaction).values() for obj in v
2735
                ]
2736
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2737

2738
        if not dryrun:
1✔
2739
            super(Permissions, self).delete(changed_by=changed_by, where=where, old_data_version=old_data_version, transaction=transaction)
1✔
2740

2741
            for u in usernames:
1✔
2742
                if len(self.getUserPermissions(u, changed_by, transaction)) == 0:
1✔
2743
                    for role in self.user_roles.select([self.user_roles.username == u], transaction=transaction):
1✔
2744
                        self.revokeRole(u, role["role"], changed_by=changed_by, old_data_version=role["data_version"], transaction=transaction)
1✔
2745

2746
        cache.invalidate("users", "usernames")
1✔
2747

2748
    def revokeRole(self, username, role, changed_by=None, old_data_version=None, transaction=None):
1✔
2749
        if not self.hasPermission(changed_by, "permission", "delete", transaction=transaction):
1✔
2750
            raise PermissionDeniedError("%s is not allowed to revoke user roles", changed_by)
1✔
2751

2752
        role_signoffs = self.db.permissionsRequiredSignoffs.select(where={"role": role}, transaction=transaction)
1✔
2753
        role_signoffs += self.db.productRequiredSignoffs.select(where={"role": role}, transaction=transaction)
1✔
2754
        if role_signoffs:
1✔
2755
            required = max([rs["signoffs_required"] for rs in role_signoffs])
1✔
2756
            users_with_role = len(self.user_roles.select(where={"role": role}, transaction=transaction))
1✔
2757
            if required > (users_with_role - 1):
1✔
2758
                raise ValueError("Revoking {} role would make it impossible for Required Signoffs to be fulfilled".format(role))
1✔
2759

2760
        return self.user_roles.delete({"username": username, "role": role}, changed_by=changed_by, old_data_version=old_data_version, transaction=transaction)
1✔
2761

2762
    def getPermission(self, username, permission, transaction=None):
1✔
2763
        try:
1✔
2764
            return self.select(where=[self.username == username, self.permission == permission], transaction=transaction)[0]
1✔
2765
        except IndexError:
1✔
2766
            return {}
1✔
2767

2768
    def getUserPermissions(self, username, retrieving_as, transaction=None):
1✔
2769
        # If the user is retrieving permissions other than their own, we need
2770
        # to make sure they have enough access to do so. If any user is able
2771
        # to retrieve permissions of anyone, it may make privilege escalation
2772
        # attacks easier.
2773
        if username != retrieving_as and not self.hasPermission(retrieving_as, "permission", "view", transaction=transaction):
1✔
2774
            raise PermissionDeniedError("You are not authorized to view permissions of other users.")
1✔
2775

2776
        rows = self.select(columns=[self.permission, self.options, self.data_version], where=[self.username == username], transaction=transaction)
1✔
2777
        ret = dict()
1✔
2778
        for row in rows:
1✔
2779
            ret[row["permission"]] = {"options": row["options"], "data_version": row["data_version"]}
1✔
2780
        return ret
1✔
2781

2782
    def getOptions(self, username, permission, transaction=None):
1✔
2783
        ret = self.select(columns=[self.options], where=[self.username == username, self.permission == permission], transaction=transaction)
1✔
2784
        if ret:
1✔
2785
            return ret[0]["options"]
1✔
2786
        else:
2787
            raise ValueError('Permission "%s" doesn\'t exist' % permission)
1✔
2788

2789
    def getUserRoles(self, username, transaction=None):
1✔
2790
        res = self.user_roles.select(
1✔
2791
            where=[self.user_roles.username == username], columns=[self.user_roles.role, self.user_roles.data_version], distinct=True, transaction=transaction
2792
        )
2793
        return [{"role": r["role"], "data_version": r["data_version"]} for r in res]
1✔
2794

2795
    def isAdmin(self, username, transaction=None):
1✔
2796
        return bool(self.getPermission(username, "admin", transaction))
1✔
2797

2798
    def hasPermission(self, username, thing, action, product=None, transaction=None):
1✔
2799
        perm = self.getPermission(username, "admin", transaction=transaction)
1✔
2800
        if perm:
1✔
2801
            options = perm["options"]
1✔
2802
            if options and options.get("products") and product not in options["products"]:
1✔
2803
                # Supporting product-wise admin permissions. If there are no options
2804
                # with admin, we assume that the user has admin access over all
2805
                # products.
2806
                return False
1✔
2807
            return True
1✔
2808

2809
        perm = self.getPermission(username, thing, transaction=transaction)
1✔
2810
        if perm:
1✔
2811
            options = perm["options"]
1✔
2812
            if options:
1✔
2813
                # If a user has a permission that doesn't explicitly limit the type of
2814
                # actions they can perform, they are allowed to do any type of action.
2815
                if options.get("actions") and action not in options["actions"]:
1✔
2816
                    return False
1✔
2817
                # Similarly, permissions without products specified grant that
2818
                # permission without any limitation on the product.
2819
                if options.get("products") and product not in options["products"]:
1✔
2820
                    return False
1✔
2821
            return True
1✔
2822

2823
        return False
1✔
2824

2825
    def hasRole(self, username, role, transaction=None):
1✔
2826
        roles_list = [r["role"] for r in self.getUserRoles(username, transaction)]
1✔
2827
        return role in roles_list
1✔
2828

2829
    def isKnownUser(self, username):
1✔
2830
        if not username:
1✔
2831
            return False
1✔
2832

2833
        cache_column = "username"
1✔
2834

2835
        def user_getter():
1✔
2836
            permissions = self.select(columns=[cache_column], distinct=True)
1✔
2837
            return [permission[cache_column] for permission in permissions]
1✔
2838

2839
        usernames = cache.get("users", "usernames", value_getter=user_getter)
1✔
2840
        return username in usernames
1✔
2841

2842

2843
class Dockerflow(AUSTable):
1✔
2844
    def __init__(self, db, metadata, dialect):
1✔
2845
        self.table = Table("dockerflow", metadata, Column("watchdog", Integer, nullable=False))
1✔
2846
        AUSTable.__init__(self, db, dialect, historyClass=None, versioned=False)
1✔
2847

2848
    def getDockerflowEntry(self, transaction=None):
1✔
2849
        return self.select(transaction=transaction)[0]
1✔
2850

2851
    def incrementWatchdogValue(self, changed_by, transaction=None, dryrun=False):
1✔
2852
        try:
1✔
2853
            value = self.getDockerflowEntry()
1✔
2854
            where = [self.watchdog == value["watchdog"]]
1✔
2855
            value["watchdog"] += 1
1✔
2856
        except IndexError:
1✔
2857
            value = {"watchdog": 1}
1✔
2858
            where = None
1✔
2859

2860
        self._putWatchdogValue(changed_by=changed_by, value=value, where=where, transaction=transaction, dryrun=dryrun)
1✔
2861

2862
        return value["watchdog"]
1✔
2863

2864
    def _putWatchdogValue(self, changed_by, value, where=None, transaction=None, dryrun=False):
1✔
2865
        if where is None:
1✔
2866
            super(Dockerflow, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, watchdog=value["watchdog"])
1✔
2867
        else:
2868
            super(Dockerflow, self).update(where=where, what=value, changed_by=changed_by, transaction=transaction, dryrun=dryrun)
1✔
2869

2870

2871
class EmergencyShutoffs(AUSTable):
1✔
2872
    def __init__(self, db, metadata, dialect):
1✔
2873
        self.table = Table(
1✔
2874
            "emergency_shutoffs",
2875
            metadata,
2876
            Column("product", String(15), nullable=False, primary_key=True),
2877
            Column("channel", String(75), nullable=False, primary_key=True),
2878
            Column("comment", String(500)),
2879
        )
2880
        AUSTable.__init__(self, db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=HistoryTable)
1✔
2881

2882
    def insert(self, changed_by, transaction=None, dryrun=False, **columns):
1✔
2883
        if not self.db.hasPermission(changed_by, "emergency_shutoff", "create", columns.get("product"), transaction):
1✔
2884
            raise PermissionDeniedError("{} is not allowed to shut off updates for product {}".format(changed_by, columns.get("product")))
1✔
2885

2886
        ret = super(EmergencyShutoffs, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2887
        if not dryrun:
1!
2888
            return ret.last_inserted_params()
1✔
2889

2890
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2891
        potential_required_signoffs = {"rs": []}
1✔
2892
        row = affected_rows[-1]
1✔
2893
        where = {"product": row["product"]}
1✔
2894
        for rs in self.db.productRequiredSignoffs.select(where=where, transaction=transaction):
1✔
2895
            if not row.get("channel") or matchRegex(row["channel"], rs["channel"]):
1✔
2896
                potential_required_signoffs["rs"].append(rs)
1✔
2897
        return potential_required_signoffs
1✔
2898

2899
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
2900
        product = self.select(where=where, columns=[self.product], transaction=transaction)[0]["product"]
1✔
2901
        if not self.db.hasPermission(changed_by, "emergency_shutoff", "delete", product, transaction):
1✔
2902
            raise PermissionDeniedError("%s is not allowed to delete shutoffs for product %s" % (changed_by, product))
1✔
2903

2904
        if not dryrun:
1✔
2905
            for current_rule in self.select(where=where, transaction=transaction):
1✔
2906
                potential_required_signoffs = [obj for v in self.getPotentialRequiredSignoffs([current_rule], transaction=transaction).values() for obj in v]
1✔
2907
                verify_signoffs(potential_required_signoffs, signoffs)
1✔
2908

2909
        super(EmergencyShutoffs, self).delete(changed_by=changed_by, where=where, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun)
1✔
2910

2911

2912
class PinnableReleasesTable(AUSTable):
1✔
2913
    def __init__(self, db, metadata, dialect):
1✔
2914
        self.table = Table(
1✔
2915
            "pinnable_releases",
2916
            metadata,
2917
            Column("product", String(15), nullable=False, primary_key=True),
2918
            Column("version", String(75), nullable=False, primary_key=True),
2919
            Column("channel", String(75), nullable=False, primary_key=True),
2920
            Column("mapping", String(100), nullable=False),
2921
        )
2922
        AUSTable.__init__(self, db, dialect, scheduled_changes=True, scheduled_changes_kwargs={"conditions": ["time"]}, historyClass=HistoryTable)
1✔
2923

2924
    def getPotentialRequiredSignoffs(self, affected_rows, transaction=None):
1✔
2925
        # Implementing this is required to schedule changes to this table
2926
        return {}
1✔
2927

2928
    def insert(self, changed_by, transaction=None, dryrun=False, **columns):
1✔
2929
        if not self.db.hasPermission(changed_by, "pinnable_release", "create", columns.get("product"), transaction):
1!
2930
            raise PermissionDeniedError("{} is not allowed to create pinnable releases for product {}".format(changed_by, columns.get("product")))
×
2931

2932
        ret = super(PinnableReleasesTable, self).insert(changed_by=changed_by, transaction=transaction, dryrun=dryrun, **columns)
1✔
2933
        if not dryrun:
1✔
2934
            return ret.last_inserted_params()
1✔
2935

2936
    def update(self, where, what, changed_by, old_data_version, transaction=None, dryrun=False, signoffs=None):
1✔
2937
        product = self.select(where=where, columns=[self.product], transaction=transaction)[0]["product"]
1✔
2938
        if not self.db.hasPermission(changed_by, "pinnable_release", "modify", product, transaction):
1!
2939
            raise PermissionDeniedError("{} is not allowed to modify pinnable releases for product {}".format(changed_by, product))
×
2940

2941
        ret = super(PinnableReleasesTable, self).update(
1✔
2942
            where=where,
2943
            what=what,
2944
            changed_by=changed_by,
2945
            old_data_version=old_data_version,
2946
            transaction=transaction,
2947
            dryrun=dryrun,
2948
        )
2949
        if not dryrun:
1✔
2950
            return ret.last_updated_params()
1✔
2951

2952
    def delete(self, where, changed_by=None, old_data_version=None, transaction=None, dryrun=False, signoffs=None):
1✔
2953
        product = self.select(where=where, columns=[self.product], transaction=transaction)[0]["product"]
1✔
2954
        if not self.db.hasPermission(changed_by, "pinnable_release", "delete", product, transaction):
1!
2955
            raise PermissionDeniedError("{} is not allowed to delete pinnable releases for product {}".format(changed_by, product))
×
2956

2957
        super(PinnableReleasesTable, self).delete(changed_by=changed_by, where=where, old_data_version=old_data_version, transaction=transaction, dryrun=dryrun)
1✔
2958

2959
    def getPinRow(self, product, channel, version, transaction=None):
1✔
2960
        rows = self.select(
1✔
2961
            where=[self.product == product, self.channel == channel, self.version == version],
2962
            columns=[self.mapping, self.data_version],
2963
            transaction=transaction,
2964
        )
2965
        if len(rows) == 0:
1✔
2966
            return None
1✔
2967
        return rows[0]
1✔
2968

2969
    def mappingHasPin(self, mapping, transaction=None):
1✔
2970
        return self.count(where=[self.mapping == mapping], transaction=transaction) > 0
1✔
2971

2972
    def getPinMapping(self, product, channel, version, transaction=None):
1✔
2973
        rows = self.select(where=[self.product == product, self.channel == channel, self.version == version], columns=[self.mapping], transaction=transaction)
1✔
2974
        if len(rows) == 0:
1✔
2975
            return None
1✔
2976
        return rows[0]["mapping"]
1✔
2977

2978

2979
# A helper that sets sql_mode. This should only be used with MySQL, and
2980
# lets us put the database in a stricter mode that will disallow things like
2981
# automatic data truncation.
2982
# From http://www.enricozini.org/2012/tips/sa-sqlmode-traditional/
2983
def my_on_connect(dbapi_con, connection_record):
1✔
2984
    cur = dbapi_con.cursor()
×
2985
    cur.execute("SET SESSION sql_mode='TRADITIONAL'")
×
2986

2987

2988
class AUSDatabase(object):
1✔
2989
    engine = None
1✔
2990
    migrate_repo = path.join(path.dirname(__file__), "migrate")
1✔
2991

2992
    def __init__(
1✔
2993
        self,
2994
        dburi=None,
2995
        mysql_traditional_mode=False,
2996
        releases_history_buckets=None,
2997
        releases_history_class=GCSHistory,
2998
        async_releases_history_class=GCSHistoryAsync,
2999
    ):
3000
        """Create a new AUSDatabase. Before this object is useful, dburi must be
3001
        set, either through the constructor or setDburi()"""
3002
        if dburi:
1!
3003
            self.setDburi(dburi, mysql_traditional_mode, releases_history_buckets, releases_history_class, async_releases_history_class)
1✔
3004
        self.log = logging.getLogger(self.__class__.__name__)
1✔
3005
        self.systemAccounts = []
1✔
3006

3007
    def setDburi(
1✔
3008
        self,
3009
        dburi,
3010
        mysql_traditional_mode=False,
3011
        releases_history_buckets=None,
3012
        releases_history_class=GCSHistory,
3013
        async_releases_history_class=GCSHistoryAsync,
3014
    ):
3015
        """Setup the database connection. Note that SQLAlchemy only opens a connection
3016
        to the database when it needs to, however."""
3017
        if self.engine:
1✔
3018
            raise AlreadySetupError()
1✔
3019
        self.dburi = dburi
1✔
3020
        self.metadata = MetaData()
1✔
3021
        engine_kwargs = {"pool_recycle": 60}
1✔
3022
        if dburi.startswith("sqlite"):
1!
3023
            # connexion 3.x runs the flask app in a thread, so we need to share the db connection
3024
            from sqlalchemy.pool import StaticPool
1✔
3025

3026
            engine_kwargs.update({"poolclass": StaticPool, "connect_args": {"check_same_thread": False}})
1✔
3027
        self.engine = create_engine(self.dburi, **engine_kwargs)
1✔
3028
        if mysql_traditional_mode and "mysql" in dburi:
1!
3029
            sqlalchemy.event.listen(self.engine, "connect", my_on_connect)
×
3030
        dialect = self.engine.name
1✔
3031
        self.rulesTable = Rules(self, self.metadata, dialect)
1✔
3032
        self.releasesTable = Releases(self, self.metadata, dialect, releases_history_buckets, releases_history_class)
1✔
3033
        self.releasesJSONTable = ReleasesJSON(self, self.metadata, dialect, releases_history_buckets, async_releases_history_class)
1✔
3034
        self.releaseAssetsTable = ReleaseAssets(self, self.metadata, dialect, releases_history_buckets, async_releases_history_class)
1✔
3035
        self.permissionsTable = Permissions(self, self.metadata, dialect)
1✔
3036
        self.dockerflowTable = Dockerflow(self, self.metadata, dialect)
1✔
3037
        self.productRequiredSignoffsTable = ProductRequiredSignoffsTable(self, self.metadata, dialect)
1✔
3038
        self.permissionsRequiredSignoffsTable = PermissionsRequiredSignoffsTable(self, self.metadata, dialect)
1✔
3039
        self.emergencyShutoffsTable = EmergencyShutoffs(self, self.metadata, dialect)
1✔
3040
        self.pinnableReleasesTable = PinnableReleasesTable(self, self.metadata, dialect)
1✔
3041
        self.metadata.bind = self.engine
1✔
3042

3043
    def setSystemAccounts(self, systemAccounts):
1✔
3044
        self.systemAccounts = systemAccounts
1✔
3045

3046
    def setDomainAllowlist(self, domainAllowlist):
1✔
3047
        self.releasesTable.setDomainAllowlist(domainAllowlist)
1✔
3048

3049
    def isKnownUser(self, username):
1✔
3050
        return self.permissions.isKnownUser(username)
×
3051

3052
    def isAdmin(self, *args, **kwargs):
1✔
3053
        return self.permissions.isAdmin(*args, **kwargs)
1✔
3054

3055
    def hasPermission(self, *args, **kwargs):
1✔
3056
        return self.permissions.hasPermission(*args, **kwargs)
1✔
3057

3058
    def hasRole(self, *args, **kwargs):
1✔
3059
        return self.permissions.hasRole(*args, **kwargs)
1✔
3060

3061
    def getUserRoles(self, *args, **kwargs):
1✔
3062
        return self.permissions.getUserRoles(*args, **kwargs)
1✔
3063

3064
    def create(self, version=None):
1✔
3065
        # Migrate's "create" merely declares a database to be under its control,
3066
        # it doesn't actually create tables or upgrade it. So we need to call it
3067
        # and then do the upgrade to get to the state we want. We also have to
3068
        # tell create that we're creating at version 0 of the database, otherwise
3069
        # upgrade will do nothing!
3070
        migrate.versioning.schema.ControlledSchema.create(self.engine, self.migrate_repo, 0)
1✔
3071
        self.upgrade(version)
1✔
3072

3073
    def upgrade(self, version=None):
1✔
3074
        # This method was taken from Buildbot:
3075
        # https://github.com/buildbot/buildbot/blob/87108ec4088dc7fd5394ac3c1d0bd3b465300d92/master/buildbot/db/model.py#L455
3076
        # http://code.google.com/p/sqlalchemy-migrate/issues/detail?id=100
3077
        # means we cannot use the migrate.versioning.api module.  So these
3078
        # methods perform similar wrapping functions to what is done by the API
3079
        # functions, but without disposing of the engine.
3080
        schema = migrate.versioning.schema.ControlledSchema(self.engine, self.migrate_repo)
1✔
3081
        changeset = schema.changeset(version)
1✔
3082
        for step, change in changeset:
1✔
3083
            self.log.debug("migrating schema version %s -> %d" % (step, step + 1))
1✔
3084
            schema.runchange(step, change, 1)
1✔
3085

3086
    def downgrade(self, version):
1✔
3087
        if version < 21:
1!
3088
            raise ValueError("Cannot downgrade below version 21")
×
3089
        schema = migrate.versioning.schema.ControlledSchema(self.engine, self.migrate_repo)
1✔
3090
        changeset = schema.changeset(version)
1✔
3091
        for step, change in changeset:
1✔
3092
            self.log.debug("migrating schema version %s -> %d" % (step, step - 1))
1✔
3093
            schema.runchange(step, change, -1)
1✔
3094

3095
    def reset(self):
1✔
3096
        self.engine = None
1✔
3097
        self.metadata.bind = None
1✔
3098

3099
    def begin(self):
1✔
3100
        return AUSTransaction(self.engine)
1✔
3101

3102
    @property
1✔
3103
    def rules(self):
1✔
3104
        return self.rulesTable
1✔
3105

3106
    @property
1✔
3107
    def releases(self):
1✔
3108
        return self.releasesTable
1✔
3109

3110
    @property
1✔
3111
    def releases_json(self):
1✔
3112
        return self.releasesJSONTable
1✔
3113

3114
    @property
1✔
3115
    def release_assets(self):
1✔
3116
        return self.releaseAssetsTable
1✔
3117

3118
    @property
1✔
3119
    def permissions(self):
1✔
3120
        return self.permissionsTable
1✔
3121

3122
    @property
1✔
3123
    def productRequiredSignoffs(self):
1✔
3124
        return self.productRequiredSignoffsTable
1✔
3125

3126
    @property
1✔
3127
    def permissionsRequiredSignoffs(self):
1✔
3128
        return self.permissionsRequiredSignoffsTable
1✔
3129

3130
    @property
1✔
3131
    def dockerflow(self):
1✔
3132
        return self.dockerflowTable
1✔
3133

3134
    @property
1✔
3135
    def emergencyShutoffs(self):
1✔
3136
        return self.emergencyShutoffsTable
1✔
3137

3138
    @property
1✔
3139
    def pinnable_releases(self):
1✔
3140
        return self.pinnableReleasesTable
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc