• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

segasai / sqlutilpy / 14062891522

25 Mar 2025 02:57PM UTC coverage: 91.612%. Remained the same
14062891522

push

github

segasai
allow reading of ttimestamp with timezone

699 of 763 relevant lines covered (91.61%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.53
src/sqlutilpy/sqlutil.py
1
"""Sqlutilpy module to access SQL databases
2
"""
3
from __future__ import print_function
1✔
4
import numpy
1✔
5
import numpy as np
1✔
6
import psycopg
1✔
7
import threading
1✔
8
import collections
1✔
9
import warnings
1✔
10
from numpy.core import numeric as sb
1✔
11

12
try:
1✔
13
    import astropy.table as atpy
1✔
14
except ImportError:
×
15
    # astropy is not installed
16
    atpy = None
×
17
try:
1✔
18
    import pandas
1✔
19
except ImportError:
×
20
    # pandas is not installed
21
    pandas = None
×
22

23
import queue
1✔
24

25
_WAIT_SELECT_TIMEOUT = 10
1✔
26
STRLEN_DEFAULT = 20
1✔
27

28

29
class config:
1✔
30
    arraysize = 100000
1✔
31

32

33
class SqlUtilException(Exception):
1✔
34
    pass
1✔
35

36

37
def getConnection(db=None,
1✔
38
                  driver=None,
39
                  user=None,
40
                  password=None,
41
                  host=None,
42
                  port=None,
43
                  timeout=None):
44
    """
45
    Obtain the connection object to the DB.
46
    It may be useful to avoid reconnecting to the DB repeatedly.
47

48
    Parameters
49
    ----------
50

51
    db : string
52
        The name of the database (in case of PostgreSQL) or filename in
53
        case of sqlite db
54
    driver :  string
55
        The db driver (either 'psycopg' or 'sqlite3')
56
    user : string, optional
57
        Username
58
    password: string, optional
59
        Password
60
    host : string, optional
61
        Host-name
62
    port : integer
63
        Connection port (by default 5432 for PostgreSQL)
64
    timeout : integer
65
        Connection timeout for sqlite
66

67
    Returns
68
    -------
69
    conn : object
70
         Database Connection
71

72
    """
73
    if driver == 'psycopg2':
1✔
74
        warnings.warn(
1✔
75
            'psycopg2 driver is not supported anymore using psycopg instead')
76
        driver = 'psycopg'
1✔
77
    if driver == 'psycopg':
1✔
78
        conn_dict = dict()
1✔
79
        if db is not None:
1✔
80
            conn_dict['dbname'] = db
1✔
81
        if host is not None:
1✔
82
            conn_dict['host'] = host
1✔
83
        if port is not None:
1✔
84
            conn_dict['port'] = port
1✔
85
        if user is not None:
1✔
86
            conn_dict['user'] = user
1✔
87
        if password is not None:
1✔
88
            conn_dict['password'] = password
1✔
89
        conn = psycopg.connect(**conn_dict)
1✔
90
    elif driver == 'sqlite3':
1✔
91
        import sqlite3
1✔
92
        if timeout is None:
1✔
93
            timeout = 5
1✔
94
        conn = sqlite3.connect(db, timeout=timeout)
1✔
95
    elif driver == 'duckdb':
1✔
96
        import duckdb
1✔
97
        conn = duckdb.connect(db)
1✔
98
    else:
99
        raise Exception("Unknown driver")
1✔
100
    return conn
1✔
101

102

103
def getCursor(conn, driver=None, preamb=None, notNamed=False):
1✔
104
    """
105
    Retrieve the database cursor
106
    """
107
    if driver == 'psycopg2':
1✔
108
        warnings.warn(
×
109
            'psycopg2 driver is not supported anymore using psycopg instead')
110
        driver = 'psycopg'
×
111
    if driver == 'psycopg':
1✔
112
        cur = conn.cursor()
1✔
113
        if preamb is not None:
1✔
114
            cur.execute(preamb)
1✔
115
        else:
116
            cur.execute('set cursor_tuple_fraction TO 1')
1✔
117
            # this is required because otherwise PG may decide to execute a
118
            # different plan
119
        if notNamed:
1✔
120
            return cur
1✔
121
        cur = conn.cursor(name='sqlutilcursor')
1✔
122
        cur.arraysize = config.arraysize
1✔
123
    elif driver == 'sqlite3':
1✔
124
        cur = conn.cursor()
1✔
125
        if preamb is not None:
1✔
126
            cur.execute(preamb)
1✔
127
    elif driver == 'duckdb':
1✔
128
        cur = conn.cursor()
1✔
129
        if preamb is not None:
1✔
130
            cur.execute(preamb)
1✔
131
    else:
132
        raise SqlUtilException('unrecognized driver')
×
133
    return cur
1✔
134

135

136
def __fromrecords(recList, dtype=None, intNullVal=None):
1✔
137
    """
138
    This function was taken from np.core.records and updated to
139
    support conversion null integers to intNullVal
140
    """
141

142
    shape = None
1✔
143
    descr = sb.dtype((np.record, dtype))
1✔
144
    try:
1✔
145
        retval = sb.array(recList, dtype=descr)
1✔
146
    except TypeError:  # list of lists instead of list of tuples
1✔
147
        shape = (len(recList), )
1✔
148
        _array = np.recarray(shape, descr)
1✔
149
        try:
1✔
150
            for k in range(_array.size):
1✔
151
                _array[k] = tuple(recList[k])
1✔
152
        except TypeError:
1✔
153
            convs = []
1✔
154
            ncols = len(dtype.fields)
1✔
155
            for _k in dtype.names:
1✔
156
                _v = dtype.fields[_k]
1✔
157
                if _v[0] in [np.int16, np.int32, np.int64]:
1✔
158
                    convs.append(lambda x: intNullVal if x is None else x)
1✔
159
                else:
160
                    convs.append(lambda x: x)
1✔
161
            convs = tuple(convs)
1✔
162

163
            def convF(x):
1✔
164
                return [convs[_](x[_]) for _ in range(ncols)]
1✔
165

166
            for k in range(k, _array.size):
1✔
167
                try:
1✔
168
                    _array[k] = tuple(recList[k])
1✔
169
                except TypeError:
1✔
170
                    _array[k] = tuple(convF(recList[k]))
1✔
171
        return _array
1✔
172
    else:
173
        if shape is not None and retval.shape != shape:
1✔
174
            retval.shape = shape
×
175

176
    res = retval.view(numpy.recarray)
1✔
177

178
    return res
1✔
179

180

181
def __converter(qIn, qOut, endEvent, dtype, intNullVal):
1✔
182
    """ Convert the input stream of tuples into numpy arrays """
183
    while (not endEvent.is_set()):
1✔
184
        try:
1✔
185
            tups = qIn.get(True, 0.1)
1✔
186
        except queue.Empty:
1✔
187
            continue
1✔
188
        try:
1✔
189
            res = __fromrecords(tups, dtype=dtype, intNullVal=intNullVal)
1✔
190
        except Exception:
×
191
            print('Failed to convert input data into array')
×
192
            endEvent.set()
×
193
            raise
×
194
        qOut.put(res)
1✔
195

196

197
def __getDType(row, typeCodes, strLength):
1✔
198
    pgTypeHash = {
1✔
199
        16: bool,
200
        18: str,
201
        19: str,  # name type used in information schema
202
        20: 'i8',
203
        21: 'i2',
204
        23: 'i4',
205
        1007: 'i4',
206
        700: 'f4',
207
        701: 'f8',
208
        1000: bool,
209
        1005: 'i2',
210
        1007: 'i4',
211
        1016: 'i8',
212
        1021: 'f4',
213
        1022: 'f8',
214
        1700: 'f8',  # numeric
215
        1114: '<M8[us]',  # timestamp
216
        1184: '<M8[us]',  # timestamp with timezone
217
        1082: '<M8[us]',  # date
218
        25: '|U%d',
219
        1042: '|U%d',  # character()
220
        1043: '|U%d'  # varchar
221
    }
222
    strTypes = [25, 1042, 1043]
1✔
223

224
    pgTypes = []
1✔
225

226
    for i, (curv, curt) in enumerate(zip(row, typeCodes)):
1✔
227
        if curt not in pgTypeHash:
1✔
228
            raise SqlUtilException('Unknown PG type %d' % curt)
1✔
229
        pgType = pgTypeHash[curt]
1✔
230
        if curt in strTypes:
1✔
231
            if curv is not None:
1✔
232
                # if the first string value is longer than
233
                # strLength use that as a maximum
234
                curmax = max(strLength, len(curv))
1✔
235
            else:
236
                # if the first value is null
237
                # just use strLength
238
                curmax = strLength
1✔
239
            pgType = pgType % (curmax, )
1✔
240
        if curt not in strTypes:
1✔
241
            try:
1✔
242
                len(curv)
1✔
243
                pgType = 'O'
1✔
244
            except TypeError:
1✔
245
                pass
1✔
246
        pgTypes.append(('a%d' % i, pgType))
1✔
247
    dtype = numpy.dtype(pgTypes)
1✔
248
    return dtype
1✔
249

250

251
def get(query,
1✔
252
        params=None,
253
        db="wsdb",
254
        driver="psycopg",
255
        user=None,
256
        password=None,
257
        host=None,
258
        preamb=None,
259
        conn=None,
260
        port=None,
261
        strLength=STRLEN_DEFAULT,
262
        timeout=None,
263
        notNamed=False,
264
        asDict=False,
265
        intNullVal=-9999):
266
    '''
267
    Executes the sql query and returns the tuple or dictionary
268
    with the numpy arrays.
269

270
    Parameters
271
    ----------
272
    query : string
273
        Query you want to execute, can include question
274
        marks to refer to query parameters
275
    params : tuple
276
        Query parameters
277
    conn : object
278
        The connection object to the DB (optional) to avoid reconnecting
279
    asDict : boolean
280
        Flag whether to retrieve the results as a dictionary with column
281
        names as keys
282
    strLength : integer
283
        The maximum length of the string.
284
        Strings will be truncated to this length
285
    intNullVal : integer, optional
286
        All the integer columns with nulls will have null replaced by
287
        this value
288
    db : string
289
        The name of the database
290
    driver : string, optional
291
        The sql driver to be used (psycopg or sqlite3)
292
    user : string, optional
293
        User name for the DB connection
294
    password : string, optional
295
        DB connection password
296
    host : string, optional
297
        Hostname of the database
298
    port : integer, optional
299
        Port of the database
300
    preamb : string
301
        SQL code to be executed before the query
302

303
    Returns
304
    -------
305
    ret : Tuple or dictionary
306
        By default you get a tuple with numpy arrays for each column
307
        in your query.
308
        If you specified asDict keyword then you get an ordered dictionary with
309
        your columns.
310

311
    Examples
312
    --------
313
    >>> a, b, c = sqlutilpy.get('select ra,dec,d25 from rc3')
314

315
    You can also use the parameters in your query:
316

317
    >>> a, b = sqlutilpy.get('select ra,dec from rc3 where name=?', "NGC 3166")
318
    '''
319

320
    connSupplied = (conn is not None)
1✔
321
    if not connSupplied:
1✔
322
        conn = getConnection(db=db,
1✔
323
                             driver=driver,
324
                             user=user,
325
                             password=password,
326
                             host=host,
327
                             port=port,
328
                             timeout=timeout)
329
    try:
1✔
330
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=notNamed)
1✔
331

332
        if params is None:
1✔
333
            res = cur.execute(query)
1✔
334
        else:
335
            res = cur.execute(query, params)
1✔
336

337
        qIn = queue.Queue(1)
1✔
338
        qOut = queue.Queue()
1✔
339
        endEvent = threading.Event()
1✔
340
        nrec = 0  # keeps the number of arrays sent to the other thread
1✔
341
        # minus number received
342
        reslist = []
1✔
343
        proc = None
1✔
344
        colNames = []
1✔
345
        if driver == 'psycopg2':
1✔
346
            warnings.warn('psycopg2 driver is not supported anymore. '
×
347
                          'We using psycopg instead')
348
            driver = 'psycopg'
×
349
        if driver == 'psycopg':
1✔
350
            try:
1✔
351
                while True:
1✔
352
                    # Iterating over the cursor, retrieving batches of results
353
                    # and then sending them for conversion
354
                    tups = cur.fetchmany()
1✔
355
                    desc = cur.description
1✔
356

357
                    # If the is just the start we need to launch the
358
                    # thread doing the conversion
359
                    no_results = tups == []
1✔
360
                    if nrec == 0:
1✔
361
                        typeCodes = [_tmp.type_code for _tmp in desc]
1✔
362
                        colNames = [_tmp.name for _tmp in cur.description]
1✔
363

364
                    # No more data
365
                    if no_results:
1✔
366
                        dtype = __getDType([None] * len(typeCodes), typeCodes,
1✔
367
                                           strLength)
368
                        break
1✔
369

370
                    dtype = __getDType(tups[0], typeCodes, strLength)
1✔
371

372
                    # Send the new batch for conversion
373
                    qIn.put(tups)
1✔
374

375
                    # If the is just the start we need to launch the
376
                    # thread doing the conversion
377
                    if nrec == 0:
1✔
378
                        proc = threading.Thread(target=__converter,
1✔
379
                                                args=(qIn, qOut, endEvent,
380
                                                      dtype, intNullVal))
381
                        proc.start()
1✔
382

383
                    # nrec is the number of batches in conversion currently
384
                    nrec += 1
1✔
385

386
                    # Try to retrieve one processed batch without waiting
387
                    # on it
388
                    try:
1✔
389
                        reslist.append(qOut.get(False))
1✔
390
                        nrec -= 1
1✔
391
                    except queue.Empty:
1✔
392
                        pass
1✔
393

394
                # Now we are done fetching the data from the DB, we
395
                # just need to assemble the converted results
396
                while (nrec != 0):
1✔
397
                    try:
1✔
398
                        reslist.append(qOut.get(True, 0.1))
1✔
399
                        nrec -= 1
1✔
400
                    except queue.Empty:
×
401
                        # continue looping unless the endEvent was set
402
                        # which should happen in the case of the crash
403
                        # of the converter thread
404
                        if endEvent.is_set():
×
405
                            raise Exception('Child thread failed')
×
406
                endEvent.set()
1✔
407
            except BaseException:
1✔
408
                endEvent.set()
1✔
409
                if proc is not None:
1✔
410
                    # notice that here the timeout is larger than the timeout
411
                    proc.join(0.2)
×
412
                    # in the converter process
413
                    if proc.is_alive():
×
414
                        proc.terminate()
×
415
                raise
1✔
416
            if proc is not None:
1✔
417
                proc.join()
1✔
418
            if reslist == []:
1✔
419
                nCols = len(desc)
1✔
420
                res = numpy.array([],
1✔
421
                                  dtype=numpy.dtype([('a%d' % i, 'f')
422
                                                     for i in range(nCols)]))
423
            else:
424
                res = numpy.concatenate(reslist)
1✔
425

426
        elif driver == 'sqlite3':
1✔
427
            tups = cur.fetchall()
1✔
428
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
429
            if len(tups) > 0:
1✔
430
                res = numpy.rec.array(tups)
1✔
431
            else:
432
                return [[]] * len(cur.description)
1✔
433
        elif driver == 'duckdb':
1✔
434
            tups = cur.fetchall()
1✔
435
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
436
            if len(tups) > 0:
1✔
437
                res = numpy.rec.array(tups)
1✔
438
            else:
439
                return [[]] * len(cur.description)
1✔
440

441
        res = [res[tmp] for tmp in res.dtype.names]
1✔
442

443
    except BaseException:
1✔
444
        failure_cleanup(conn, connSupplied)
1✔
445
        raise
1✔
446

447
    cur.close()
1✔
448

449
    if not connSupplied:
1✔
450
        conn.close()
1✔
451

452
    if asDict:
1✔
453
        resDict = collections.OrderedDict()
1✔
454
        repeats = {}
1✔
455
        for _n, _v in zip(colNames, res):
1✔
456
            if _n in resDict:
1✔
457
                curn = _n + '_%d' % (repeats[_n])
1✔
458
                repeats[_n] += 1
1✔
459
                warnings.warn(('Column name %s is repeated in the output, ' +
1✔
460
                               'new name %s assigned') % (_n, curn))
461
            else:
462
                repeats[_n] = 1
1✔
463
                curn = _n
1✔
464
            resDict[curn] = _v
1✔
465
        res = resDict
1✔
466
    return res
1✔
467

468

469
def execute(query,
1✔
470
            params=None,
471
            db='wsdb',
472
            driver="psycopg",
473
            user=None,
474
            password=None,
475
            host=None,
476
            conn=None,
477
            preamb=None,
478
            timeout=None,
479
            noCommit=False):
480
    """
481
    Execute a given SQL command without returning the results
482

483
    Parameters
484
    ----------
485
    query: string
486
        The query or command you are executing
487
    params: tuple, optional
488
        Optional parameters of your query
489
    db : string
490
        Database name
491
    driver : string
492
        Driver for the DB connection ('psycopg' or 'sqlite3')
493
    user : string, optional
494
        user name for the DB connection
495
    password : string, optional
496
        DB connection password
497
    host : string, optional
498
        Hostname of the database
499
    port : integer, optional
500
        Port of the database
501
    noCommit: bool
502
        By default execute() will commit your command.
503
        If you say noCommit, the commit won't be issued.
504

505
    Examples
506
    --------
507
    >>> sqlutil.execute('drop table mytab', conn=conn)
508
    >>> sqlutil.execute('create table mytab (a int)', db='mydb')
509

510
    """
511
    connSupplied = (conn is not None)
1✔
512
    if not connSupplied:
1✔
513
        conn = getConnection(db=db,
1✔
514
                             driver=driver,
515
                             user=user,
516
                             password=password,
517
                             host=host,
518
                             timeout=timeout)
519
    try:
1✔
520
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
521
        if params is not None:
1✔
522
            cur.execute(query, params)
1✔
523
        else:
524
            # sqlite3 doesn't like params here...
525
            cur.execute(query)
1✔
526
    except BaseException:
1✔
527
        failure_cleanup(conn, connSupplied)
1✔
528
        raise
1✔
529
    cur.close()
1✔
530
    if not noCommit:
1✔
531
        conn.commit()
1✔
532
    if not connSupplied:
1✔
533
        conn.close()  # do not close if we were given the connection
1✔
534

535

536
def __create_schema(tableName, arrays, names, temp=False):
1✔
537
    hash = dict([(np.int32, 'integer'), (np.int64, 'bigint'),
1✔
538
                 (np.uint64, 'bigint'), (np.int16, 'smallint'),
539
                 (np.uint8, 'smallint'), (np.int8, 'smallint'),
540
                 (np.float32, 'real'), (np.float64, 'double precision'),
541
                 (np.bytes_, 'varchar'), (np.str_, 'varchar'),
542
                 (np.bool_, 'boolean'), (np.datetime64, 'timestamp')])
543
    if temp:
1✔
544
        temp = 'temporary'
1✔
545
    else:
546
        temp = ''
1✔
547
    outp = 'create %s table %s ' % (temp, tableName)
1✔
548
    outp1 = []
1✔
549
    for arr, name in zip(arrays, names):
1✔
550
        curtyp = arr.dtype.type
1✔
551
        if curtyp == np.object_:
1✔
552
            # array
553
            curotyp = hash[arr[0].dtype.type] + '[]'
1✔
554
        else:
555
            curotyp = hash[arr.dtype.type]
1✔
556
        outp1.append(f'"{name}" {curotyp}')
1✔
557
    return outp + '(' + ','.join(outp1) + ')'
1✔
558

559

560
def __print_arrays(arrays, f, delimiter=' '):
1✔
561
    """
562
    print the input arrays into the open file separated by a delimiter
563
    """
564
    format_dict = dict([(np.int32, '%d'), (np.int64, '%d'), (np.int16, '%d'),
1✔
565
                        (np.int8, '%d'), (np.uint8, '%d'),
566
                        (np.float32, '%.18e'), (np.float64, '%.18e'),
567
                        (np.bytes_, '%s'), (np.str_, '%s'),
568
                        (np.datetime64, '%s'), (np.bool_, '%d')])
569
    fmts = []
1✔
570
    array_mode = False
1✔
571
    names = []
1✔
572
    for i, x in enumerate(arrays):
1✔
573
        names.append('f%d' % i)
1✔
574
        if x.dtype.type in format_dict:
1✔
575
            fmts.append(format_dict[x.dtype.type])
1✔
576
        else:
577
            if x.dtype.type == np.object_:
1✔
578
                array_mode = True
1✔
579
                fmts.append(None)
1✔
580
            else:
581
                raise RuntimeError(
×
582
                    'Unsupported column type %s. Please report' %
583
                    (x.dtype.type))
584
    recarr = np.rec.fromarrays(arrays)
1✔
585
    if not array_mode:
1✔
586
        np.savetxt(f, recarr, fmt=fmts, delimiter=delimiter)
1✔
587
    else:
588
        # this is really slow
589
        for row in recarr:
1✔
590
            for i, field in enumerate(names):
1✔
591
                if i != 0:
1✔
592
                    f.write(delimiter.encode())
1✔
593
                if fmts[i] is None:
1✔
594
                    curstr = np.array2string(
1✔
595
                        row[field],
596
                        max_line_width=np.inf,
597
                        threshold=None,
598
                        separator=',',
599
                        formatter={'all': lambda x: str(x)})
600
                    # formatter is needed because otherwise there is
601
                    # whitespace padding
602
                    curstr = '{' + curstr[1:-1] + '}'
1✔
603
                    f.write(curstr.encode())
1✔
604
                else:
605
                    f.write(str(row[field]).encode())
1✔
606
            f.write(b'\n')
1✔
607

608

609
def failure_cleanup(conn, connSupplied):
1✔
610
    try:
1✔
611
        conn.rollback()
1✔
612
    except Exception:
×
613
        pass
×
614
    if not connSupplied:
1✔
615
        try:
1✔
616
            conn.close()  # do not close if we were given the connection
1✔
617
        except Exception:
×
618
            pass
×
619

620

621
def upload(tableName,
1✔
622
           arrays,
623
           names=None,
624
           db="wsdb",
625
           driver="psycopg",
626
           user=None,
627
           password=None,
628
           host=None,
629
           conn=None,
630
           preamb=None,
631
           timeout=None,
632
           noCommit=False,
633
           temp=False,
634
           analyze=False,
635
           createTable=True,
636
           delimiter='|'):
637
    """
638
    Upload the data stored in the tuple of arrays in the DB
639

640
    Parameters
641
    ----------
642
    tableName : string
643
        The name of the table where the data will be uploaded
644
    arrays_or_table : tuple
645
        Tuple of arrays that will be columns of the new table
646
        If names are not specified, I this parameter can be pandas or
647
        astropy table
648
    names : tuple
649
    db: string
650
         Databas name
651
    driver: string
652
         Python database driver "psycopg",
653
    user: string,
654
    password: string
655
    host: string
656
    conn: object
657
         SQL connection
658
    preamb: string
659
         The string/query to be executed before your command
660
    noCommit: bool
661
         If true, the commit is not executed and the table will go away
662
         after the disconnect
663
    temp: bool
664
         If true a temporary table will be created
665
    analyze: bool
666
         if True, the table will be analyzed after the upload
667
    createTable: bool
668
         if True the table will be created before uploading (default)
669
    delimiter: string
670
         the string used for delimiting the input data when ingesting into
671
         the db (default is |)
672

673
    Examples
674
    --------
675
    >>> x = np.arange(10)
676
    >>> y = x**.5
677
    >>> sqlutilpy.upload('mytable', (x, y), ('xcol', 'ycol'))
678

679
    >>> T = astropy.Table({'x':[1, 2, 3], 'y':['a', 'b', 'c'])
680
    >>> sqlutilpy.upload('mytable', T)
681
    """
682
    connSupplied = (conn is not None)
1✔
683
    if not connSupplied:
1✔
684
        conn = getConnection(db=db,
1✔
685
                             driver=driver,
686
                             user=user,
687
                             password=password,
688
                             host=host,
689
                             timeout=timeout)
690
    if names is None:
1✔
691
        for i in range(1):
1✔
692
            # we assume that we were given a table
693
            if atpy is not None:
1✔
694
                if isinstance(arrays, atpy.Table):
1✔
695
                    names = arrays.columns
1✔
696
                    arrays = [arrays[_] for _ in names]
1✔
697
                    break
1✔
698
            if pandas is not None:
1✔
699
                if isinstance(arrays, pandas.DataFrame):
1✔
700
                    names = arrays.columns
1✔
701
                    arrays = [arrays[_] for _ in names]
1✔
702
                    break
1✔
703
            if isinstance(arrays, dict):
1✔
704
                names = arrays.keys()
1✔
705
                arrays = [arrays[_] for _ in names]
1✔
706
                break
1✔
707
            if names is None:
1✔
708
                raise Exception('you either need to give astropy \
1✔
709
table/pandas/dictionary or provide a separate list of arrays and their names')
710

711
    arrays = [np.asarray(_) for _ in arrays]
1✔
712
    if len(arrays) != len(names):
1✔
713
        raise RuntimeError('The column names list must have the same '
×
714
                           'length as array list')
715
    repl_char = {
1✔
716
        ' ': '_',
717
        '-': '_',
718
        '(': '_',
719
        ')': '_',
720
        '[': '_',
721
        ']': '_',
722
        '<': '_',
723
        '>': '_'
724
    }
725
    fixed_names = []
1✔
726
    for name in names:
1✔
727
        fixed_name = name + ''
1✔
728
        for k in repl_char.keys():
1✔
729
            fixed_name = fixed_name.replace(k, repl_char[k])
1✔
730
        if fixed_name != name:
1✔
731
            warnings.warn('''Renamed column '%s' to '%s' ''' %
1✔
732
                          (name, fixed_name))
733
        fixed_names.append(fixed_name)
1✔
734
    names = fixed_names
1✔
735
    try:
1✔
736
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
737
        if createTable:
1✔
738
            query1 = __create_schema(tableName, arrays, names, temp=temp)
1✔
739
            cur.execute(query1)
1✔
740
        nsplit = 100000
1✔
741
        N = len(arrays[0])
1✔
742
        names = ','.join(names)
1✔
743
        for i in range(0, N, nsplit):
1✔
744
            try:
1✔
745
                with cur.copy(f'''copy {tableName}({names}) from stdin
1✔
746
                        with delimiter '{delimiter}' ''') as copy:
747
                    __print_arrays([_[i:i + nsplit] for _ in arrays],
1✔
748
                                   copy,
749
                                   delimiter=delimiter)
750

751
            finally:
752
                pass
1✔
753
    except BaseException:
1✔
754
        failure_cleanup(conn, connSupplied)
1✔
755
        raise
1✔
756
    if analyze:
1✔
757
        cur.execute('analyze %s' % tableName)
1✔
758
    cur.close()
1✔
759
    if not noCommit:
1✔
760
        conn.commit()
1✔
761
    if not connSupplied:
1✔
762
        conn.close()  # do not close if we were given the connection
1✔
763

764

765
def local_join(query,
1✔
766
               tableName,
767
               arrays,
768
               names,
769
               db=None,
770
               driver="psycopg",
771
               user=None,
772
               password=None,
773
               host=None,
774
               port=None,
775
               conn=None,
776
               preamb=None,
777
               timeout=None,
778
               strLength=STRLEN_DEFAULT,
779
               intNullVal=-9999,
780
               asDict=False):
781
    """
782
    Join your local data in python with the data in the database
783
    This command first uploads the data in the DB creating a temporary table
784
    and then runs a user specified query that can your local data.
785

786
    Parameters
787
    ----------
788
    query : String with the query to be executed
789
    tableName : The name of the temporary table that is going to be created
790
    arrays : The tuple with list of arrays with the data to be loaded in the DB
791
    names : The tuple with the column names for the user table
792

793
    Examples
794
    --------
795
    This will extract the rows from the table sometable matching
796
    to the provided array x
797

798
    >>> x = np.arange(10)
799
    >>> y = x**.5
800
    >>> sqlutilpy.local_join('''
801
    ... SELECT s.* FROM mytable AS m LEFT JOIN sometable AS s
802
    ... ON s.x = m.x ORDER BY m.xcol''',
803
    ... 'mytable', (x, y), ('x', 'y'))
804
    """
805

806
    connSupplied = (conn is not None)
1✔
807
    if not connSupplied:
1✔
808
        conn = getConnection(db=db,
1✔
809
                             driver=driver,
810
                             user=user,
811
                             password=password,
812
                             host=host,
813
                             timeout=timeout,
814
                             port=port)
815
    try:
1✔
816
        upload(tableName,
1✔
817
               arrays,
818
               names,
819
               conn=conn,
820
               noCommit=True,
821
               temp=True,
822
               analyze=True)
823
    except BaseException:
×
824
        failure_cleanup(conn, connSupplied)
×
825
        raise
×
826
    try:
1✔
827
        res = get(query,
1✔
828
                  conn=conn,
829
                  preamb=preamb,
830
                  strLength=strLength,
831
                  asDict=asDict,
832
                  intNullVal=intNullVal)
833
    except BaseException:
×
834
        failure_cleanup(conn, connSupplied)
×
835
        raise
×
836

837
    conn.rollback()
1✔
838

839
    if not connSupplied:
1✔
840
        conn.close()
1✔
841
    return res
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc