• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

segasai / sqlutilpy / 13034557528

29 Jan 2025 03:21PM UTC coverage: 91.612% (+2.7%) from 88.904%
13034557528

push

github

segasai
allow .connect to duckdb

699 of 763 relevant lines covered (91.61%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.53
src/sqlutilpy/sqlutil.py
1
"""Sqlutilpy module to access SQL databases
2
"""
3
from __future__ import print_function
1✔
4
import numpy
1✔
5
import numpy as np
1✔
6
import psycopg
1✔
7
import threading
1✔
8
import collections
1✔
9
import warnings
1✔
10
from numpy.core import numeric as sb
1✔
11

12
try:
1✔
13
    import astropy.table as atpy
1✔
14
except ImportError:
×
15
    # astropy is not installed
16
    atpy = None
×
17
try:
1✔
18
    import pandas
1✔
19
except ImportError:
×
20
    # pandas is not installed
21
    pandas = None
×
22

23
import queue
1✔
24

25
_WAIT_SELECT_TIMEOUT = 10
1✔
26
STRLEN_DEFAULT = 20
1✔
27

28

29
class config:
1✔
30
    arraysize = 100000
1✔
31

32

33
class SqlUtilException(Exception):
1✔
34
    pass
1✔
35

36

37
def getConnection(db=None,
1✔
38
                  driver=None,
39
                  user=None,
40
                  password=None,
41
                  host=None,
42
                  port=None,
43
                  timeout=None):
44
    """
45
    Obtain the connection object to the DB.
46
    It may be useful to avoid reconnecting to the DB repeatedly.
47

48
    Parameters
49
    ----------
50

51
    db : string
52
        The name of the database (in case of PostgreSQL) or filename in
53
        case of sqlite db
54
    driver :  string
55
        The db driver (either 'psycopg' or 'sqlite3')
56
    user : string, optional
57
        Username
58
    password: string, optional
59
        Password
60
    host : string, optional
61
        Host-name
62
    port : integer
63
        Connection port (by default 5432 for PostgreSQL)
64
    timeout : integer
65
        Connection timeout for sqlite
66

67
    Returns
68
    -------
69
    conn : object
70
         Database Connection
71

72
    """
73
    if driver == 'psycopg2':
1✔
74
        warnings.warn(
1✔
75
            'psycopg2 driver is not supported anymore using psycopg instead')
76
        driver = 'psycopg'
1✔
77
    if driver == 'psycopg':
1✔
78
        conn_dict = dict()
1✔
79
        if db is not None:
1✔
80
            conn_dict['dbname'] = db
1✔
81
        if host is not None:
1✔
82
            conn_dict['host'] = host
1✔
83
        if port is not None:
1✔
84
            conn_dict['port'] = port
1✔
85
        if user is not None:
1✔
86
            conn_dict['user'] = user
1✔
87
        if password is not None:
1✔
88
            conn_dict['password'] = password
1✔
89
        conn = psycopg.connect(**conn_dict)
1✔
90
    elif driver == 'sqlite3':
1✔
91
        import sqlite3
1✔
92
        if timeout is None:
1✔
93
            timeout = 5
1✔
94
        conn = sqlite3.connect(db, timeout=timeout)
1✔
95
    elif driver == 'duckdb':
1✔
96
        import duckdb
1✔
97
        conn = duckdb.connect(db)
1✔
98
    else:
99
        raise Exception("Unknown driver")
1✔
100
    return conn
1✔
101

102

103
def getCursor(conn, driver=None, preamb=None, notNamed=False):
1✔
104
    """
105
    Retrieve the database cursor
106
    """
107
    if driver == 'psycopg2':
1✔
108
        warnings.warn(
×
109
            'psycopg2 driver is not supported anymore using psycopg instead')
110
        driver = 'psycopg'
×
111
    if driver == 'psycopg':
1✔
112
        cur = conn.cursor()
1✔
113
        if preamb is not None:
1✔
114
            cur.execute(preamb)
1✔
115
        else:
116
            cur.execute('set cursor_tuple_fraction TO 1')
1✔
117
            # this is required because otherwise PG may decide to execute a
118
            # different plan
119
        if notNamed:
1✔
120
            return cur
1✔
121
        cur = conn.cursor(name='sqlutilcursor')
1✔
122
        cur.arraysize = config.arraysize
1✔
123
    elif driver == 'sqlite3':
1✔
124
        cur = conn.cursor()
1✔
125
        if preamb is not None:
1✔
126
            cur.execute(preamb)
1✔
127
    elif driver == 'duckdb':
1✔
128
        cur = conn.cursor()
1✔
129
        if preamb is not None:
1✔
130
            cur.execute(preamb)
1✔
131
    else:
132
        raise SqlUtilException('unrecognized driver')
×
133
    return cur
1✔
134

135

136
def __fromrecords(recList, dtype=None, intNullVal=None):
1✔
137
    """
138
    This function was taken from np.core.records and updated to
139
    support conversion null integers to intNullVal
140
    """
141

142
    shape = None
1✔
143
    descr = sb.dtype((np.core.records.record, dtype))
1✔
144
    try:
1✔
145
        retval = sb.array(recList, dtype=descr)
1✔
146
    except TypeError:  # list of lists instead of list of tuples
1✔
147
        shape = (len(recList), )
1✔
148
        _array = np.core.records.recarray(shape, descr)
1✔
149
        try:
1✔
150
            for k in range(_array.size):
1✔
151
                _array[k] = tuple(recList[k])
1✔
152
        except TypeError:
1✔
153
            convs = []
1✔
154
            ncols = len(dtype.fields)
1✔
155
            for _k in dtype.names:
1✔
156
                _v = dtype.fields[_k]
1✔
157
                if _v[0] in [np.int16, np.int32, np.int64]:
1✔
158
                    convs.append(lambda x: intNullVal if x is None else x)
1✔
159
                else:
160
                    convs.append(lambda x: x)
1✔
161
            convs = tuple(convs)
1✔
162

163
            def convF(x):
1✔
164
                return [convs[_](x[_]) for _ in range(ncols)]
1✔
165

166
            for k in range(k, _array.size):
1✔
167
                try:
1✔
168
                    _array[k] = tuple(recList[k])
1✔
169
                except TypeError:
1✔
170
                    _array[k] = tuple(convF(recList[k]))
1✔
171
        return _array
1✔
172
    else:
173
        if shape is not None and retval.shape != shape:
1✔
174
            retval.shape = shape
×
175

176
    res = retval.view(numpy.core.records.recarray)
1✔
177

178
    return res
1✔
179

180

181
def __converter(qIn, qOut, endEvent, dtype, intNullVal):
1✔
182
    """ Convert the input stream of tuples into numpy arrays """
183
    while (not endEvent.is_set()):
1✔
184
        try:
1✔
185
            tups = qIn.get(True, 0.1)
1✔
186
        except queue.Empty:
1✔
187
            continue
1✔
188
        try:
1✔
189
            res = __fromrecords(tups, dtype=dtype, intNullVal=intNullVal)
1✔
190
        except Exception:
×
191
            print('Failed to convert input data into array')
×
192
            endEvent.set()
×
193
            raise
×
194
        qOut.put(res)
1✔
195

196

197
def __getDType(row, typeCodes, strLength):
1✔
198
    pgTypeHash = {
1✔
199
        16: bool,
200
        18: str,
201
        19: str,  # name type used in information schema
202
        20: 'i8',
203
        21: 'i2',
204
        23: 'i4',
205
        1007: 'i4',
206
        700: 'f4',
207
        701: 'f8',
208
        1000: bool,
209
        1005: 'i2',
210
        1007: 'i4',
211
        1016: 'i8',
212
        1021: 'f4',
213
        1022: 'f8',
214
        1700: 'f8',  # numeric
215
        1114: '<M8[us]',  # timestamp
216
        1082: '<M8[us]',  # date
217
        25: '|U%d',
218
        1042: '|U%d',  # character()
219
        1043: '|U%d'  # varchar
220
    }
221
    strTypes = [25, 1042, 1043]
1✔
222

223
    pgTypes = []
1✔
224

225
    for i, (curv, curt) in enumerate(zip(row, typeCodes)):
1✔
226
        if curt not in pgTypeHash:
1✔
227
            raise SqlUtilException('Unknown PG type %d' % curt)
1✔
228
        pgType = pgTypeHash[curt]
1✔
229
        if curt in strTypes:
1✔
230
            if curv is not None:
1✔
231
                # if the first string value is longer than
232
                # strLength use that as a maximum
233
                curmax = max(strLength, len(curv))
1✔
234
            else:
235
                # if the first value is null
236
                # just use strLength
237
                curmax = strLength
1✔
238
            pgType = pgType % (curmax, )
1✔
239
        if curt not in strTypes:
1✔
240
            try:
1✔
241
                len(curv)
1✔
242
                pgType = 'O'
1✔
243
            except TypeError:
1✔
244
                pass
1✔
245
        pgTypes.append(('a%d' % i, pgType))
1✔
246
    dtype = numpy.dtype(pgTypes)
1✔
247
    return dtype
1✔
248

249

250
def get(query,
1✔
251
        params=None,
252
        db="wsdb",
253
        driver="psycopg",
254
        user=None,
255
        password=None,
256
        host=None,
257
        preamb=None,
258
        conn=None,
259
        port=None,
260
        strLength=STRLEN_DEFAULT,
261
        timeout=None,
262
        notNamed=False,
263
        asDict=False,
264
        intNullVal=-9999):
265
    '''
266
    Executes the sql query and returns the tuple or dictionary
267
    with the numpy arrays.
268

269
    Parameters
270
    ----------
271
    query : string
272
        Query you want to execute, can include question
273
        marks to refer to query parameters
274
    params : tuple
275
        Query parameters
276
    conn : object
277
        The connection object to the DB (optional) to avoid reconnecting
278
    asDict : boolean
279
        Flag whether to retrieve the results as a dictionary with column
280
        names as keys
281
    strLength : integer
282
        The maximum length of the string.
283
        Strings will be truncated to this length
284
    intNullVal : integer, optional
285
        All the integer columns with nulls will have null replaced by
286
        this value
287
    db : string
288
        The name of the database
289
    driver : string, optional
290
        The sql driver to be used (psycopg or sqlite3)
291
    user : string, optional
292
        User name for the DB connection
293
    password : string, optional
294
        DB connection password
295
    host : string, optional
296
        Hostname of the database
297
    port : integer, optional
298
        Port of the database
299
    preamb : string
300
        SQL code to be executed before the query
301

302
    Returns
303
    -------
304
    ret : Tuple or dictionary
305
        By default you get a tuple with numpy arrays for each column
306
        in your query.
307
        If you specified asDict keyword then you get an ordered dictionary with
308
        your columns.
309

310
    Examples
311
    --------
312
    >>> a, b, c = sqlutilpy.get('select ra,dec,d25 from rc3')
313

314
    You can also use the parameters in your query:
315

316
    >>> a, b = sqlutilpy.get('select ra,dec from rc3 where name=?', "NGC 3166")
317
    '''
318

319
    connSupplied = (conn is not None)
1✔
320
    if not connSupplied:
1✔
321
        conn = getConnection(db=db,
1✔
322
                             driver=driver,
323
                             user=user,
324
                             password=password,
325
                             host=host,
326
                             port=port,
327
                             timeout=timeout)
328
    try:
1✔
329
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=notNamed)
1✔
330

331
        if params is None:
1✔
332
            res = cur.execute(query)
1✔
333
        else:
334
            res = cur.execute(query, params)
1✔
335

336
        qIn = queue.Queue(1)
1✔
337
        qOut = queue.Queue()
1✔
338
        endEvent = threading.Event()
1✔
339
        nrec = 0  # keeps the number of arrays sent to the other thread
1✔
340
        # minus number received
341
        reslist = []
1✔
342
        proc = None
1✔
343
        colNames = []
1✔
344
        if driver == 'psycopg2':
1✔
345
            warnings.warn('psycopg2 driver is not supported anymore. '
×
346
                          'We using psycopg instead')
347
            driver = 'psycopg'
×
348
        if driver == 'psycopg':
1✔
349
            try:
1✔
350
                while True:
1✔
351
                    # Iterating over the cursor, retrieving batches of results
352
                    # and then sending them for conversion
353
                    tups = cur.fetchmany()
1✔
354
                    desc = cur.description
1✔
355

356
                    # If the is just the start we need to launch the
357
                    # thread doing the conversion
358
                    no_results = tups == []
1✔
359
                    if nrec == 0:
1✔
360
                        typeCodes = [_tmp.type_code for _tmp in desc]
1✔
361
                        colNames = [_tmp.name for _tmp in cur.description]
1✔
362

363
                    # No more data
364
                    if no_results:
1✔
365
                        dtype = __getDType([None] * len(typeCodes), typeCodes,
1✔
366
                                           strLength)
367
                        break
1✔
368

369
                    dtype = __getDType(tups[0], typeCodes, strLength)
1✔
370

371
                    # Send the new batch for conversion
372
                    qIn.put(tups)
1✔
373

374
                    # If the is just the start we need to launch the
375
                    # thread doing the conversion
376
                    if nrec == 0:
1✔
377
                        proc = threading.Thread(target=__converter,
1✔
378
                                                args=(qIn, qOut, endEvent,
379
                                                      dtype, intNullVal))
380
                        proc.start()
1✔
381

382
                    # nrec is the number of batches in conversion currently
383
                    nrec += 1
1✔
384

385
                    # Try to retrieve one processed batch without waiting
386
                    # on it
387
                    try:
1✔
388
                        reslist.append(qOut.get(False))
1✔
389
                        nrec -= 1
1✔
390
                    except queue.Empty:
1✔
391
                        pass
1✔
392

393
                # Now we are done fetching the data from the DB, we
394
                # just need to assemble the converted results
395
                while (nrec != 0):
1✔
396
                    try:
1✔
397
                        reslist.append(qOut.get(True, 0.1))
1✔
398
                        nrec -= 1
1✔
399
                    except queue.Empty:
×
400
                        # continue looping unless the endEvent was set
401
                        # which should happen in the case of the crash
402
                        # of the converter thread
403
                        if endEvent.is_set():
×
404
                            raise Exception('Child thread failed')
×
405
                endEvent.set()
1✔
406
            except BaseException:
1✔
407
                endEvent.set()
1✔
408
                if proc is not None:
1✔
409
                    # notice that here the timeout is larger than the timeout
410
                    proc.join(0.2)
×
411
                    # in the converter process
412
                    if proc.is_alive():
×
413
                        proc.terminate()
×
414
                raise
1✔
415
            if proc is not None:
1✔
416
                proc.join()
1✔
417
            if reslist == []:
1✔
418
                nCols = len(desc)
1✔
419
                res = numpy.array([],
1✔
420
                                  dtype=numpy.dtype([('a%d' % i, 'f')
421
                                                     for i in range(nCols)]))
422
            else:
423
                res = numpy.concatenate(reslist)
1✔
424

425
        elif driver == 'sqlite3':
1✔
426
            tups = cur.fetchall()
1✔
427
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
428
            if len(tups) > 0:
1✔
429
                res = numpy.core.records.array(tups)
1✔
430
            else:
431
                return [[]] * len(cur.description)
1✔
432
        elif driver == 'duckdb':
1✔
433
            tups = cur.fetchall()
1✔
434
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
435
            if len(tups) > 0:
1✔
436
                res = numpy.core.records.array(tups)
1✔
437
            else:
438
                return [[]] * len(cur.description)
1✔
439

440
        res = [res[tmp] for tmp in res.dtype.names]
1✔
441

442
    except BaseException:
1✔
443
        failure_cleanup(conn, connSupplied)
1✔
444
        raise
1✔
445

446
    cur.close()
1✔
447

448
    if not connSupplied:
1✔
449
        conn.close()
1✔
450

451
    if asDict:
1✔
452
        resDict = collections.OrderedDict()
1✔
453
        repeats = {}
1✔
454
        for _n, _v in zip(colNames, res):
1✔
455
            if _n in resDict:
1✔
456
                curn = _n + '_%d' % (repeats[_n])
1✔
457
                repeats[_n] += 1
1✔
458
                warnings.warn(('Column name %s is repeated in the output, ' +
1✔
459
                               'new name %s assigned') % (_n, curn))
460
            else:
461
                repeats[_n] = 1
1✔
462
                curn = _n
1✔
463
            resDict[curn] = _v
1✔
464
        res = resDict
1✔
465
    return res
1✔
466

467

468
def execute(query,
1✔
469
            params=None,
470
            db='wsdb',
471
            driver="psycopg",
472
            user=None,
473
            password=None,
474
            host=None,
475
            conn=None,
476
            preamb=None,
477
            timeout=None,
478
            noCommit=False):
479
    """
480
    Execute a given SQL command without returning the results
481

482
    Parameters
483
    ----------
484
    query: string
485
        The query or command you are executing
486
    params: tuple, optional
487
        Optional parameters of your query
488
    db : string
489
        Database name
490
    driver : string
491
        Driver for the DB connection ('psycopg' or 'sqlite3')
492
    user : string, optional
493
        user name for the DB connection
494
    password : string, optional
495
        DB connection password
496
    host : string, optional
497
        Hostname of the database
498
    port : integer, optional
499
        Port of the database
500
    noCommit: bool
501
        By default execute() will commit your command.
502
        If you say noCommit, the commit won't be issued.
503

504
    Examples
505
    --------
506
    >>> sqlutil.execute('drop table mytab', conn=conn)
507
    >>> sqlutil.execute('create table mytab (a int)', db='mydb')
508

509
    """
510
    connSupplied = (conn is not None)
1✔
511
    if not connSupplied:
1✔
512
        conn = getConnection(db=db,
1✔
513
                             driver=driver,
514
                             user=user,
515
                             password=password,
516
                             host=host,
517
                             timeout=timeout)
518
    try:
1✔
519
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
520
        if params is not None:
1✔
521
            cur.execute(query, params)
1✔
522
        else:
523
            # sqlite3 doesn't like params here...
524
            cur.execute(query)
1✔
525
    except BaseException:
1✔
526
        failure_cleanup(conn, connSupplied)
1✔
527
        raise
1✔
528
    cur.close()
1✔
529
    if not noCommit:
1✔
530
        conn.commit()
1✔
531
    if not connSupplied:
1✔
532
        conn.close()  # do not close if we were given the connection
1✔
533

534

535
def __create_schema(tableName, arrays, names, temp=False):
1✔
536
    hash = dict([(np.int32, 'integer'), (np.int64, 'bigint'),
1✔
537
                 (np.uint64, 'bigint'), (np.int16, 'smallint'),
538
                 (np.uint8, 'smallint'), (np.int8, 'smallint'),
539
                 (np.float32, 'real'), (np.float64, 'double precision'),
540
                 (np.bytes_, 'varchar'), (np.str_, 'varchar'),
541
                 (np.bool_, 'boolean'), (np.datetime64, 'timestamp')])
542
    if temp:
1✔
543
        temp = 'temporary'
1✔
544
    else:
545
        temp = ''
1✔
546
    outp = 'create %s table %s ' % (temp, tableName)
1✔
547
    outp1 = []
1✔
548
    for arr, name in zip(arrays, names):
1✔
549
        curtyp = arr.dtype.type
1✔
550
        if curtyp == np.object_:
1✔
551
            # array
552
            curotyp = hash[arr[0].dtype.type] + '[]'
1✔
553
        else:
554
            curotyp = hash[arr.dtype.type]
1✔
555
        outp1.append(f'"{name}" {curotyp}')
1✔
556
    return outp + '(' + ','.join(outp1) + ')'
1✔
557

558

559
def __print_arrays(arrays, f, delimiter=' '):
1✔
560
    """
561
    print the input arrays into the open file separated by a delimiter
562
    """
563
    format_dict = dict([(np.int32, '%d'), (np.int64, '%d'), (np.int16, '%d'),
1✔
564
                        (np.int8, '%d'), (np.uint8, '%d'),
565
                        (np.float32, '%.18e'), (np.float64, '%.18e'),
566
                        (np.bytes_, '%s'), (np.str_, '%s'),
567
                        (np.datetime64, '%s'), (np.bool_, '%d')])
568
    fmts = []
1✔
569
    array_mode = False
1✔
570
    names = []
1✔
571
    for i, x in enumerate(arrays):
1✔
572
        names.append('f%d' % i)
1✔
573
        if x.dtype.type in format_dict:
1✔
574
            fmts.append(format_dict[x.dtype.type])
1✔
575
        else:
576
            if x.dtype.type == np.object_:
1✔
577
                array_mode = True
1✔
578
                fmts.append(None)
1✔
579
            else:
580
                raise RuntimeError(
×
581
                    'Unsupported column type %s. Please report' %
582
                    (x.dtype.type))
583
    recarr = np.rec.fromarrays(arrays)
1✔
584
    if not array_mode:
1✔
585
        np.savetxt(f, recarr, fmt=fmts, delimiter=delimiter)
1✔
586
    else:
587
        # this is really slow
588
        for row in recarr:
1✔
589
            for i, field in enumerate(names):
1✔
590
                if i != 0:
1✔
591
                    f.write(delimiter.encode())
1✔
592
                if fmts[i] is None:
1✔
593
                    curstr = np.array2string(
1✔
594
                        row[field],
595
                        max_line_width=np.inf,
596
                        threshold=None,
597
                        separator=',',
598
                        formatter={'all': lambda x: str(x)})
599
                    # formatter is needed because otherwise there is
600
                    # whitespace padding
601
                    curstr = '{' + curstr[1:-1] + '}'
1✔
602
                    f.write(curstr.encode())
1✔
603
                else:
604
                    f.write(str(row[field]).encode())
1✔
605
            f.write(b'\n')
1✔
606

607

608
def failure_cleanup(conn, connSupplied):
1✔
609
    try:
1✔
610
        conn.rollback()
1✔
611
    except Exception:
×
612
        pass
×
613
    if not connSupplied:
1✔
614
        try:
1✔
615
            conn.close()  # do not close if we were given the connection
1✔
616
        except Exception:
×
617
            pass
×
618

619

620
def upload(tableName,
1✔
621
           arrays,
622
           names=None,
623
           db="wsdb",
624
           driver="psycopg",
625
           user=None,
626
           password=None,
627
           host=None,
628
           conn=None,
629
           preamb=None,
630
           timeout=None,
631
           noCommit=False,
632
           temp=False,
633
           analyze=False,
634
           createTable=True,
635
           delimiter='|'):
636
    """
637
    Upload the data stored in the tuple of arrays in the DB
638

639
    Parameters
640
    ----------
641
    tableName : string
642
        The name of the table where the data will be uploaded
643
    arrays_or_table : tuple
644
        Tuple of arrays that will be columns of the new table
645
        If names are not specified, I this parameter can be pandas or
646
        astropy table
647
    names : tuple
648
    db: string
649
         Databas name
650
    driver: string
651
         Python database driver "psycopg",
652
    user: string,
653
    password: string
654
    host: string
655
    conn: object
656
         SQL connection
657
    preamb: string
658
         The string/query to be executed before your command
659
    noCommit: bool
660
         If true, the commit is not executed and the table will go away
661
         after the disconnect
662
    temp: bool
663
         If true a temporary table will be created
664
    analyze: bool
665
         if True, the table will be analyzed after the upload
666
    createTable: bool
667
         if True the table will be created before uploading (default)
668
    delimiter: string
669
         the string used for delimiting the input data when ingesting into
670
         the db (default is |)
671

672
    Examples
673
    --------
674
    >>> x = np.arange(10)
675
    >>> y = x**.5
676
    >>> sqlutilpy.upload('mytable', (x, y), ('xcol', 'ycol'))
677

678
    >>> T = astropy.Table({'x':[1, 2, 3], 'y':['a', 'b', 'c'])
679
    >>> sqlutilpy.upload('mytable', T)
680
    """
681
    connSupplied = (conn is not None)
1✔
682
    if not connSupplied:
1✔
683
        conn = getConnection(db=db,
1✔
684
                             driver=driver,
685
                             user=user,
686
                             password=password,
687
                             host=host,
688
                             timeout=timeout)
689
    if names is None:
1✔
690
        for i in range(1):
1✔
691
            # we assume that we were given a table
692
            if atpy is not None:
1✔
693
                if isinstance(arrays, atpy.Table):
1✔
694
                    names = arrays.columns
1✔
695
                    arrays = [arrays[_] for _ in names]
1✔
696
                    break
1✔
697
            if pandas is not None:
1✔
698
                if isinstance(arrays, pandas.DataFrame):
1✔
699
                    names = arrays.columns
1✔
700
                    arrays = [arrays[_] for _ in names]
1✔
701
                    break
1✔
702
            if isinstance(arrays, dict):
1✔
703
                names = arrays.keys()
1✔
704
                arrays = [arrays[_] for _ in names]
1✔
705
                break
1✔
706
            if names is None:
1✔
707
                raise Exception('you either need to give astropy \
1✔
708
table/pandas/dictionary or provide a separate list of arrays and their names')
709

710
    arrays = [np.asarray(_) for _ in arrays]
1✔
711
    if len(arrays) != len(names):
1✔
712
        raise RuntimeError('The column names list must have the same '
×
713
                           'length as array list')
714
    repl_char = {
1✔
715
        ' ': '_',
716
        '-': '_',
717
        '(': '_',
718
        ')': '_',
719
        '[': '_',
720
        ']': '_',
721
        '<': '_',
722
        '>': '_'
723
    }
724
    fixed_names = []
1✔
725
    for name in names:
1✔
726
        fixed_name = name + ''
1✔
727
        for k in repl_char.keys():
1✔
728
            fixed_name = fixed_name.replace(k, repl_char[k])
1✔
729
        if fixed_name != name:
1✔
730
            warnings.warn('''Renamed column '%s' to '%s' ''' %
1✔
731
                          (name, fixed_name))
732
        fixed_names.append(fixed_name)
1✔
733
    names = fixed_names
1✔
734
    try:
1✔
735
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
736
        if createTable:
1✔
737
            query1 = __create_schema(tableName, arrays, names, temp=temp)
1✔
738
            cur.execute(query1)
1✔
739
        nsplit = 100000
1✔
740
        N = len(arrays[0])
1✔
741
        names = ','.join(names)
1✔
742
        for i in range(0, N, nsplit):
1✔
743
            try:
1✔
744
                with cur.copy(f'''copy {tableName}({names}) from stdin
1✔
745
                        with delimiter '{delimiter}' ''') as copy:
746
                    __print_arrays([_[i:i + nsplit] for _ in arrays],
1✔
747
                                   copy,
748
                                   delimiter=delimiter)
749

750
            finally:
751
                pass
1✔
752
    except BaseException:
1✔
753
        failure_cleanup(conn, connSupplied)
1✔
754
        raise
1✔
755
    if analyze:
1✔
756
        cur.execute('analyze %s' % tableName)
1✔
757
    cur.close()
1✔
758
    if not noCommit:
1✔
759
        conn.commit()
1✔
760
    if not connSupplied:
1✔
761
        conn.close()  # do not close if we were given the connection
1✔
762

763

764
def local_join(query,
1✔
765
               tableName,
766
               arrays,
767
               names,
768
               db=None,
769
               driver="psycopg",
770
               user=None,
771
               password=None,
772
               host=None,
773
               port=None,
774
               conn=None,
775
               preamb=None,
776
               timeout=None,
777
               strLength=STRLEN_DEFAULT,
778
               intNullVal=-9999,
779
               asDict=False):
780
    """
781
    Join your local data in python with the data in the database
782
    This command first uploads the data in the DB creating a temporary table
783
    and then runs a user specified query that can your local data.
784

785
    Parameters
786
    ----------
787
    query : String with the query to be executed
788
    tableName : The name of the temporary table that is going to be created
789
    arrays : The tuple with list of arrays with the data to be loaded in the DB
790
    names : The tuple with the column names for the user table
791

792
    Examples
793
    --------
794
    This will extract the rows from the table sometable matching
795
    to the provided array x
796

797
    >>> x = np.arange(10)
798
    >>> y = x**.5
799
    >>> sqlutilpy.local_join('''
800
    ... SELECT s.* FROM mytable AS m LEFT JOIN sometable AS s
801
    ... ON s.x = m.x ORDER BY m.xcol''',
802
    ... 'mytable', (x, y), ('x', 'y'))
803
    """
804

805
    connSupplied = (conn is not None)
1✔
806
    if not connSupplied:
1✔
807
        conn = getConnection(db=db,
1✔
808
                             driver=driver,
809
                             user=user,
810
                             password=password,
811
                             host=host,
812
                             timeout=timeout,
813
                             port=port)
814
    try:
1✔
815
        upload(tableName,
1✔
816
               arrays,
817
               names,
818
               conn=conn,
819
               noCommit=True,
820
               temp=True,
821
               analyze=True)
822
    except BaseException:
×
823
        failure_cleanup(conn, connSupplied)
×
824
        raise
×
825
    try:
1✔
826
        res = get(query,
1✔
827
                  conn=conn,
828
                  preamb=preamb,
829
                  strLength=strLength,
830
                  asDict=asDict,
831
                  intNullVal=intNullVal)
832
    except BaseException:
×
833
        failure_cleanup(conn, connSupplied)
×
834
        raise
×
835

836
    conn.rollback()
1✔
837

838
    if not connSupplied:
1✔
839
        conn.close()
1✔
840
    return res
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc