• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

winter-telescope / winterdrp / 3699777599

pending completion
3699777599

push

github

GitHub
Add initial mypy integration (#241)

223 of 223 new or added lines in 45 files covered. (100.0%)

4575 of 6107 relevant lines covered (74.91%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

65.06
/winterdrp/processors/database/postgres.py
1
"""
2
Module containing postgres util functions
3
"""
4
# pylint: disable=not-context-manager
5
import logging
1✔
6
import os
1✔
7
from glob import glob
1✔
8
from typing import Optional
1✔
9

10
import astropy.io.fits
1✔
11
import numpy as np
1✔
12
import psycopg
1✔
13
from psycopg import errors
1✔
14

15
from winterdrp.errors import ProcessorError
1✔
16

17
logger = logging.getLogger(__name__)
1✔
18

19
PG_ADMIN_USER_KEY = "PG_ADMIN_USER"
1✔
20
PG_ADMIN_PWD_KEY = "PG_ADMIN_PWD"
1✔
21

22

23
class DataBaseError(ProcessorError):
1✔
24
    """Error relating to postgres interactions"""
25

26

27
def validate_credentials(db_user: str, password: str, admin=False):
1✔
28

29
    if db_user is None:
1✔
30
        user = "db_user"
×
31
        env_user_var = "DB_USER"
×
32
        if admin:
×
33
            user = "admin_db_user"
×
34
            env_user_var = PG_ADMIN_USER_KEY
×
35
        err = (
36
            f"'{user}' is set as None. Please pass a db_user as an argument, "
37
            f"or set the environment variable '{env_user_var}'. Using "
38
        )
39
        logger.warning(err)
×
40
        raise DataBaseError(err)
41

42
    if password is None:
1✔
43
        pwd = "password"
×
44
        env_pwd_var = "DB_PWD"
×
45
        if admin:
×
46
            pwd = "db_admin_password"
×
47
            env_pwd_var = PG_ADMIN_PWD_KEY
×
48
        err = (
49
            f"'{pwd}' is set as None. Please pass a password as an argument, "
50
            f"or set the environment variable '{env_pwd_var}'."
51
        )
52
        logger.error(err)
53
        raise DataBaseError(err)
54

55

56
def create_db(db_name: str):
1✔
57
    admin_user = os.environ.get(PG_ADMIN_USER_KEY)
1✔
58
    admin_password = os.environ.get(PG_ADMIN_PWD_KEY)
1✔
59
    validate_credentials(db_user=admin_user, password=admin_password)
1✔
60

61
    with psycopg.connect(
1✔
62
        f"dbname=postgres user={admin_user} password={admin_password}"
63
    ) as conn:
64
        conn.autocommit = True
1✔
65
        sql = f"""CREATE database {db_name}"""
1✔
66
        conn.execute(sql)
1✔
67
        logger.info(f"Created db {db_name}")
1✔
68

69

70
def run_sql_command_from_file(file_path, db_name, db_user, password, admin=False):
1✔
71
    validate_credentials(db_name, db_user, admin)
×
72
    with psycopg.connect(
×
73
        f"dbname={db_name} user={db_user} password={password}"
74
    ) as conn:
75
        with open(file_path, "r", encoding="utf8") as sql_file:
×
76
            conn.execute(sql_file.read())
×
77

78
        logger.info(f"Executed sql commands from file {file_path}")
×
79

80

81
def create_table(schema_path: str, db_name: str, db_user: str, password: str):
1✔
82
    validate_credentials(db_user, password)
1✔
83

84
    with psycopg.connect(
1✔
85
        f"dbname={db_name} user={db_user} password={password}"
86
    ) as conn:
87
        conn.autocommit = True
1✔
88
        with open(schema_path, "r", encoding="utf8") as schema_file:
1✔
89
            conn.execute(schema_file.read())
1✔
90

91
    logger.info(f"Created table from schema path {schema_path}")
1✔
92

93

94
def create_new_user(new_db_user: str, new_password: str):
1✔
95
    admin_user = os.environ.get(PG_ADMIN_USER_KEY)
×
96
    admin_password = os.environ.get(PG_ADMIN_PWD_KEY)
×
97

98
    validate_credentials(new_db_user, new_password)
×
99
    validate_credentials(db_user=admin_user, password=admin_password, admin=True)
×
100

101
    with psycopg.connect(
×
102
        f"dbname=postgres user={admin_user} password={admin_password}"
103
    ) as conn:
104
        conn.autocommit = True
×
105
        command = f"""CREATE ROLE {new_db_user} WITH password '{new_password}' LOGIN;"""
×
106
        conn.execute(command)
×
107

108

109
def grant_privileges(db_name: str, db_user: str):
1✔
110
    admin_user = os.environ.get(PG_ADMIN_USER_KEY)
1✔
111
    admin_password = os.environ.get(PG_ADMIN_PWD_KEY)
1✔
112
    validate_credentials(admin_user, admin_password, admin=True)
1✔
113

114
    with psycopg.connect(
1✔
115
        f"dbname=postgres user={admin_user} password={admin_password}"
116
    ) as conn:
117
        conn.autocommit = True
1✔
118
        command = f"""GRANT ALL PRIVILEGES ON DATABASE {db_name} TO {db_user};"""
1✔
119
        conn.execute(command)
1✔
120

121

122
def check_if_user_exists(
1✔
123
    user_name: str,
124
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
125
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
126
) -> bool:
127
    validate_credentials(db_user, password)
1✔
128

129
    with psycopg.connect(f"dbname=postgres user={db_user} password={password}") as conn:
1✔
130
        conn.autocommit = True
1✔
131
        command = """SELECT usename FROM pg_user;"""
1✔
132
        data = conn.execute(command).fetchall()
1✔
133
    existing_user_names = [x[0] for x in data]
1✔
134
    logger.debug(f"Found the following users: {existing_user_names}")
1✔
135

136
    user_exist_bool = user_name in existing_user_names
1✔
137
    logger.debug(
1✔
138
        f"User '{user_name}' {['does not exist', 'already exists'][user_exist_bool]}"
139
    )
140
    return user_exist_bool
1✔
141

142

143
def check_if_db_exists(
1✔
144
    db_name: str,
145
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
146
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
147
) -> bool:
148
    validate_credentials(db_user, password)
1✔
149

150
    with psycopg.connect(f"dbname=postgres user={db_user} password={password}") as conn:
1✔
151
        conn.autocommit = True
1✔
152
        command = """SELECT datname FROM pg_database;"""
1✔
153
        data = conn.execute(command).fetchall()
1✔
154

155
    existing_db_names = [x[0] for x in data]
1✔
156
    logger.debug(f"Found the following databases: {existing_db_names}")
1✔
157

158
    db_exist_bool = db_name in existing_db_names
1✔
159
    logger.debug(
1✔
160
        f"Database '{db_name}' {['does not exist', 'already exists'][db_exist_bool]}"
161
    )
162

163
    return db_exist_bool
1✔
164

165

166
def check_if_table_exists(
1✔
167
    db_name: str, db_table: str, db_user: str, password: str
168
) -> bool:
169
    validate_credentials(db_user=db_user, password=password)
1✔
170

171
    with psycopg.connect(
1✔
172
        f"dbname={db_name} user={db_user} password={password}"
173
    ) as conn:
174
        conn.autocommit = True
1✔
175
        command = (
1✔
176
            "SELECT table_name FROM information_schema.tables "
177
            "WHERE table_schema='public';"
178
        )
179
        data = conn.execute(command).fetchall()
1✔
180

181
    existing_table_names = [x[0] for x in data]
1✔
182
    logger.debug(f"Found the following tables: {existing_table_names}")
1✔
183

184
    table_exist_bool = db_table in existing_table_names
1✔
185
    logger.debug(
1✔
186
        f"Database table '{db_table}' "
187
        f"{['does not exist', 'already exists'][table_exist_bool]}"
188
    )
189

190
    return table_exist_bool
1✔
191

192

193
def get_foreign_tables_list(schema_files: list[str]) -> np.ndarray:
1✔
194
    foreign_tables_list = []
1✔
195
    for schema_file_path in schema_files:
1✔
196
        table_names = []
1✔
197
        with open(schema_file_path, "r", encoding="utf8") as schema_file:
1✔
198
            schema = schema_file.read()
1✔
199
        if "FOREIGN KEY" not in schema:
1✔
200
            pass
1✔
201
        else:
202
            schema = schema.replace("\n", "")
1✔
203
            schema = schema.replace("\t", "")
1✔
204
            schema_split = np.array(schema.split(","))
1✔
205
            fk_rows = np.array(["FOREIGN KEY" in x for x in schema_split])
1✔
206
            for row in schema_split[fk_rows]:
1✔
207
                words = np.array(row.split(" "))
1✔
208
                refmask = np.array(["REFERENCES" in x for x in words])
1✔
209
                idx = np.where(refmask)[0][0] + 1
1✔
210
                tablename = words[idx].split("(")[0]
1✔
211
                table_names.append(tablename)
1✔
212
        foreign_tables_list.append(np.array(table_names))
1✔
213
    return np.array(foreign_tables_list)
1✔
214

215

216
def get_ordered_schema_list(schema_files: list[str]) -> list[str]:
1✔
217
    foreign_tables_list = get_foreign_tables_list(schema_files)
1✔
218
    ordered_schema_list = []
1✔
219
    tables_created = []
1✔
220
    schema_table_names = [x.split("/")[-1].split(".sql")[0] for x in schema_files]
1✔
221
    while len(tables_created) < len(schema_files):
1✔
222
        for ind, schema_file in enumerate(schema_files):
1✔
223
            table_name = schema_table_names[ind]
1✔
224
            if table_name in tables_created:
1✔
225
                pass
1✔
226
            else:
227
                foreign_tables = foreign_tables_list[ind]
1✔
228
                if len(foreign_tables) == 0:
1✔
229
                    ordered_schema_list.append(schema_file)
1✔
230
                    tables_created.append(table_name)
1✔
231
                else:
232
                    if np.all(np.isin(foreign_tables, tables_created)):
1✔
233
                        ordered_schema_list.append(schema_file)
1✔
234
                        tables_created.append(table_name)
1✔
235

236
    return ordered_schema_list
1✔
237

238

239
def create_tables_from_schema(
1✔
240
    schema_dir: str,
241
    db_name: str,
242
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
243
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
244
):
245
    schema_files = glob(f"{schema_dir}/*.sql")
1✔
246
    ordered_schema_files = get_ordered_schema_list(schema_files)
1✔
247
    logger.info(f"Creating the following tables - {ordered_schema_files}")
1✔
248
    for schema_file in ordered_schema_files:
1✔
249
        create_table(
1✔
250
            schema_path=schema_file, db_name=db_name, db_user=db_user, password=password
251
        )
252

253

254
def export_to_db(
1✔
255
    value_dict: dict | astropy.io.fits.Header,
256
    db_name: str,
257
    db_table: str,
258
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
259
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
260
    duplicate_protocol: str = "fail",
261
) -> tuple[list, list]:
262
    with psycopg.connect(
1✔
263
        f"dbname={db_name} user={db_user} password={password}"
264
    ) as conn:
265
        conn.autocommit = True
1✔
266

267
        sql_query = f"""
1✔
268
        SELECT Col.Column_Name from
269
            INFORMATION_SCHEMA.TABLE_CONSTRAINTS Tab,
270
            INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE Col
271
        WHERE
272
            Col.Constraint_Name = Tab.Constraint_Name
273
            AND Col.Table_Name = Tab.Table_Name
274
            AND Constraint_Type = 'PRIMARY KEY'
275
            AND Col.Table_Name = '{db_table}'
276
        """
277
        serial_keys, serial_key_values = [], []
1✔
278
        with conn.execute(sql_query) as cursor:
1✔
279

280
            primary_key = [x[0] for x in cursor.fetchall()]
1✔
281
            serial_keys = list(
1✔
282
                get_sequence_keys_from_table(db_table, db_name, db_user, password)
283
            )
284
            logger.debug(serial_keys)
1✔
285
            colnames = [
1✔
286
                desc[0]
287
                for desc in conn.execute(
288
                    f"SELECT * FROM {db_table} LIMIT 1"
289
                ).description
290
                if desc[0] not in serial_keys
291
            ]
292

293
            colnames_str = ""
1✔
294
            for column in colnames:
1✔
295
                colnames_str += f'"{column}",'
1✔
296
            colnames_str = colnames_str[:-1]
1✔
297
            txt = f"INSERT INTO {db_table} ({colnames_str}) VALUES ("
1✔
298

299
            for char in ["[", "]", "'"]:
1✔
300
                txt = txt.replace(char, "")
1✔
301

302
            for column in colnames:
1✔
303
                txt += f"'{str(value_dict[column])}', "
1✔
304

305
            txt = txt + ") "
1✔
306
            txt = txt.replace(", )", ")")
1✔
307

308
            if len(serial_keys) > 0:
1✔
309
                txt += "RETURNING "
1✔
310
                for key in serial_keys:
1✔
311
                    txt += f"{key},"
1✔
312
                txt += ";"
1✔
313
                txt = txt.replace(",;", ";")
1✔
314

315
            logger.debug(txt)
1✔
316
            command = txt
1✔
317

318
            try:
1✔
319
                cursor.execute(command)
1✔
320
                if len(serial_keys) > 0:
1✔
321
                    serial_key_values = cursor.fetchall()[0]
1✔
322
                else:
323
                    serial_key_values = []
1✔
324

325
            except errors.UniqueViolation as exc:
×
326
                primary_key_values = [value_dict[x] for x in primary_key]
×
327

328
                if duplicate_protocol == "fail":
×
329
                    err = (
330
                        f"Duplicate error, entry with "
331
                        f"{primary_key}={primary_key_values} "
332
                        f"already exists in {db_name}."
333
                    )
334
                    logger.error(err)
335
                    raise errors.UniqueViolation from exc
336

337
                if duplicate_protocol == "ignore":
×
338
                    logger.debug(
×
339
                        f"Found duplicate entry with "
340
                        f"{primary_key}={primary_key_values} in {db_name}. Ignoring."
341
                    )
342
                elif duplicate_protocol == "replace":
×
343
                    logger.debug(
×
344
                        f"Updating duplicate entry with "
345
                        f"{primary_key}={primary_key_values} in {db_name}."
346
                    )
347
                    update_colnames = []
×
348
                    for column in colnames:
×
349
                        if column not in primary_key:
×
350
                            update_colnames.append(column)
×
351
                    serial_key_values = modify_db_entry(
×
352
                        db_query_columns=primary_key,
353
                        db_query_values=primary_key_values,
354
                        value_dict=value_dict,
355
                        db_alter_columns=update_colnames,
356
                        db_table=db_table,
357
                        db_name=db_name,
358
                        db_user=db_user,
359
                        password=password,
360
                        return_columns=serial_keys,
361
                    )
362

363
    return serial_keys, serial_key_values
1✔
364

365

366
def parse_constraints(db_query_columns, db_comparison_types, db_accepted_values):
1✔
367
    assert len(db_comparison_types) == len(db_accepted_values)
1✔
368
    assert np.all(np.isin(np.unique(db_comparison_types), ["=", "<", ">", "between"]))
1✔
369
    constraints = ""
1✔
370
    for i, column in enumerate(db_query_columns):
1✔
371
        if db_comparison_types[i] == "between":
1✔
372
            assert len(db_accepted_values[i]) == 2
×
373
            constraints += (
×
374
                f"{column} between {db_accepted_values[i][0]} "
375
                f"and {db_accepted_values[i][1]} AND "
376
            )
377
        else:
378
            constraints += (
1✔
379
                f"{column} {db_comparison_types[i]} " f"{db_accepted_values[i]} AND "
380
            )
381

382
        constraints = constraints[:-4]  # strip the last AND
1✔
383

384
    return constraints
1✔
385

386

387
def parse_select():
1✔
388
    pass
×
389

390

391
def import_from_db(
1✔
392
    db_name: str,
393
    db_table: str,
394
    db_query_columns: str | list[str],
395
    db_accepted_values: str | int | float | list[str | float | int | list],
396
    db_output_columns: str | list[str],
397
    output_alias_map: str | list[str],
398
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
399
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
400
    max_num_results: Optional[int] = None,
401
    db_comparison_types: Optional[list[str]] = None,
402
) -> list[dict]:
403
    """Query an SQL database with constraints, and return a list of dictionaries.
404
    One dictionary per entry returned from the query.
405

406
    Parameters
407
    ----------
408
    db_name: Name of database to query
409
    db_table: Name of database table to query
410
    db_query_columns: Name of column to query
411
    db_accepted_values: Accepted value for query for column
412
    db_output_columns: Name(s) of columns to return for matched database entries
413
    output_alias_map: Alias to assign for each output column
414
    db_user: Username for database
415
    password: password for database
416

417
    Returns
418
    -------
419
    A list of dictionaries (one per entry)
420
    """
421

422
    if not isinstance(db_query_columns, list):
×
423
        db_query_columns = [db_query_columns]
×
424

425
    if not isinstance(db_accepted_values, list):
×
426
        db_accepted_values = [db_accepted_values]
×
427

428
    if not isinstance(db_output_columns, list):
×
429
        db_output_columns = [db_output_columns]
×
430

431
    assert len(db_query_columns) == len(db_accepted_values)
×
432

433
    if output_alias_map is None:
×
434
        output_alias_map = db_output_columns
×
435

436
    if not isinstance(output_alias_map, list):
×
437
        output_alias_map = [output_alias_map]
×
438

439
    assert len(output_alias_map) == len(db_output_columns)
×
440

441
    all_query_res = []
×
442

443
    if db_comparison_types is None:
×
444
        db_comparison_types = ["="] * len(db_accepted_values)
×
445
    assert len(db_comparison_types) == len(db_accepted_values)
×
446
    assert np.isin(np.all(np.unique(db_comparison_types), ["=", "<", ">", "between"]))
×
447

448
    constraints = parse_constraints(
×
449
        db_query_columns, db_comparison_types, db_accepted_values
450
    )
451

452
    with psycopg.connect(
×
453
        f"dbname={db_name} user={db_user} password={password}"
454
    ) as conn:
455
        conn.autocommit = True
×
456
        sql_query = f"""
×
457
        SELECT {', '.join(db_output_columns)} from {db_table}
458
            WHERE {constraints}
459
        """
460

461
        if max_num_results is not None:
×
462
            sql_query += f" LIMIT {max_num_results}"
×
463

464
        sql_query += ";"
×
465

466
        logger.debug(f"Query: {sql_query}")
×
467

468
        with conn.execute(sql_query) as cursor:
×
469
            query_output = cursor.fetchall()
×
470

471
        for entry in query_output:
×
472

473
            assert len(entry) == len(db_output_columns)
×
474

475
            query_res = {}
×
476

477
            for i, key in enumerate(output_alias_map):
×
478
                query_res[key] = entry[i]
×
479

480
            all_query_res.append(query_res)
×
481

482
    return all_query_res
×
483

484

485
def execute_query(sql_query, db_name, db_user, password):
1✔
486
    with psycopg.connect(
1✔
487
        f"dbname={db_name} user={db_user} password={password}"
488
    ) as conn:
489
        conn.autocommit = True
1✔
490
        logger.debug(f"Query: {sql_query}")
1✔
491

492
        with conn.execute(sql_query) as cursor:
1✔
493
            query_output = cursor.fetchall()
1✔
494

495
        return query_output
1✔
496

497

498
def xmatch_import_db(
1✔
499
    db_name: str,
500
    db_table: str,
501
    db_query_columns: str | list[str],
502
    db_accepted_values: str | int | float | list[str | float | int],
503
    db_output_columns: str | list[str],
504
    output_alias_map: str | list[str],
505
    ra: float,
506
    dec: float,
507
    xmatch_radius_arcsec: float,
508
    ra_field_name: str = "ra",
509
    dec_field_name: str = "dec",
510
    query_dist=False,
511
    q3c=False,
512
    db_comparison_types: Optional[list[str]] = None,
513
    order_field_name: Optional[str] = None,
514
    num_limit: Optional[int] = None,
515
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
516
    db_password: str = os.environ.get(PG_ADMIN_PWD_KEY),
517
) -> list[dict]:
518

519
    if output_alias_map is None:
×
520
        output_alias_map = {}
×
521
        for col in db_output_columns:
×
522
            output_alias_map[col] = col
×
523

524
    xmatch_radius_deg = xmatch_radius_arcsec / 3600.0
×
525

526
    if q3c:
×
527
        constraints = (
×
528
            f"q3c_radial_query({ra_field_name},{dec_field_name},"
529
            f"{ra},{dec},{xmatch_radius_deg}) "
530
        )
531
    else:
532
        ra_min = ra - xmatch_radius_deg
×
533
        ra_max = ra + xmatch_radius_deg
×
534
        dec_min = dec - xmatch_radius_deg
×
535
        dec_max = dec + xmatch_radius_deg
×
536
        constraints = (
×
537
            f" {ra_field_name} between {ra_min} and {ra_max} AND "
538
            f"{dec_field_name} between {dec_min} and {dec_max} "
539
        )
540

541
    parsed_constraints = parse_constraints(
×
542
        db_query_columns, db_comparison_types, db_accepted_values
543
    )
544
    if len(parsed_constraints) > 0:
×
545
        constraints += f"""AND {parsed_constraints}"""
×
546

547
    select = f""" {'"' + '","'.join(db_output_columns) + '"'}"""
×
548
    if query_dist:
×
549
        if q3c:
×
550
            select = (
×
551
                f"""q3c_dist({ra_field_name},{dec_field_name},{ra},{dec}) AS xdist,"""
552
                + select
553
            )
554
        else:
555
            select = f"""{ra_field_name} - ra AS xdist,""" + select
×
556

557
    query = f"""SELECT {select} FROM {db_table} WHERE {constraints}"""
×
558

559
    if order_field_name is not None:
×
560
        query += f""" ORDER BY {order_field_name}"""
×
561
    if num_limit is not None:
×
562
        query += f""" LIMIT {num_limit}"""
×
563

564
    query += ";"
×
565

566
    query_output = execute_query(query, db_name, db_user, db_password)
×
567
    all_query_res = []
×
568

569
    for entry in query_output:
×
570
        if not query_dist:
×
571
            assert len(entry) == len(db_output_columns)
×
572
        else:
573
            assert len(entry) == len(db_output_columns) + 1
×
574
        query_res = {}
×
575
        for i, key in enumerate(output_alias_map):
×
576
            query_res[key] = entry[i]
×
577
            if query_dist:
×
578
                query_res["dist"] = entry["xdist"]
×
579
        all_query_res.append(query_res)
×
580
    return all_query_res
×
581

582

583
def get_sequence_keys_from_table(
1✔
584
    db_table: str, db_name: str, db_user: str, password: str
585
):
586
    with psycopg.connect(
1✔
587
        f"dbname={db_name} user={db_user} password={password}"
588
    ) as conn:
589
        conn.autocommit = True
1✔
590
        sequences = [
1✔
591
            x[0]
592
            for x in conn.execute(
593
                "SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';"
594
            ).fetchall()
595
        ]
596
        seq_tables = np.array([x.split("_")[0] for x in sequences])
1✔
597
        seq_columns = np.array([x.split("_")[1] for x in sequences])
1✔
598
        table_sequence_keys = seq_columns[(seq_tables == db_table)]
1✔
599
    return table_sequence_keys
1✔
600

601

602
def modify_db_entry(
1✔
603
    db_name: str,
604
    db_table: str,
605
    db_query_columns: str | list[str],
606
    db_query_values: str | int | float | list[str | float | int | list],
607
    value_dict: dict | astropy.io.fits.Header,
608
    db_alter_columns: str | list[str],
609
    return_columns: Optional[str | list[str]] = None,
610
    db_query_comparison_types: Optional[list[str]] = None,
611
    db_user: str = os.environ.get(PG_ADMIN_USER_KEY),
612
    password: str = os.environ.get(PG_ADMIN_PWD_KEY),
613
):
614
    if not isinstance(db_query_columns, list):
1✔
615
        db_query_columns = [db_query_columns]
×
616
    if not isinstance(db_query_values, list):
1✔
617
        db_query_values = [db_query_values]
×
618
    if not isinstance(db_alter_columns, list):
1✔
619
        db_alter_columns = [db_alter_columns]
1✔
620

621
    if return_columns is None:
1✔
622
        return_columns = db_alter_columns
1✔
623
    if not isinstance(return_columns, list):
1✔
624
        return_columns = [return_columns]
×
625

626
    assert len(db_query_columns) == len(db_query_values)
1✔
627

628
    if db_query_comparison_types is None:
1✔
629
        db_query_comparison_types = ["="] * len(db_query_values)
×
630
    assert len(db_query_comparison_types) == len(db_query_values)
1✔
631
    assert np.all(
1✔
632
        np.isin(np.unique(db_query_comparison_types), ["=", "<", ">", "between"])
633
    )
634

635
    parsed_constraints = parse_constraints(
1✔
636
        db_query_columns, db_query_comparison_types, db_query_values
637
    )
638

639
    constraints = f"""{parsed_constraints}"""
1✔
640
    logger.debug(db_query_columns)
1✔
641
    with psycopg.connect(
1✔
642
        f"dbname={db_name} user={db_user} password={password}"
643
    ) as conn:
644
        conn.autocommit = True
1✔
645

646
        db_alter_values = [str(value_dict[c]) for c in db_alter_columns]
1✔
647

648
        alter_values_txt = [
1✔
649
            f"{db_alter_columns[ind]}='{db_alter_values[ind]}'"
650
            for ind in range(len(db_alter_columns))
651
        ]
652

653
        sql_query = f"""
1✔
654
                UPDATE {db_table} SET {', '.join(alter_values_txt)} WHERE {constraints}
655
                """
656
        if len(return_columns) > 0:
1✔
657
            logger.debug(return_columns)
1✔
658
            sql_query += f""" RETURNING {', '.join(return_columns)}"""
1✔
659
        sql_query += ";"
1✔
660
        query_output = execute_query(sql_query, db_name, db_user, password)
1✔
661

662
    return query_output
1✔
663

664

665
def get_colnames_from_schema(schema_file_path):
1✔
666
    with open(schema_file_path, "r", encoding="utf8") as schema_file:
1✔
667
        dat = schema_file.read()
1✔
668
    dat = dat.split(");")[0]
1✔
669
    dat = dat.split("\n")[1:-1]
1✔
670
    pkstrip = [x.strip(",").split("PRIMARY KEY")[0].strip() for x in dat]
1✔
671
    fkstrip = [x.strip(",").split("FOREIGN KEY")[0].strip() for x in pkstrip]
1✔
672
    colnames = [x.split(" ")[0].strip('"') for x in fkstrip]
1✔
673
    return colnames
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc