• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

winter-telescope / winterdrp / 3762336726

pending completion
3762336726

push

github

GitHub
Update postgres DB, cap processors (#253)

493 of 493 new or added lines in 11 files covered. (100.0%)

4639 of 6135 relevant lines covered (75.62%)

0.76 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.2
/winterdrp/processors/database/postgres.py
1
"""
2
Module containing postgres util functions
3
"""
4
# pylint: disable=not-context-manager
5
import logging
1✔
6
import os
1✔
7
from glob import glob
1✔
8
from pathlib import Path
1✔
9
from typing import Optional
1✔
10

11
import numpy as np
1✔
12
import psycopg
1✔
13
from psycopg import errors
1✔
14
from psycopg.rows import Row
1✔
15

16
from winterdrp.data import DataBlock
1✔
17
from winterdrp.errors import ProcessorError
1✔
18
from winterdrp.processors.database.constraints import DBQueryConstraints
1✔
19
from winterdrp.processors.database.utils import get_ordered_schema_list
1✔
20

21
logger = logging.getLogger(__name__)
1✔
22

23
DB_USER_KEY = "DB_USER"
1✔
24
DB_PASSWORD_KEY = "DB_PWD"
1✔
25

26
PG_ADMIN_USER_KEY = "PG_ADMIN_USER"
1✔
27
PG_ADMIN_PWD_KEY = "PG_ADMIN_PWD"
1✔
28

29
DB_USER = os.getenv(DB_USER_KEY)
1✔
30
DB_PASSWORD = os.getenv(DB_PASSWORD_KEY)
1✔
31

32
ADMIN_USER = os.getenv(PG_ADMIN_USER_KEY, DB_USER)
1✔
33
ADMIN_PASSWORD = os.getenv(PG_ADMIN_PWD_KEY, DB_PASSWORD)
1✔
34

35
POSTGRES_DUPLICATE_PROTOCOLS = ["fail", "ignore", "replace"]
1✔
36

37

38
class DataBaseError(ProcessorError):
1✔
39
    """Error relating to postgres interactions"""
40

41

42
class PostgresUser:
1✔
43
    """
44
    Basic Postgres user class for executing functions
45
    """
46

47
    user_env_varaiable = DB_USER_KEY
1✔
48
    pass_env_variable = DB_PASSWORD_KEY
1✔
49

50
    def __init__(self, db_user: str = DB_USER, db_password: str = DB_PASSWORD):
1✔
51
        self.db_user = db_user
1✔
52
        self.db_password = db_password
1✔
53

54
    def validate_credentials(self):
1✔
55
        """
56
        Checks that user credentials exist
57
        :return: None
58
        """
59
        if self.db_user is None:
1✔
60
            err = (
61
                f"'db_user' is set as None. Please pass a db_user as an argument, "
62
                f"or set the environment variable '{self.user_env_varaiable}'."
63
            )
64
            logger.error(err)
65
            raise DataBaseError(err)
66

67
        if self.db_password is None:
1✔
68
            err = (
69
                f"'db_password' is set as None. Please pass a password as an argument, "
70
                f"or set the environment variable '{self.pass_env_variable}'."
71
            )
72
            logger.error(err)
73
            raise DataBaseError(err)
74

75
        # TODO check user exists
76

77
    def run_sql_command_from_file(self, file_path: str | Path, db_name: str):
1✔
78
        """
79
        Execute SQL command from file
80

81
        :param file_path: File to execute
82
        :param db_name: name of database
83
        :return: False
84
        """
85
        with psycopg.connect(
×
86
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
87
        ) as conn:
88
            with open(file_path, "r", encoding="utf8") as sql_file:
×
89
                conn.execute(sql_file.read())
×
90

91
            logger.info(f"Executed sql commands from file {file_path}")
×
92

93
    def create_table(self, schema_path: str | Path, db_name: str):
1✔
94
        """
95
        Create a database table
96

97
        :param schema_path: File to execute
98
        :param db_name: name of database
99
        :return: None
100
        """
101
        with psycopg.connect(
1✔
102
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
103
        ) as conn:
104
            conn.autocommit = True
1✔
105
            with open(schema_path, "r", encoding="utf8") as schema_file:
1✔
106
                conn.execute(schema_file.read())
1✔
107

108
        logger.info(f"Created table from schema path {schema_path}")
1✔
109

110
    def create_tables_from_schema(
1✔
111
        self,
112
        schema_dir: str | Path,
113
        db_name: str,
114
    ):
115
        """
116
        Creates a db with tables, as described by .sql files in a directory
117

118
        :param schema_dir: Directory containing schema files
119
        :param db_name: name of DB
120
        :return: None
121
        """
122
        schema_files = glob(f"{schema_dir}/*.sql")
1✔
123
        ordered_schema_files = get_ordered_schema_list(schema_files)
1✔
124
        logger.info(f"Creating the following tables - {ordered_schema_files}")
1✔
125
        for schema_file in ordered_schema_files:
1✔
126
            self.create_table(schema_path=schema_file, db_name=db_name)
1✔
127

128
    def export_to_db(
1✔
129
        self,
130
        value_dict: dict | DataBlock,
131
        db_name: str,
132
        db_table: str,
133
        duplicate_protocol: str = "fail",
134
    ) -> tuple[list, list]:
135
        """
136
        Export a list of fields in value dict to a batabase table
137

138
        :param value_dict: dictionary/DataBlock/other dictonary-like object to export
139
        :param db_name: name of db to export to
140
        :param db_table: table of DB to export to
141
        :param duplicate_protocol: protocol for handling duplicates,
142
            in "fail"/"ignore"/"replace"
143
        :return:
144
        """
145

146
        assert duplicate_protocol in POSTGRES_DUPLICATE_PROTOCOLS
1✔
147

148
        with psycopg.connect(
1✔
149
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
150
        ) as conn:
151
            conn.autocommit = True
1✔
152

153
            sql_query = f"""
1✔
154
            SELECT Col.Column_Name from
155
                INFORMATION_SCHEMA.TABLE_CONSTRAINTS Tab,
156
                INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE Col
157
            WHERE
158
                Col.Constraint_Name = Tab.Constraint_Name
159
                AND Col.Table_Name = Tab.Table_Name
160
                AND Constraint_Type = 'PRIMARY KEY'
161
                AND Col.Table_Name = '{db_table}'
162
            """
163
            serial_keys, serial_key_values = [], []
1✔
164
            with conn.execute(sql_query) as cursor:
1✔
165

166
                primary_key = [x[0] for x in cursor.fetchall()]
1✔
167
                serial_keys = list(self.get_sequence_keys_from_table(db_table, db_name))
1✔
168
                logger.debug(serial_keys)
1✔
169
                colnames = [
1✔
170
                    desc[0]
171
                    for desc in conn.execute(
172
                        f"SELECT * FROM {db_table} LIMIT 1"
173
                    ).description
174
                    if desc[0] not in serial_keys
175
                ]
176

177
                colnames_str = ""
1✔
178
                for column in colnames:
1✔
179
                    colnames_str += f'"{column}",'
1✔
180
                colnames_str = colnames_str[:-1]
1✔
181
                txt = f"INSERT INTO {db_table} ({colnames_str}) VALUES ("
1✔
182

183
                for char in ["[", "]", "'"]:
1✔
184
                    txt = txt.replace(char, "")
1✔
185

186
                for column in colnames:
1✔
187
                    txt += f"'{str(value_dict[column])}', "
1✔
188

189
                txt = txt + ") "
1✔
190
                txt = txt.replace(", )", ")")
1✔
191

192
                if len(serial_keys) > 0:
1✔
193
                    txt += "RETURNING "
1✔
194
                    for key in serial_keys:
1✔
195
                        txt += f"{key},"
1✔
196
                    txt += ";"
1✔
197
                    txt = txt.replace(",;", ";")
1✔
198

199
                logger.debug(txt)
1✔
200
                command = txt
1✔
201

202
                try:
1✔
203
                    cursor.execute(command)
1✔
204
                    if len(serial_keys) > 0:
1✔
205
                        serial_key_values = cursor.fetchall()[0]
1✔
206
                    else:
207
                        serial_key_values = []
1✔
208

209
                except errors.UniqueViolation as exc:
×
210
                    primary_key_values = [value_dict[x] for x in primary_key]
×
211

212
                    if duplicate_protocol == "fail":
×
213
                        err = (
214
                            f"Duplicate error, entry with "
215
                            f"{primary_key}={primary_key_values} "
216
                            f"already exists in {db_name}."
217
                        )
218
                        logger.error(err)
219
                        raise errors.UniqueViolation from exc
220

221
                    if duplicate_protocol == "ignore":
×
222
                        logger.debug(
×
223
                            f"Found duplicate entry with "
224
                            f"{primary_key}={primary_key_values} in {db_name}. "
225
                            f"Ignoring."
226
                        )
227
                    elif duplicate_protocol == "replace":
×
228
                        logger.debug(
×
229
                            f"Updating duplicate entry with "
230
                            f"{primary_key}={primary_key_values} in {db_name}."
231
                        )
232

233
                        db_constraints = DBQueryConstraints(
×
234
                            columns=primary_key,
235
                            accepted_values=primary_key_values,
236
                        )
237

238
                        update_colnames = []
×
239
                        for column in colnames:
×
240
                            if column not in primary_key:
×
241
                                update_colnames.append(column)
×
242

243
                        serial_key_values = self.modify_db_entry(
×
244
                            db_constraints=db_constraints,
245
                            value_dict=value_dict,
246
                            db_alter_columns=update_colnames,
247
                            db_table=db_table,
248
                            db_name=db_name,
249
                            return_columns=serial_keys,
250
                        )
251

252
        return serial_keys, serial_key_values
1✔
253

254
    def modify_db_entry(
1✔
255
        self,
256
        db_name: str,
257
        db_table: str,
258
        db_constraints: DBQueryConstraints,
259
        value_dict: dict | DataBlock,
260
        db_alter_columns: str | list[str],
261
        return_columns: Optional[str | list[str]] = None,
262
    ) -> list[Row]:
263
        """
264
        Modify a db entry
265

266
        :param db_name: name of db
267
        :param db_table: Name of table
268
        :param value_dict: dict-like object to provide updated values
269
        :param db_alter_columns: columns to alter in db
270
        :param return_columns: columns to return
271
        :return: db query (return columns)
272
        """
273

274
        if not isinstance(db_alter_columns, list):
1✔
275
            db_alter_columns = [db_alter_columns]
1✔
276

277
        if return_columns is None:
1✔
278
            return_columns = db_alter_columns
1✔
279
        if not isinstance(return_columns, list):
1✔
280
            return_columns = [return_columns]
×
281

282
        constraints = db_constraints.parse_constraints()
1✔
283

284
        with psycopg.connect(
1✔
285
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
286
        ) as conn:
287
            conn.autocommit = True
1✔
288

289
            db_alter_values = [str(value_dict[c]) for c in db_alter_columns]
1✔
290

291
            alter_values_txt = [
1✔
292
                f"{db_alter_columns[ind]}='{db_alter_values[ind]}'"
293
                for ind in range(len(db_alter_columns))
294
            ]
295

296
            sql_query = (
1✔
297
                f"UPDATE {db_table} SET {', '.join(alter_values_txt)} "
298
                f"WHERE {constraints}"
299
            )
300

301
            if len(return_columns) > 0:
1✔
302
                logger.debug(return_columns)
1✔
303
                sql_query += f""" RETURNING {', '.join(return_columns)}"""
1✔
304
            sql_query += ";"
1✔
305
            query_output = self.execute_query(sql_query, db_name)
1✔
306

307
        return query_output
1✔
308

309
    def get_sequence_keys_from_table(self, db_table: str, db_name: str) -> np.ndarray:
1✔
310
        """
311
        Gets sequence keys of db table
312

313
        :param db_table: database table to use
314
        :param db_name: dataname name
315
        :return: numpy array of keys
316
        """
317
        with psycopg.connect(
1✔
318
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
319
        ) as conn:
320
            conn.autocommit = True
1✔
321
            sequences = [
1✔
322
                x[0]
323
                for x in conn.execute(
324
                    "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';"
325
                ).fetchall()
326
            ]
327
            seq_tables = np.array([x.split("_")[0] for x in sequences])
1✔
328
            seq_columns = np.array([x.split("_")[1] for x in sequences])
1✔
329
            table_sequence_keys = seq_columns[(seq_tables == db_table)]
1✔
330
        return table_sequence_keys
1✔
331

332
    def import_from_db(
1✔
333
        self,
334
        db_name: str,
335
        db_table: str,
336
        db_output_columns: str | list[str],
337
        output_alias_map: Optional[str | list[str]] = None,
338
        max_num_results: Optional[int] = None,
339
        db_constraints: Optional[DBQueryConstraints] = None,
340
    ) -> list[dict]:
341
        """Query an SQL database with constraints, and return a list of dictionaries.
342
        One dictionary per entry returned from the query.
343

344
        #TODO check admin
345

346
        Parameters
347
        ----------
348
        db_name: Name of database to query
349
        db_table: Name of database table to query
350
        db_output_columns: Name(s) of columns to return for matched database entries
351
        output_alias_map: Alias to assign for each output column
352
        max_num_results: Maximum number of results to return
353

354
        Returns
355
        -------
356
        A list of dictionaries (one per entry)
357
        """
358

359
        if not isinstance(db_output_columns, list):
×
360
            db_output_columns = [db_output_columns]
×
361

362
        if output_alias_map is None:
×
363
            output_alias_map = db_output_columns
×
364

365
        if not isinstance(output_alias_map, list):
×
366
            output_alias_map = [output_alias_map]
×
367

368
        assert len(output_alias_map) == len(db_output_columns)
×
369

370
        all_query_res = []
×
371

372
        if db_constraints is not None:
×
373
            constraints = db_constraints.parse_constraints()
×
374
        else:
375
            constraints = ""
×
376

377
        with psycopg.connect(
×
378
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
379
        ) as conn:
380
            conn.autocommit = True
×
381
            sql_query = f"""
×
382
            SELECT {', '.join(db_output_columns)} from {db_table}
383
                WHERE {constraints}
384
            """
385

386
            if max_num_results is not None:
×
387
                sql_query += f" LIMIT {max_num_results}"
×
388

389
            sql_query += ";"
×
390

391
            logger.debug(f"Query: {sql_query}")
×
392

393
            with conn.execute(sql_query) as cursor:
×
394
                query_output = cursor.fetchall()
×
395

396
            for entry in query_output:
×
397

398
                assert len(entry) == len(db_output_columns)
×
399

400
                query_res = {}
×
401

402
                for i, key in enumerate(output_alias_map):
×
403
                    query_res[key] = entry[i]
×
404

405
                all_query_res.append(query_res)
×
406

407
        return all_query_res
×
408

409
    def execute_query(self, sql_query: str, db_name: str) -> list[Row]:
1✔
410
        """
411
        Generically execute SQL query
412

413
        :param sql_query: SQL query to execute
414
        :param db_name: db name
415
        :return: rows from db
416
        """
417
        with psycopg.connect(
1✔
418
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
419
        ) as conn:
420
            conn.autocommit = True
1✔
421
            logger.debug(f"Query: {sql_query}")
1✔
422

423
            with conn.execute(sql_query) as cursor:
1✔
424
                query_output = cursor.fetchall()
1✔
425

426
        return query_output
1✔
427

428
    def crossmatch_with_database(
1✔
429
        self,
430
        db_name: str,
431
        db_table: str,
432
        db_output_columns: str | list[str],
433
        ra: float,
434
        dec: float,
435
        crossmatch_radius_arcsec: float,
436
        output_alias_map: Optional[dict] = None,
437
        ra_field_name: str = "ra",
438
        dec_field_name: str = "dec",
439
        query_distance_bool: bool = False,
440
        q3c_bool: bool = False,
441
        query_constraints: Optional[DBQueryConstraints] = None,
442
        order_field_name: Optional[str] = None,
443
        num_limit: Optional[int] = None,
444
    ) -> list[dict]:
445
        """
446
        Crossmatch a given spatial position (ra/dec) with sources in a database,
447
        and returns a list of matches
448

449
        #TODO: check admin
450

451
        :param db_name: name of db to query
452
        :param db_table: name of db table
453
        :param db_output_columns: columns to return
454
        :param output_alias_map: mapping for renaming columns
455
        :param ra: RA
456
        :param dec: dec
457
        :param crossmatch_radius_arcsec: radius for crossmatch
458
        :param ra_field_name: name of ra column in database
459
        :param dec_field_name: name of dec column in database
460
        :param query_distance_bool: boolean where to return crossmatch distance
461
        :param q3c_bool: boolean whether to use q3c_bool
462
        :param order_field_name: field to order result by
463
        :param num_limit: limit on sql query
464
        :return: list of query result dictionaries
465
        """
466

467
        if output_alias_map is None:
×
468
            output_alias_map = {}
×
469
            for col in db_output_columns:
×
470
                output_alias_map[col] = col
×
471

472
        crossmatch_radius_deg = crossmatch_radius_arcsec / 3600.0
×
473

474
        if q3c_bool:
×
475
            constraints = (
×
476
                f"q3c_radial_query({ra_field_name},{dec_field_name},"
477
                f"{ra},{dec},{crossmatch_radius_deg}) "
478
            )
479
        else:
480
            ra_min = ra - crossmatch_radius_deg
×
481
            ra_max = ra + crossmatch_radius_deg
×
482
            dec_min = dec - crossmatch_radius_deg
×
483
            dec_max = dec + crossmatch_radius_deg
×
484
            constraints = (
×
485
                f" {ra_field_name} between {ra_min} and {ra_max} AND "
486
                f"{dec_field_name} between {dec_min} and {dec_max} "
487
            )
488

489
        if query_constraints is not None:
×
490
            constraints += f"""AND {query_constraints.parse_constraints()}"""
×
491

492
        select = f""" {'"' + '","'.join(db_output_columns) + '"'}"""
×
493
        if query_distance_bool:
×
494
            if q3c_bool:
×
495
                select = (
×
496
                    f"q3c_dist({ra_field_name},{dec_field_name},{ra},{dec}) AS xdist,"
497
                    + select
498
                )
499
            else:
500
                select = f"""{ra_field_name} - ra AS xdist,""" + select
×
501

502
        query = f"""SELECT {select} FROM {db_table} WHERE {constraints}"""
×
503

504
        if order_field_name is not None:
×
505
            query += f""" ORDER BY {order_field_name}"""
×
506
        if num_limit is not None:
×
507
            query += f""" LIMIT {num_limit}"""
×
508

509
        query += ";"
×
510

511
        query_output = self.execute_query(query, db_name)
×
512
        all_query_res = []
×
513

514
        for entry in query_output:
×
515
            if not query_distance_bool:
×
516
                assert len(entry) == len(db_output_columns)
×
517
            else:
518
                assert len(entry) == len(db_output_columns) + 1
×
519
            query_res = {}
×
520
            for i, key in enumerate(output_alias_map):
×
521
                query_res[key] = entry[i]
×
522
                if query_distance_bool:
×
523
                    query_res["dist"] = entry["xdist"]
×
524
            all_query_res.append(query_res)
×
525
        return all_query_res
×
526

527
    def check_if_exists(
1✔
528
        self, check_command: str, check_value: str, db_name: str = "postgres"
529
    ) -> bool:
530
        """
531
        Check if a user account exists
532

533
        :param check_command if a user/database/table exists
534
        :param check_value: username to check
535
        :param db_name: name of database to query
536
        :return: boolean
537
        """
538
        with psycopg.connect(
1✔
539
            f"dbname={db_name} user={self.db_user} password={self.db_password}"
540
        ) as conn:
541
            conn.autocommit = True
1✔
542
            data = conn.execute(check_command).fetchall()
1✔
543
        existing_user_names = [x[0] for x in data]
1✔
544
        logger.debug(f"Found the following values: {existing_user_names}")
1✔
545

546
        return check_value in existing_user_names
1✔
547

548
    def create_db(self, db_name: str):
1✔
549
        """
550
        Creates a database using credentials
551

552
        :param db_name: DB to create
553
        :return: None
554
        """
555

556
        with psycopg.connect(
1✔
557
            f"dbname=postgres user={self.db_user} password={self.db_password}"
558
        ) as conn:
559
            conn.autocommit = True
1✔
560
            sql = f"""CREATE database {db_name}"""
1✔
561
            conn.execute(sql)
1✔
562
            logger.info(f"Created db {db_name}")
1✔
563

564
    def check_if_db_exists(self, db_name: str) -> bool:
1✔
565
        """
566
        Check if a user account exists
567

568
        :param db_name: database to check
569
        :return: boolean
570
        """
571

572
        check_command = """SELECT datname FROM pg_database;"""
1✔
573

574
        db_exist_bool = self.check_if_exists(
1✔
575
            check_command=check_command,
576
            check_value=db_name,
577
            db_name="postgres",
578
        )
579

580
        logger.debug(f"Database '{db_name}' does {['not ', ''][db_exist_bool]} exist")
1✔
581

582
        return db_exist_bool
1✔
583

584
    def check_if_table_exists(self, db_name: str, db_table: str) -> bool:
1✔
585
        """
586
        Check if a db table account exists
587

588
        :param db_name: database to check
589
        :param db_table: table to check
590
        :return: boolean
591
        """
592

593
        check_command = (
1✔
594
            "SELECT table_name FROM information_schema.tables "
595
            "WHERE table_schema='public';"
596
        )
597

598
        table_exist_bool = self.check_if_exists(
1✔
599
            check_command=check_command,
600
            check_value=db_table,
601
            db_name=db_name,
602
        )
603

604
        logger.debug(f"Table '{db_table}' does {['not ', ''][table_exist_bool]} exist")
1✔
605

606
        return table_exist_bool
1✔
607

608

609
class PostgresAdmin(PostgresUser):
1✔
610
    """
611
    An Admin postgres user, with additional functionality for creatying new users
612
    """
613

614
    user_env_varaiable = PG_ADMIN_USER_KEY
1✔
615
    pass_env_variable = PG_ADMIN_PWD_KEY
1✔
616

617
    def __init__(self, db_user: str = ADMIN_USER, db_password: str = ADMIN_PASSWORD):
1✔
618
        super().__init__(db_user=db_user, db_password=db_password)
1✔
619

620
    def create_new_user(self, new_db_user: str, new_password: str):
1✔
621
        """
622
        Create a new postgres user
623

624
        :param new_db_user: new username
625
        :param new_password: new user password
626
        :return: None
627
        """
628

629
        with psycopg.connect(
×
630
            f"dbname=postgres user={self.db_user} password={self.db_password}"
631
        ) as conn:
632
            conn.autocommit = True
×
633
            command = f"CREATE ROLE {new_db_user} WITH password '{new_password}' LOGIN;"
×
634
            conn.execute(command)
×
635

636
    def grant_privileges(self, db_name: str, db_user: str):
1✔
637
        """
638
        Grant privilege to user on database
639

640
        :param db_name: name of database
641
        :param db_user: username to grant privileges for db_user
642
        :return: None
643
        """
644
        with psycopg.connect(
1✔
645
            f"dbname=postgres user={self.db_user} password={self.db_password}"
646
        ) as conn:
647
            conn.autocommit = True
1✔
648
            command = f"""GRANT ALL PRIVILEGES ON DATABASE {db_name} TO {db_user};"""
1✔
649
            conn.execute(command)
1✔
650

651
    def check_if_user_exists(self, user_name: str) -> bool:
1✔
652
        """
653
        Check if a user account exists
654

655
        :param user_name: username to check
656
        :return: boolean
657
        """
658
        check_command = """SELECT usename FROM pg_user;"""
1✔
659

660
        user_exist_bool = self.check_if_exists(
1✔
661
            check_command=check_command,
662
            check_value=user_name,
663
        )
664

665
        logger.debug(f"User '{user_name}' does {['not ', ''][user_exist_bool]} exist")
1✔
666

667
        return user_exist_bool
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc