• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

segasai / sqlutilpy / 14365889562

09 Apr 2025 07:58PM UTC coverage: 91.732% (+0.1%) from 91.612%
14365889562

push

github

segasai
thread terminate does not exist

699 of 762 relevant lines covered (91.73%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.53
src/sqlutilpy/sqlutil.py
1
"""Sqlutilpy module to access SQL databases
2
"""
3
from __future__ import print_function
1✔
4
import numpy
1✔
5
import numpy as np
1✔
6
import psycopg
1✔
7
import threading
1✔
8
import collections
1✔
9
import warnings
1✔
10
from numpy.core import numeric as sb
1✔
11

12
try:
1✔
13
    import astropy.table as atpy
1✔
14
except ImportError:
×
15
    # astropy is not installed
16
    atpy = None
×
17
try:
1✔
18
    import pandas
1✔
19
except ImportError:
×
20
    # pandas is not installed
21
    pandas = None
×
22

23
import queue
1✔
24

25
_WAIT_SELECT_TIMEOUT = 10
1✔
26
STRLEN_DEFAULT = 20
1✔
27

28

29
class config:
1✔
30
    arraysize = 100000
1✔
31

32

33
class SqlUtilException(Exception):
1✔
34
    pass
1✔
35

36

37
def getConnection(db=None,
1✔
38
                  driver=None,
39
                  user=None,
40
                  password=None,
41
                  host=None,
42
                  port=None,
43
                  timeout=None):
44
    """
45
    Obtain the connection object to the DB.
46
    It may be useful to avoid reconnecting to the DB repeatedly.
47

48
    Parameters
49
    ----------
50

51
    db : string
52
        The name of the database (in case of PostgreSQL) or filename in
53
        case of sqlite db
54
    driver :  string
55
        The db driver (either 'psycopg' or 'sqlite3')
56
    user : string, optional
57
        Username
58
    password: string, optional
59
        Password
60
    host : string, optional
61
        Host-name
62
    port : integer
63
        Connection port (by default 5432 for PostgreSQL)
64
    timeout : integer
65
        Connection timeout for sqlite
66

67
    Returns
68
    -------
69
    conn : object
70
         Database Connection
71

72
    """
73
    if driver == 'psycopg2':
1✔
74
        warnings.warn(
1✔
75
            'psycopg2 driver is not supported anymore using psycopg instead')
76
        driver = 'psycopg'
1✔
77
    if driver == 'psycopg':
1✔
78
        conn_dict = dict()
1✔
79
        if db is not None:
1✔
80
            conn_dict['dbname'] = db
1✔
81
        if host is not None:
1✔
82
            conn_dict['host'] = host
1✔
83
        if port is not None:
1✔
84
            conn_dict['port'] = port
1✔
85
        if user is not None:
1✔
86
            conn_dict['user'] = user
1✔
87
        if password is not None:
1✔
88
            conn_dict['password'] = password
1✔
89
        conn = psycopg.connect(**conn_dict)
1✔
90
    elif driver == 'sqlite3':
1✔
91
        import sqlite3
1✔
92
        if timeout is None:
1✔
93
            timeout = 5
1✔
94
        conn = sqlite3.connect(db, timeout=timeout)
1✔
95
    elif driver == 'duckdb':
1✔
96
        import duckdb
1✔
97
        conn = duckdb.connect(db)
1✔
98
    else:
99
        raise Exception("Unknown driver")
1✔
100
    return conn
1✔
101

102

103
def getCursor(conn, driver=None, preamb=None, notNamed=False):
1✔
104
    """
105
    Retrieve the database cursor
106
    """
107
    if driver == 'psycopg2':
1✔
108
        warnings.warn(
×
109
            'psycopg2 driver is not supported anymore using psycopg instead')
110
        driver = 'psycopg'
×
111
    if driver == 'psycopg':
1✔
112
        cur = conn.cursor()
1✔
113
        if preamb is not None:
1✔
114
            cur.execute(preamb)
1✔
115
        else:
116
            cur.execute('set cursor_tuple_fraction TO 1')
1✔
117
            # this is required because otherwise PG may decide to execute a
118
            # different plan
119
        if notNamed:
1✔
120
            return cur
1✔
121
        cur = conn.cursor(name='sqlutilcursor')
1✔
122
        cur.arraysize = config.arraysize
1✔
123
    elif driver == 'sqlite3':
1✔
124
        cur = conn.cursor()
1✔
125
        if preamb is not None:
1✔
126
            cur.execute(preamb)
1✔
127
    elif driver == 'duckdb':
1✔
128
        cur = conn.cursor()
1✔
129
        if preamb is not None:
1✔
130
            cur.execute(preamb)
1✔
131
    else:
132
        raise SqlUtilException('unrecognized driver')
×
133
    return cur
1✔
134

135

136
def __fromrecords(recList, dtype=None, intNullVal=None):
1✔
137
    """
138
    This function was taken from np.core.records and updated to
139
    support conversion null integers to intNullVal
140
    """
141

142
    shape = None
1✔
143
    descr = sb.dtype((np.record, dtype))
1✔
144
    try:
1✔
145
        retval = sb.array(recList, dtype=descr)
1✔
146
    except TypeError:  # list of lists instead of list of tuples
1✔
147
        shape = (len(recList), )
1✔
148
        _array = np.recarray(shape, descr)
1✔
149
        try:
1✔
150
            for k in range(_array.size):
1✔
151
                _array[k] = tuple(recList[k])
1✔
152
        except TypeError:
1✔
153
            convs = []
1✔
154
            ncols = len(dtype.fields)
1✔
155
            for _k in dtype.names:
1✔
156
                _v = dtype.fields[_k]
1✔
157
                if _v[0] in [np.int16, np.int32, np.int64]:
1✔
158
                    convs.append(lambda x: intNullVal if x is None else x)
1✔
159
                else:
160
                    convs.append(lambda x: x)
1✔
161
            convs = tuple(convs)
1✔
162

163
            def convF(x):
1✔
164
                return [convs[_](x[_]) for _ in range(ncols)]
1✔
165

166
            for k in range(k, _array.size):
1✔
167
                try:
1✔
168
                    _array[k] = tuple(recList[k])
1✔
169
                except TypeError:
1✔
170
                    _array[k] = tuple(convF(recList[k]))
1✔
171
        return _array
1✔
172
    else:
173
        if shape is not None and retval.shape != shape:
1✔
174
            retval.shape = shape
×
175

176
    res = retval.view(numpy.recarray)
1✔
177

178
    return res
1✔
179

180

181
def __converter(qIn, qOut, endEvent, dtype, intNullVal):
1✔
182
    """ Convert the input stream of tuples into numpy arrays """
183
    while (not endEvent.is_set()):
1✔
184
        try:
1✔
185
            tups = qIn.get(True, 0.1)
1✔
186
        except queue.Empty:
1✔
187
            continue
1✔
188
        try:
1✔
189
            res = __fromrecords(tups, dtype=dtype, intNullVal=intNullVal)
1✔
190
        except Exception:
×
191
            print('Failed to convert input data into array')
×
192
            endEvent.set()
×
193
            raise
×
194
        qOut.put(res)
1✔
195

196

197
def __getDType(row, typeCodes, strLength):
1✔
198
    pgTypeHash = {
1✔
199
        16: bool,
200
        18: str,
201
        19: str,  # name type used in information schema
202
        20: 'i8',
203
        21: 'i2',
204
        23: 'i4',
205
        1007: 'i4',
206
        700: 'f4',
207
        701: 'f8',
208
        1000: bool,
209
        1005: 'i2',
210
        1007: 'i4',
211
        1016: 'i8',
212
        1021: 'f4',
213
        1022: 'f8',
214
        1700: 'f8',  # numeric
215
        1114: '<M8[us]',  # timestamp
216
        1184: '<M8[us]',  # timestamp with timezone
217
        1082: '<M8[us]',  # date
218
        25: '|U%d',
219
        1042: '|U%d',  # character()
220
        1043: '|U%d'  # varchar
221
    }
222
    strTypes = [25, 1042, 1043]
1✔
223

224
    pgTypes = []
1✔
225

226
    for i, (curv, curt) in enumerate(zip(row, typeCodes)):
1✔
227
        if curt not in pgTypeHash:
1✔
228
            raise SqlUtilException('Unknown PG type %d' % curt)
1✔
229
        pgType = pgTypeHash[curt]
1✔
230
        if curt in strTypes:
1✔
231
            if curv is not None:
1✔
232
                # if the first string value is longer than
233
                # strLength use that as a maximum
234
                curmax = max(strLength, len(curv))
1✔
235
            else:
236
                # if the first value is null
237
                # just use strLength
238
                curmax = strLength
1✔
239
            pgType = pgType % (curmax, )
1✔
240
        if curt not in strTypes:
1✔
241
            try:
1✔
242
                len(curv)
1✔
243
                pgType = 'O'
1✔
244
            except TypeError:
1✔
245
                pass
1✔
246
        pgTypes.append(('a%d' % i, pgType))
1✔
247
    dtype = numpy.dtype(pgTypes)
1✔
248
    return dtype
1✔
249

250

251
def get(query,
1✔
252
        params=None,
253
        db="wsdb",
254
        driver="psycopg",
255
        user=None,
256
        password=None,
257
        host=None,
258
        preamb=None,
259
        conn=None,
260
        port=None,
261
        strLength=STRLEN_DEFAULT,
262
        timeout=None,
263
        notNamed=False,
264
        asDict=False,
265
        intNullVal=-9999):
266
    '''
267
    Executes the sql query and returns the tuple or dictionary
268
    with the numpy arrays.
269

270
    Parameters
271
    ----------
272
    query : string
273
        Query you want to execute, can include question
274
        marks to refer to query parameters
275
    params : tuple
276
        Query parameters
277
    conn : object
278
        The connection object to the DB (optional) to avoid reconnecting
279
    asDict : boolean
280
        Flag whether to retrieve the results as a dictionary with column
281
        names as keys
282
    strLength : integer
283
        The maximum length of the string.
284
        Strings will be truncated to this length
285
    intNullVal : integer, optional
286
        All the integer columns with nulls will have null replaced by
287
        this value
288
    db : string
289
        The name of the database
290
    driver : string, optional
291
        The sql driver to be used (psycopg or sqlite3)
292
    user : string, optional
293
        User name for the DB connection
294
    password : string, optional
295
        DB connection password
296
    host : string, optional
297
        Hostname of the database
298
    port : integer, optional
299
        Port of the database
300
    preamb : string
301
        SQL code to be executed before the query
302

303
    Returns
304
    -------
305
    ret : Tuple or dictionary
306
        By default you get a tuple with numpy arrays for each column
307
        in your query.
308
        If you specified asDict keyword then you get an ordered dictionary with
309
        your columns.
310

311
    Examples
312
    --------
313
    >>> a, b, c = sqlutilpy.get('select ra,dec,d25 from rc3')
314

315
    You can also use the parameters in your query:
316

317
    >>> a, b = sqlutilpy.get('select ra,dec from rc3 where name=?', "NGC 3166")
318
    '''
319

320
    connSupplied = (conn is not None)
1✔
321
    if not connSupplied:
1✔
322
        conn = getConnection(db=db,
1✔
323
                             driver=driver,
324
                             user=user,
325
                             password=password,
326
                             host=host,
327
                             port=port,
328
                             timeout=timeout)
329
    try:
1✔
330
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=notNamed)
1✔
331

332
        if params is None:
1✔
333
            res = cur.execute(query)
1✔
334
        else:
335
            res = cur.execute(query, params)
1✔
336

337
        qIn = queue.Queue(1)
1✔
338
        qOut = queue.Queue()
1✔
339
        endEvent = threading.Event()
1✔
340
        nrec = 0  # keeps the number of arrays sent to the other thread
1✔
341
        # minus number received
342
        reslist = []
1✔
343
        proc = None
1✔
344
        colNames = []
1✔
345
        if driver == 'psycopg2':
1✔
346
            warnings.warn('psycopg2 driver is not supported anymore. '
×
347
                          'We using psycopg instead')
348
            driver = 'psycopg'
×
349
        if driver == 'psycopg':
1✔
350
            try:
1✔
351
                while True:
1✔
352
                    # Iterating over the cursor, retrieving batches of results
353
                    # and then sending them for conversion
354
                    tups = cur.fetchmany()
1✔
355
                    desc = cur.description
1✔
356

357
                    # If the is just the start we need to launch the
358
                    # thread doing the conversion
359
                    no_results = tups == []
1✔
360
                    if nrec == 0:
1✔
361
                        typeCodes = [_tmp.type_code for _tmp in desc]
1✔
362
                        colNames = [_tmp.name for _tmp in cur.description]
1✔
363

364
                    # No more data
365
                    if no_results:
1✔
366
                        dtype = __getDType([None] * len(typeCodes), typeCodes,
1✔
367
                                           strLength)
368
                        break
1✔
369

370
                    dtype = __getDType(tups[0], typeCodes, strLength)
1✔
371

372
                    # Send the new batch for conversion
373
                    qIn.put(tups)
1✔
374

375
                    # If the is just the start we need to launch the
376
                    # thread doing the conversion
377
                    if nrec == 0:
1✔
378
                        proc = threading.Thread(target=__converter,
1✔
379
                                                args=(qIn, qOut, endEvent,
380
                                                      dtype, intNullVal))
381
                        proc.start()
1✔
382

383
                    # nrec is the number of batches in conversion currently
384
                    nrec += 1
1✔
385

386
                    # Try to retrieve one processed batch without waiting
387
                    # on it
388
                    try:
1✔
389
                        reslist.append(qOut.get(False))
1✔
390
                        nrec -= 1
1✔
391
                    except queue.Empty:
1✔
392
                        pass
1✔
393

394
                # Now we are done fetching the data from the DB, we
395
                # just need to assemble the converted results
396
                while (nrec != 0):
1✔
397
                    try:
1✔
398
                        reslist.append(qOut.get(True, 0.1))
1✔
399
                        nrec -= 1
1✔
400
                    except queue.Empty:
×
401
                        # continue looping unless the endEvent was set
402
                        # which should happen in the case of the crash
403
                        # of the converter thread
404
                        if endEvent.is_set():
×
405
                            raise Exception('Child thread failed')
×
406
                endEvent.set()
1✔
407
            except BaseException:
1✔
408
                endEvent.set()
1✔
409
                if proc is not None:
1✔
410
                    # notice that here the timeout is larger than the timeout
411
                    proc.join(0.2)
×
412
                    # in the converter process
413
                    if proc.is_alive():
×
414
                        # could not kill
415
                        pass
×
416
                raise
1✔
417
            if proc is not None:
1✔
418
                proc.join()
1✔
419
            if reslist == []:
1✔
420
                nCols = len(desc)
1✔
421
                res = numpy.array([],
1✔
422
                                  dtype=numpy.dtype([('a%d' % i, 'f')
423
                                                     for i in range(nCols)]))
424
            else:
425
                res = numpy.concatenate(reslist)
1✔
426

427
        elif driver == 'sqlite3':
1✔
428
            tups = cur.fetchall()
1✔
429
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
430
            if len(tups) > 0:
1✔
431
                res = numpy.rec.array(tups)
1✔
432
            else:
433
                return [[]] * len(cur.description)
1✔
434
        elif driver == 'duckdb':
1✔
435
            tups = cur.fetchall()
1✔
436
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
437
            if len(tups) > 0:
1✔
438
                res = numpy.rec.array(tups)
1✔
439
            else:
440
                return [[]] * len(cur.description)
1✔
441

442
        res = [res[tmp] for tmp in res.dtype.names]
1✔
443

444
    except BaseException:
1✔
445
        failure_cleanup(conn, connSupplied)
1✔
446
        raise
1✔
447

448
    cur.close()
1✔
449

450
    if not connSupplied:
1✔
451
        conn.close()
1✔
452

453
    if asDict:
1✔
454
        resDict = collections.OrderedDict()
1✔
455
        repeats = {}
1✔
456
        for _n, _v in zip(colNames, res):
1✔
457
            if _n in resDict:
1✔
458
                curn = _n + '_%d' % (repeats[_n])
1✔
459
                repeats[_n] += 1
1✔
460
                warnings.warn(('Column name %s is repeated in the output, ' +
1✔
461
                               'new name %s assigned') % (_n, curn))
462
            else:
463
                repeats[_n] = 1
1✔
464
                curn = _n
1✔
465
            resDict[curn] = _v
1✔
466
        res = resDict
1✔
467
    return res
1✔
468

469

470
def execute(query,
1✔
471
            params=None,
472
            db='wsdb',
473
            driver="psycopg",
474
            user=None,
475
            password=None,
476
            host=None,
477
            conn=None,
478
            preamb=None,
479
            timeout=None,
480
            noCommit=False):
481
    """
482
    Execute a given SQL command without returning the results
483

484
    Parameters
485
    ----------
486
    query: string
487
        The query or command you are executing
488
    params: tuple, optional
489
        Optional parameters of your query
490
    db : string
491
        Database name
492
    driver : string
493
        Driver for the DB connection ('psycopg' or 'sqlite3')
494
    user : string, optional
495
        user name for the DB connection
496
    password : string, optional
497
        DB connection password
498
    host : string, optional
499
        Hostname of the database
500
    port : integer, optional
501
        Port of the database
502
    noCommit: bool
503
        By default execute() will commit your command.
504
        If you say noCommit, the commit won't be issued.
505

506
    Examples
507
    --------
508
    >>> sqlutil.execute('drop table mytab', conn=conn)
509
    >>> sqlutil.execute('create table mytab (a int)', db='mydb')
510

511
    """
512
    connSupplied = (conn is not None)
1✔
513
    if not connSupplied:
1✔
514
        conn = getConnection(db=db,
1✔
515
                             driver=driver,
516
                             user=user,
517
                             password=password,
518
                             host=host,
519
                             timeout=timeout)
520
    try:
1✔
521
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
522
        if params is not None:
1✔
523
            cur.execute(query, params)
1✔
524
        else:
525
            # sqlite3 doesn't like params here...
526
            cur.execute(query)
1✔
527
    except BaseException:
1✔
528
        failure_cleanup(conn, connSupplied)
1✔
529
        raise
1✔
530
    cur.close()
1✔
531
    if not noCommit:
1✔
532
        conn.commit()
1✔
533
    if not connSupplied:
1✔
534
        conn.close()  # do not close if we were given the connection
1✔
535

536

537
def __create_schema(tableName, arrays, names, temp=False):
1✔
538
    hash = dict([(np.int32, 'integer'), (np.int64, 'bigint'),
1✔
539
                 (np.uint64, 'bigint'), (np.int16, 'smallint'),
540
                 (np.uint8, 'smallint'), (np.int8, 'smallint'),
541
                 (np.float32, 'real'), (np.float64, 'double precision'),
542
                 (np.bytes_, 'varchar'), (np.str_, 'varchar'),
543
                 (np.bool_, 'boolean'), (np.datetime64, 'timestamp')])
544
    if temp:
1✔
545
        temp = 'temporary'
1✔
546
    else:
547
        temp = ''
1✔
548
    outp = 'create %s table %s ' % (temp, tableName)
1✔
549
    outp1 = []
1✔
550
    for arr, name in zip(arrays, names):
1✔
551
        curtyp = arr.dtype.type
1✔
552
        if curtyp == np.object_:
1✔
553
            # array
554
            curotyp = hash[arr[0].dtype.type] + '[]'
1✔
555
        else:
556
            curotyp = hash[arr.dtype.type]
1✔
557
        outp1.append(f'"{name}" {curotyp}')
1✔
558
    return outp + '(' + ','.join(outp1) + ')'
1✔
559

560

561
def __print_arrays(arrays, f, delimiter=' '):
1✔
562
    """
563
    print the input arrays into the open file separated by a delimiter
564
    """
565
    format_dict = dict([(np.int32, '%d'), (np.int64, '%d'), (np.int16, '%d'),
1✔
566
                        (np.int8, '%d'), (np.uint8, '%d'),
567
                        (np.float32, '%.18e'), (np.float64, '%.18e'),
568
                        (np.bytes_, '%s'), (np.str_, '%s'),
569
                        (np.datetime64, '%s'), (np.bool_, '%d')])
570
    fmts = []
1✔
571
    array_mode = False
1✔
572
    names = []
1✔
573
    for i, x in enumerate(arrays):
1✔
574
        names.append('f%d' % i)
1✔
575
        if x.dtype.type in format_dict:
1✔
576
            fmts.append(format_dict[x.dtype.type])
1✔
577
        else:
578
            if x.dtype.type == np.object_:
1✔
579
                array_mode = True
1✔
580
                fmts.append(None)
1✔
581
            else:
582
                raise RuntimeError(
×
583
                    'Unsupported column type %s. Please report' %
584
                    (x.dtype.type))
585
    recarr = np.rec.fromarrays(arrays)
1✔
586
    if not array_mode:
1✔
587
        np.savetxt(f, recarr, fmt=fmts, delimiter=delimiter)
1✔
588
    else:
589
        # this is really slow
590
        for row in recarr:
1✔
591
            for i, field in enumerate(names):
1✔
592
                if i != 0:
1✔
593
                    f.write(delimiter.encode())
1✔
594
                if fmts[i] is None:
1✔
595
                    curstr = np.array2string(
1✔
596
                        row[field],
597
                        max_line_width=np.inf,
598
                        threshold=None,
599
                        separator=',',
600
                        formatter={'all': lambda x: str(x)})
601
                    # formatter is needed because otherwise there is
602
                    # whitespace padding
603
                    curstr = '{' + curstr[1:-1] + '}'
1✔
604
                    f.write(curstr.encode())
1✔
605
                else:
606
                    f.write(str(row[field]).encode())
1✔
607
            f.write(b'\n')
1✔
608

609

610
def failure_cleanup(conn, connSupplied):
1✔
611
    try:
1✔
612
        conn.rollback()
1✔
613
    except Exception:
×
614
        pass
×
615
    if not connSupplied:
1✔
616
        try:
1✔
617
            conn.close()  # do not close if we were given the connection
1✔
618
        except Exception:
×
619
            pass
×
620

621

622
def upload(tableName,
1✔
623
           arrays,
624
           names=None,
625
           db="wsdb",
626
           driver="psycopg",
627
           user=None,
628
           password=None,
629
           host=None,
630
           conn=None,
631
           preamb=None,
632
           timeout=None,
633
           noCommit=False,
634
           temp=False,
635
           analyze=False,
636
           createTable=True,
637
           delimiter='|'):
638
    """
639
    Upload the data stored in the tuple of arrays in the DB
640

641
    Parameters
642
    ----------
643
    tableName : string
644
        The name of the table where the data will be uploaded
645
    arrays_or_table : tuple
646
        Tuple of arrays that will be columns of the new table
647
        If names are not specified, I this parameter can be pandas or
648
        astropy table
649
    names : tuple
650
    db: string
651
         Databas name
652
    driver: string
653
         Python database driver "psycopg",
654
    user: string,
655
    password: string
656
    host: string
657
    conn: object
658
         SQL connection
659
    preamb: string
660
         The string/query to be executed before your command
661
    noCommit: bool
662
         If true, the commit is not executed and the table will go away
663
         after the disconnect
664
    temp: bool
665
         If true a temporary table will be created
666
    analyze: bool
667
         if True, the table will be analyzed after the upload
668
    createTable: bool
669
         if True the table will be created before uploading (default)
670
    delimiter: string
671
         the string used for delimiting the input data when ingesting into
672
         the db (default is |)
673

674
    Examples
675
    --------
676
    >>> x = np.arange(10)
677
    >>> y = x**.5
678
    >>> sqlutilpy.upload('mytable', (x, y), ('xcol', 'ycol'))
679

680
    >>> T = astropy.Table({'x':[1, 2, 3], 'y':['a', 'b', 'c'])
681
    >>> sqlutilpy.upload('mytable', T)
682
    """
683
    connSupplied = (conn is not None)
1✔
684
    if not connSupplied:
1✔
685
        conn = getConnection(db=db,
1✔
686
                             driver=driver,
687
                             user=user,
688
                             password=password,
689
                             host=host,
690
                             timeout=timeout)
691
    if names is None:
1✔
692
        for i in range(1):
1✔
693
            # we assume that we were given a table
694
            if atpy is not None:
1✔
695
                if isinstance(arrays, atpy.Table):
1✔
696
                    names = arrays.columns
1✔
697
                    arrays = [arrays[_] for _ in names]
1✔
698
                    break
1✔
699
            if pandas is not None:
1✔
700
                if isinstance(arrays, pandas.DataFrame):
1✔
701
                    names = arrays.columns
1✔
702
                    arrays = [arrays[_] for _ in names]
1✔
703
                    break
1✔
704
            if isinstance(arrays, dict):
1✔
705
                names = arrays.keys()
1✔
706
                arrays = [arrays[_] for _ in names]
1✔
707
                break
1✔
708
            if names is None:
1✔
709
                raise Exception('you either need to give astropy \
1✔
710
table/pandas/dictionary or provide a separate list of arrays and their names')
711

712
    arrays = [np.asarray(_) for _ in arrays]
1✔
713
    if len(arrays) != len(names):
1✔
714
        raise RuntimeError('The column names list must have the same '
×
715
                           'length as array list')
716
    repl_char = {
1✔
717
        ' ': '_',
718
        '-': '_',
719
        '(': '_',
720
        ')': '_',
721
        '[': '_',
722
        ']': '_',
723
        '<': '_',
724
        '>': '_'
725
    }
726
    fixed_names = []
1✔
727
    for name in names:
1✔
728
        fixed_name = name + ''
1✔
729
        for k in repl_char.keys():
1✔
730
            fixed_name = fixed_name.replace(k, repl_char[k])
1✔
731
        if fixed_name != name:
1✔
732
            warnings.warn('''Renamed column '%s' to '%s' ''' %
1✔
733
                          (name, fixed_name))
734
        fixed_names.append(fixed_name)
1✔
735
    names = fixed_names
1✔
736
    try:
1✔
737
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
738
        if createTable:
1✔
739
            query1 = __create_schema(tableName, arrays, names, temp=temp)
1✔
740
            cur.execute(query1)
1✔
741
        nsplit = 100000
1✔
742
        N = len(arrays[0])
1✔
743
        names = ','.join(names)
1✔
744
        for i in range(0, N, nsplit):
1✔
745
            try:
1✔
746
                with cur.copy(f'''copy {tableName}({names}) from stdin
1✔
747
                        with delimiter '{delimiter}' ''') as copy:
748
                    __print_arrays([_[i:i + nsplit] for _ in arrays],
1✔
749
                                   copy,
750
                                   delimiter=delimiter)
751

752
            finally:
753
                pass
1✔
754
    except BaseException:
1✔
755
        failure_cleanup(conn, connSupplied)
1✔
756
        raise
1✔
757
    if analyze:
1✔
758
        cur.execute('analyze %s' % tableName)
1✔
759
    cur.close()
1✔
760
    if not noCommit:
1✔
761
        conn.commit()
1✔
762
    if not connSupplied:
1✔
763
        conn.close()  # do not close if we were given the connection
1✔
764

765

766
def local_join(query,
1✔
767
               tableName,
768
               arrays,
769
               names,
770
               db=None,
771
               driver="psycopg",
772
               user=None,
773
               password=None,
774
               host=None,
775
               port=None,
776
               conn=None,
777
               preamb=None,
778
               timeout=None,
779
               strLength=STRLEN_DEFAULT,
780
               intNullVal=-9999,
781
               asDict=False):
782
    """
783
    Join your local data in python with the data in the database
784
    This command first uploads the data in the DB creating a temporary table
785
    and then runs a user specified query that can your local data.
786

787
    Parameters
788
    ----------
789
    query : String with the query to be executed
790
    tableName : The name of the temporary table that is going to be created
791
    arrays : The tuple with list of arrays with the data to be loaded in the DB
792
    names : The tuple with the column names for the user table
793

794
    Examples
795
    --------
796
    This will extract the rows from the table sometable matching
797
    to the provided array x
798

799
    >>> x = np.arange(10)
800
    >>> y = x**.5
801
    >>> sqlutilpy.local_join('''
802
    ... SELECT s.* FROM mytable AS m LEFT JOIN sometable AS s
803
    ... ON s.x = m.x ORDER BY m.xcol''',
804
    ... 'mytable', (x, y), ('x', 'y'))
805
    """
806

807
    connSupplied = (conn is not None)
1✔
808
    if not connSupplied:
1✔
809
        conn = getConnection(db=db,
1✔
810
                             driver=driver,
811
                             user=user,
812
                             password=password,
813
                             host=host,
814
                             timeout=timeout,
815
                             port=port)
816
    try:
1✔
817
        upload(tableName,
1✔
818
               arrays,
819
               names,
820
               conn=conn,
821
               noCommit=True,
822
               temp=True,
823
               analyze=True)
824
    except BaseException:
×
825
        failure_cleanup(conn, connSupplied)
×
826
        raise
×
827
    try:
1✔
828
        res = get(query,
1✔
829
                  conn=conn,
830
                  preamb=preamb,
831
                  strLength=strLength,
832
                  asDict=asDict,
833
                  intNullVal=intNullVal)
834
    except BaseException:
×
835
        failure_cleanup(conn, connSupplied)
×
836
        raise
×
837

838
    conn.rollback()
1✔
839

840
    if not connSupplied:
1✔
841
        conn.close()
1✔
842
    return res
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc