• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

winter-telescope / winterdrp / 3596924942

pending completion
3596924942

Pull #224

github

GitHub
Merge 0ea201d54 into 00fbdf6f7
Pull Request #224: Code with Style

1490 of 1490 new or added lines in 93 files covered. (100.0%)

4571 of 6109 relevant lines covered (74.82%)

0.75 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.78
/winterdrp/processors/database/postgres.py
1
import logging
1✔
2
import os
1✔
3
from glob import glob
1✔
4

5
import astropy.io.fits
1✔
6
import numpy as np
1✔
7
import psycopg
1✔
8
from psycopg import errors
1✔
9

10
from winterdrp.errors import ProcessorError
1✔
11

12
logger = logging.getLogger(__name__)
1✔
13

14
schema_dir = os.path.join(os.path.dirname(__file__), "schema")
1✔
15

16
pg_admin_user_key = "PG_ADMIN_USER"
1✔
17
pg_admin_pwd_key = "PG_ADMIN_PWD"
1✔
18

19

20
class DataBaseError(ProcessorError):
1✔
21
    pass
1✔
22

23

24
def validate_credentials(db_user: str, password: str, admin=False):
1✔
25

26
    if db_user is None:
1✔
27
        user = "db_user"
×
28
        env_user_var = "DB_USER"
×
29
        if admin:
×
30
            user = "admin_db_user"
×
31
            env_user_var = pg_admin_user_key
×
32
        err = (
33
            f"'{user}' is set as None. Please pass a db_user as an argument, "
34
            f"or set the environment variable '{env_user_var}'. Using "
35
        )
36
        logger.warning(err)
×
37
        raise DataBaseError(err)
38

39
    if password is None:
1✔
40
        pwd = "password"
×
41
        env_pwd_var = "DB_PWD"
×
42
        if admin:
×
43
            pwd = "db_admin_password"
×
44
            env_pwd_var = pg_admin_pwd_key
×
45
        err = (
46
            f"'{pwd}' is set as None. Please pass a password as an argument, "
47
            f"or set the environment variable '{env_pwd_var}'."
48
        )
49
        logger.error(err)
50
        raise DataBaseError(err)
51

52

53
def create_db(db_name: str):
1✔
54
    admin_user = os.environ.get(pg_admin_user_key)
1✔
55
    admin_password = os.environ.get(pg_admin_pwd_key)
1✔
56
    validate_credentials(db_user=admin_user, password=admin_password)
1✔
57

58
    with psycopg.connect(
1✔
59
        f"dbname=postgres user={admin_user} password={admin_password}"
60
    ) as conn:
61
        conn.autocommit = True
1✔
62
        sql = f"""CREATE database {db_name}"""
1✔
63
        conn.execute(sql)
1✔
64
        logger.info(f"Created db {db_name}")
1✔
65

66

67
def run_sql_command_from_file(file_path, db_name, db_user, password, admin=False):
1✔
68
    validate_credentials(db_name, db_user, admin)
×
69
    with psycopg.connect(
×
70
        f"dbname={db_name} user={db_user} password={password}"
71
    ) as conn:
72
        with open(file_path, "r") as f:
×
73
            conn.execute(f.read())
×
74

75
        logger.info(f"Executed sql commands from file {file_path}")
×
76

77

78
def create_table(schema_path: str, db_name: str, db_user: str, password: str):
1✔
79
    validate_credentials(db_user, password)
1✔
80

81
    with psycopg.connect(
1✔
82
        f"dbname={db_name} user={db_user} password={password}"
83
    ) as conn:
84
        conn.autocommit = True
1✔
85
        with open(schema_path, "r") as f:
1✔
86
            conn.execute(f.read())
1✔
87

88
    logger.info(f"Created table from schema path {schema_path}")
1✔
89

90

91
def create_new_user(new_db_user: str, new_password: str):
1✔
92
    admin_user = os.environ.get(pg_admin_user_key)
×
93
    admin_password = os.environ.get(pg_admin_pwd_key)
×
94

95
    validate_credentials(new_db_user, new_password)
×
96
    validate_credentials(db_user=admin_user, password=admin_password, admin=True)
×
97

98
    with psycopg.connect(
×
99
        f"dbname=postgres user={admin_user} password={admin_password}"
100
    ) as conn:
101
        conn.autocommit = True
×
102
        command = f"""CREATE ROLE {new_db_user} WITH password '{new_password}' LOGIN;"""
×
103
        conn.execute(command)
×
104

105

106
def grant_privileges(db_name: str, db_user: str):
1✔
107
    admin_user = os.environ.get(pg_admin_user_key)
1✔
108
    admin_password = os.environ.get(pg_admin_pwd_key)
1✔
109
    validate_credentials(admin_user, admin_password, admin=True)
1✔
110

111
    with psycopg.connect(
1✔
112
        f"dbname=postgres user={admin_user} password={admin_password}"
113
    ) as conn:
114
        conn.autocommit = True
1✔
115
        command = f"""GRANT ALL PRIVILEGES ON DATABASE {db_name} TO {db_user};"""
1✔
116
        conn.execute(command)
1✔
117

118

119
def check_if_user_exists(
1✔
120
    user_name: str,
121
    db_user: str = os.environ.get(pg_admin_user_key),
122
    password: str = os.environ.get(pg_admin_pwd_key),
123
) -> bool:
124
    validate_credentials(db_user, password)
1✔
125

126
    with psycopg.connect(f"dbname=postgres user={db_user} password={password}") as conn:
1✔
127
        conn.autocommit = True
1✔
128
        command = """SELECT usename FROM pg_user;"""
1✔
129
        data = conn.execute(command).fetchall()
1✔
130
    existing_user_names = [x[0] for x in data]
1✔
131
    logger.debug(f"Found the following users: {existing_user_names}")
1✔
132

133
    user_exist_bool = user_name in existing_user_names
1✔
134
    logger.debug(
1✔
135
        f"User '{user_name}' {['does not exist', 'already exists'][user_exist_bool]}"
136
    )
137
    return user_exist_bool
1✔
138

139

140
def check_if_db_exists(
1✔
141
    db_name: str,
142
    db_user: str = os.environ.get(pg_admin_user_key),
143
    password: str = os.environ.get(pg_admin_pwd_key),
144
) -> bool:
145
    validate_credentials(db_user, password)
1✔
146

147
    with psycopg.connect(f"dbname=postgres user={db_user} password={password}") as conn:
1✔
148
        conn.autocommit = True
1✔
149
        command = """SELECT datname FROM pg_database;"""
1✔
150
        data = conn.execute(command).fetchall()
1✔
151

152
    existing_db_names = [x[0] for x in data]
1✔
153
    logger.debug(f"Found the following databases: {existing_db_names}")
1✔
154

155
    db_exist_bool = db_name in existing_db_names
1✔
156
    logger.debug(
1✔
157
        f"Database '{db_name}' {['does not exist', 'already exists'][db_exist_bool]}"
158
    )
159

160
    return db_exist_bool
1✔
161

162

163
def check_if_table_exists(
1✔
164
    db_name: str, db_table: str, db_user: str, password: str
165
) -> bool:
166
    validate_credentials(db_user=db_user, password=password)
1✔
167

168
    with psycopg.connect(
1✔
169
        f"dbname={db_name} user={db_user} password={password}"
170
    ) as conn:
171
        conn.autocommit = True
1✔
172
        command = """SELECT table_name FROM information_schema.tables WHERE table_schema='public';"""
1✔
173
        data = conn.execute(command).fetchall()
1✔
174

175
    existing_table_names = [x[0] for x in data]
1✔
176
    logger.debug(f"Found the following tables: {existing_table_names}")
1✔
177

178
    table_exist_bool = db_table in existing_table_names
1✔
179
    logger.debug(
1✔
180
        f"Database table '{db_table}' {['does not exist', 'already exists'][table_exist_bool]}"
181
    )
182

183
    return table_exist_bool
1✔
184

185

186
def get_foreign_tables_list(schema_files: list[str]) -> np.ndarray:
1✔
187
    foreign_tables_list = []
1✔
188
    for schema_file in schema_files:
1✔
189
        table_names = []
1✔
190
        with open(schema_file, "r") as f:
1✔
191
            schema = f.read()
1✔
192
        if "FOREIGN KEY" not in schema:
1✔
193
            pass
1✔
194
        else:
195
            schema = schema.replace("\n", "")
1✔
196
            schema = schema.replace("\t", "")
1✔
197
            schema_split = np.array(schema.split(","))
1✔
198
            fk_rows = np.array(["FOREIGN KEY" in x for x in schema_split])
1✔
199
            for row in schema_split[fk_rows]:
1✔
200
                words = np.array(row.split(" "))
1✔
201
                refmask = np.array(["REFERENCES" in x for x in words])
1✔
202
                idx = np.where(refmask)[0][0] + 1
1✔
203
                tablename = words[idx].split("(")[0]
1✔
204
                table_names.append(tablename)
1✔
205
        foreign_tables_list.append(np.array(table_names))
1✔
206
    return np.array(foreign_tables_list)
1✔
207

208

209
def get_ordered_schema_list(schema_files: list[str]) -> list[str]:
1✔
210
    foreign_tables_list = get_foreign_tables_list(schema_files)
1✔
211
    ordered_schema_list = []
1✔
212
    tables_created = []
1✔
213
    schema_table_names = [x.split("/")[-1].split(".sql")[0] for x in schema_files]
1✔
214
    while len(tables_created) < len(schema_files):
1✔
215
        for ind, schema_file in enumerate(schema_files):
1✔
216
            table_name = schema_table_names[ind]
1✔
217
            if table_name in tables_created:
1✔
218
                pass
1✔
219
            else:
220
                foreign_tables = foreign_tables_list[ind]
1✔
221
                if len(foreign_tables) == 0:
1✔
222
                    ordered_schema_list.append(schema_file)
1✔
223
                    tables_created.append(table_name)
1✔
224
                else:
225
                    if np.all(np.isin(foreign_tables, tables_created)):
1✔
226
                        ordered_schema_list.append(schema_file)
1✔
227
                        tables_created.append(table_name)
1✔
228

229
    return ordered_schema_list
1✔
230

231

232
def create_tables_from_schema(
1✔
233
    schema_dir: str,
234
    db_name: str,
235
    db_user: str = os.environ.get(pg_admin_user_key),
236
    password: str = os.environ.get(pg_admin_pwd_key),
237
):
238
    schema_files = glob(f"{schema_dir}/*.sql")
1✔
239
    ordered_schema_files = get_ordered_schema_list(schema_files)
1✔
240
    logger.info(f"Creating the following tables - {ordered_schema_files}")
1✔
241
    for schema_file in ordered_schema_files:
1✔
242
        create_table(
1✔
243
            schema_path=schema_file, db_name=db_name, db_user=db_user, password=password
244
        )
245

246

247
def export_to_db(
1✔
248
    value_dict: dict | astropy.io.fits.Header,
249
    db_name: str,
250
    db_table: str,
251
    db_user: str = os.environ.get(pg_admin_user_key),
252
    password: str = os.environ.get(pg_admin_pwd_key),
253
    duplicate_protocol: str = "fail",
254
) -> tuple[list, list]:
255
    with psycopg.connect(
1✔
256
        f"dbname={db_name} user={db_user} password={password}"
257
    ) as conn:
258
        conn.autocommit = True
1✔
259

260
        sql_query = f"""
1✔
261
        SELECT Col.Column_Name from
262
            INFORMATION_SCHEMA.TABLE_CONSTRAINTS Tab,
263
            INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE Col
264
        WHERE
265
            Col.Constraint_Name = Tab.Constraint_Name
266
            AND Col.Table_Name = Tab.Table_Name
267
            AND Constraint_Type = 'PRIMARY KEY'
268
            AND Col.Table_Name = '{db_table}'
269
        """
270
        serial_keys, serial_key_values = [], []
1✔
271
        with conn.execute(sql_query) as cursor:
1✔
272

273
            primary_key = [x[0] for x in cursor.fetchall()]
1✔
274
            serial_keys = [
1✔
275
                x
276
                for x in get_sequence_keys_from_table(
277
                    db_table, db_name, db_user, password
278
                )
279
            ]
280
            logger.debug(serial_keys)
1✔
281
            colnames = [
1✔
282
                desc[0]
283
                for desc in conn.execute(
284
                    f"SELECT * FROM {db_table} LIMIT 1"
285
                ).description
286
                if desc[0] not in serial_keys
287
            ]
288

289
            colnames_str = ""
1✔
290
            for x in colnames:
1✔
291
                colnames_str += f'"{x}",'
1✔
292
            colnames_str = colnames_str[:-1]
1✔
293
            txt = f"INSERT INTO {db_table} ({colnames_str}) VALUES ("
1✔
294

295
            for char in ["[", "]", "'"]:
1✔
296
                txt = txt.replace(char, "")
1✔
297

298
            for c in colnames:
1✔
299
                txt += f"'{str(value_dict[c])}', "
1✔
300

301
            txt = txt + ") "
1✔
302
            txt = txt.replace(", )", ")")
1✔
303

304
            if len(serial_keys) > 0:
1✔
305
                txt += f"RETURNING "
1✔
306
                for key in serial_keys:
1✔
307
                    txt += f"{key},"
1✔
308
                txt += ";"
1✔
309
                txt = txt.replace(",;", ";")
1✔
310

311
            logger.debug(txt)
1✔
312
            command = txt
1✔
313

314
            try:
1✔
315
                cursor.execute(command)
1✔
316
                if len(serial_keys) > 0:
1✔
317
                    serial_key_values = cursor.fetchall()[0]
1✔
318
                else:
319
                    serial_key_values = []
1✔
320

321
            except errors.UniqueViolation:
×
322
                primary_key_values = [value_dict[x] for x in primary_key]
×
323
                if duplicate_protocol == "fail":
×
324
                    err = f"Duplicate error, entry with {primary_key}={primary_key_values} already exists in {db_name}."
325
                    logger.error(err)
326
                    raise errors.UniqueViolation
327
                elif duplicate_protocol == "ignore":
×
328
                    logger.debug(
×
329
                        f"Found duplicate entry with {primary_key}={primary_key_values} in {db_name}. Ignoring."
330
                    )
331
                    pass
×
332
                elif duplicate_protocol == "replace":
×
333
                    logger.debug(
×
334
                        f"Updating duplicate entry with {primary_key}={primary_key_values} in {db_name}."
335
                    )
336
                    update_colnames = []
×
337
                    for x in colnames:
×
338
                        if not x in primary_key:
×
339
                            update_colnames.append(x)
×
340
                    serial_key_values = modify_db_entry(
×
341
                        db_query_columns=primary_key,
342
                        db_query_values=primary_key_values,
343
                        value_dict=value_dict,
344
                        db_alter_columns=update_colnames,
345
                        db_table=db_table,
346
                        db_name=db_name,
347
                        db_user=db_user,
348
                        password=password,
349
                        return_columns=serial_keys,
350
                    )
351

352
    return serial_keys, serial_key_values
1✔
353

354

355
def parse_constraints(db_query_columns, db_comparison_types, db_accepted_values):
1✔
356
    assert len(db_comparison_types) == len(db_accepted_values)
1✔
357
    assert np.all(np.isin(np.unique(db_comparison_types), ["=", "<", ">", "between"]))
1✔
358
    constraints = ""
1✔
359
    for i, x in enumerate(db_query_columns):
1✔
360
        if db_comparison_types[i] == "between":
1✔
361
            assert len(db_accepted_values[i]) == 2
×
362
            constraints += f"{x} between {db_accepted_values[i][0]} and {db_accepted_values[i][1]} AND "
×
363
        else:
364
            constraints += f"{x} {db_comparison_types[i]} {db_accepted_values[i]} AND "
1✔
365

366
        constraints = constraints[:-4]  # strip the last AND
1✔
367

368
    return constraints
1✔
369

370

371
def parse_select():
1✔
372
    pass
×
373

374

375
def import_from_db(
1✔
376
    db_name: str,
377
    db_table: str,
378
    db_query_columns: str | list[str],
379
    db_accepted_values: str | int | float | list[str | float | int | list],
380
    db_output_columns: str | list[str],
381
    output_alias_map: str | list[str],
382
    db_user: str = os.environ.get(pg_admin_user_key),
383
    password: str = os.environ.get(pg_admin_pwd_key),
384
    max_num_results: int = None,
385
    db_comparison_types: list[str] = None,
386
) -> list[dict]:
387
    """Query an SQL database with constraints, and return a list of dictionaries.
388
    One dictionary per entry returned from the query.
389

390
    Parameters
391
    ----------
392
    db_name: Name of database to query
393
    db_table: Name of database table to query
394
    db_query_columns: Name of column to query
395
    db_accepted_values: Accepted value for query for column
396
    db_output_columns: Name(s) of columns to return for matched database entries
397
    output_alias_map: Alias to assign for each output column
398
    db_user: Username for database
399
    password: password for database
400

401
    Returns
402
    -------
403
    A list of dictionaries (one per entry)
404
    """
405

406
    if not isinstance(db_query_columns, list):
×
407
        db_query_columns = [db_query_columns]
×
408

409
    if not isinstance(db_accepted_values, list):
×
410
        db_accepted_values = [db_accepted_values]
×
411

412
    if not isinstance(db_output_columns, list):
×
413
        db_output_columns = [db_output_columns]
×
414

415
    assert len(db_query_columns) == len(db_accepted_values)
×
416

417
    if output_alias_map is None:
×
418
        output_alias_map = db_output_columns
×
419

420
    if not isinstance(output_alias_map, list):
×
421
        output_alias_map = [output_alias_map]
×
422

423
    assert len(output_alias_map) == len(db_output_columns)
×
424

425
    all_query_res = []
×
426

427
    if db_comparison_types is None:
×
428
        db_comparison_types = ["="] * len(db_accepted_values)
×
429
    assert len(db_comparison_types) == len(db_accepted_values)
×
430
    assert np.isin(np.all(np.unique(db_comparison_types), ["=", "<", ">", "between"]))
×
431

432
    constraints = parse_constraints(
×
433
        db_query_columns, db_comparison_types, db_accepted_values
434
    )
435
    # constraints = " AND ".join([f"{x} {db_comparison_types[i]} {db_accepted_values[i]}" for i, x in enumerate(
436
    # db_query_columns)])
437

438
    with psycopg.connect(
×
439
        f"dbname={db_name} user={db_user} password={password}"
440
    ) as conn:
441
        conn.autocommit = True
×
442
        sql_query = f"""
×
443
        SELECT {', '.join(db_output_columns)} from {db_table}
444
            WHERE {constraints}
445
        """
446

447
        if max_num_results is not None:
×
448
            sql_query += f" LIMIT {max_num_results}"
×
449

450
        sql_query += f";"
×
451

452
        logger.debug(f"Query: {sql_query}")
×
453

454
        with conn.execute(sql_query) as cursor:
×
455
            query_output = cursor.fetchall()
×
456

457
        for entry in query_output:
×
458

459
            assert len(entry) == len(db_output_columns)
×
460

461
            query_res = dict()
×
462

463
            for i, key in enumerate(output_alias_map):
×
464
                query_res[key] = entry[i]
×
465

466
            all_query_res.append(query_res)
×
467

468
    return all_query_res
×
469

470

471
def execute_query(sql_query, db_name, db_user, password):
1✔
472
    with psycopg.connect(
1✔
473
        f"dbname={db_name} user={db_user} password={password}"
474
    ) as conn:
475
        conn.autocommit = True
1✔
476
        logger.debug(f"Query: {sql_query}")
1✔
477

478
        with conn.execute(sql_query) as cursor:
1✔
479
            query_output = cursor.fetchall()
1✔
480

481
        return query_output
1✔
482

483

484
def xmatch_import_db(
1✔
485
    db_name: str,
486
    db_table: str,
487
    db_query_columns: str | list[str],
488
    db_accepted_values: str | int | float | list[str | float | int],
489
    db_output_columns: str | list[str],
490
    output_alias_map: str | list[str],
491
    ra: float,
492
    dec: float,
493
    xmatch_radius_arcsec: float,
494
    ra_field_name: str = "ra",
495
    dec_field_name: str = "dec",
496
    query_dist=False,
497
    q3c=False,
498
    db_comparison_types: list[str] = None,
499
    order_field_name: str = None,
500
    order_ascending: bool = True,
501
    num_limit: int = None,
502
    db_user: str = os.environ.get(pg_admin_user_key),
503
    db_password: str = os.environ.get(pg_admin_pwd_key),
504
) -> list[dict]:
505

506
    if output_alias_map is None:
×
507
        output_alias_map = {}
×
508
        for col in db_output_columns:
×
509
            output_alias_map[col] = col
×
510

511
    xmatch_radius_deg = xmatch_radius_arcsec / 3600.0
×
512

513
    if q3c:
×
514
        constraints = f"""q3c_radial_query({ra_field_name},{dec_field_name},{ra},{dec},{xmatch_radius_deg}) """
×
515
    else:
516
        ra_min = ra - xmatch_radius_deg
×
517
        ra_max = ra + xmatch_radius_deg
×
518
        dec_min = dec - xmatch_radius_deg
×
519
        dec_max = dec + xmatch_radius_deg
×
520
        constraints = f""" {ra_field_name} between {ra_min} and {ra_max} AND {dec_field_name} between {dec_min} and {dec_max} """
×
521

522
    parsed_constraints = parse_constraints(
×
523
        db_query_columns, db_comparison_types, db_accepted_values
524
    )
525
    if len(parsed_constraints) > 0:
×
526
        constraints += f"""AND {parsed_constraints}"""
×
527

528
    select = f""" {'"' + '","'.join(db_output_columns) + '"'}"""
×
529
    if query_dist:
×
530
        if q3c:
×
531
            select = (
×
532
                f"""q3c_dist({ra_field_name},{dec_field_name},{ra},{dec}) AS xdist,"""
533
                + select
534
            )
535
        else:
536
            select = f"""{ra_field_name} - ra AS xdist,""" + select
×
537

538
    query = f"""SELECT {select} FROM {db_table} WHERE {constraints}"""
×
539
    order_seq = ["asc", "desc"][np.sum(order_ascending)]
×
540
    if order_field_name is not None:
×
541
        query += f""" ORDER BY {order_field_name}"""
×
542
    if num_limit is not None:
×
543
        query += f""" LIMIT {num_limit}"""
×
544

545
    query += ";"
×
546

547
    query_output = execute_query(query, db_name, db_user, db_password)
×
548
    all_query_res = []
×
549

550
    for entry in query_output:
×
551
        if not query_dist:
×
552
            assert len(entry) == len(db_output_columns)
×
553
        else:
554
            assert len(entry) == len(db_output_columns) + 1
×
555
        query_res = dict()
×
556
        for i, key in enumerate(output_alias_map):
×
557
            query_res[key] = entry[i]
×
558
            if query_dist:
×
559
                query_res["dist"] = entry["xdist"]
×
560
        all_query_res.append(query_res)
×
561
    return all_query_res
×
562

563

564
def get_sequence_keys_from_table(
1✔
565
    db_table: str, db_name: str, db_user: str, password: str
566
):
567
    with psycopg.connect(
1✔
568
        f"dbname={db_name} user={db_user} password={password}"
569
    ) as conn:
570
        conn.autocommit = True
1✔
571
        sequences = [
1✔
572
            x[0]
573
            for x in conn.execute(
574
                f"SELECT c.relname FROM pg_class c WHERE c.relkind = 'S';"
575
            ).fetchall()
576
        ]
577
        seq_tables = np.array([x.split("_")[0] for x in sequences])
1✔
578
        seq_columns = np.array([x.split("_")[1] for x in sequences])
1✔
579
        table_sequence_keys = seq_columns[(seq_tables == db_table)]
1✔
580
    return table_sequence_keys
1✔
581

582

583
def modify_db_entry(
1✔
584
    db_name: str,
585
    db_table: str,
586
    db_query_columns: str | list[str],
587
    db_query_values: str | int | float | list[str | float | int | list],
588
    value_dict: dict | astropy.io.fits.Header,
589
    db_alter_columns: str | list[str],
590
    return_columns: str | list[str] = None,
591
    db_query_comparison_types: list[str] = None,
592
    db_user: str = os.environ.get(pg_admin_user_key),
593
    password: str = os.environ.get(pg_admin_pwd_key),
594
):
595
    if not isinstance(db_query_columns, list):
1✔
596
        db_query_columns = [db_query_columns]
×
597
    if not isinstance(db_query_values, list):
1✔
598
        db_query_values = [db_query_values]
×
599
    if not isinstance(db_alter_columns, list):
1✔
600
        db_alter_columns = [db_alter_columns]
1✔
601

602
    if return_columns is None:
1✔
603
        return_columns = db_alter_columns
1✔
604
    if not isinstance(return_columns, list):
1✔
605
        return_columns = [return_columns]
×
606

607
    assert len(db_query_columns) == len(db_query_values)
1✔
608

609
    if db_query_comparison_types is None:
1✔
610
        db_query_comparison_types = ["="] * len(db_query_values)
×
611
    assert len(db_query_comparison_types) == len(db_query_values)
1✔
612
    assert np.all(
1✔
613
        np.isin(np.unique(db_query_comparison_types), ["=", "<", ">", "between"])
614
    )
615

616
    parsed_constraints = parse_constraints(
1✔
617
        db_query_columns, db_query_comparison_types, db_query_values
618
    )
619

620
    constraints = f"""{parsed_constraints}"""
1✔
621
    logger.debug(db_query_columns)
1✔
622
    with psycopg.connect(
1✔
623
        f"dbname={db_name} user={db_user} password={password}"
624
    ) as conn:
625
        conn.autocommit = True
1✔
626

627
        db_alter_values = [str(value_dict[c]) for c in db_alter_columns]
1✔
628

629
        alter_values_txt = [
1✔
630
            f"{db_alter_columns[ind]}='{db_alter_values[ind]}'"
631
            for ind in range(len(db_alter_columns))
632
        ]
633

634
        sql_query = f"""
1✔
635
                UPDATE {db_table} SET {', '.join(alter_values_txt)} WHERE {constraints}
636
                """
637
        if len(return_columns) > 0:
1✔
638
            logger.debug(return_columns)
1✔
639
            sql_query += f""" RETURNING {', '.join(return_columns)}"""
1✔
640
        sql_query += ";"
1✔
641
        query_output = execute_query(sql_query, db_name, db_user, password)
1✔
642

643
    return query_output
1✔
644

645

646
def get_colnames_from_schema(schema_file):
1✔
647
    with open(schema_file, "r") as f:
1✔
648
        dat = f.read()
1✔
649
    dat = dat.split(");")[0]
1✔
650
    dat = dat.split("\n")[1:-1]
1✔
651
    pkstrip = [x.strip(",").split("PRIMARY KEY")[0].strip() for x in dat]
1✔
652
    fkstrip = [x.strip(",").split("FOREIGN KEY")[0].strip() for x in pkstrip]
1✔
653
    colnames = [x.split(" ")[0].strip('"') for x in fkstrip]
1✔
654
    return colnames
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc