• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

segasai / sqlutilpy / 14420409876

12 Apr 2025 02:14PM UTC coverage: 91.711% (-0.02%) from 91.732%
14420409876

push

github

segasai
get rid of obsolete numpy code

697 of 760 relevant lines covered (91.71%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.51
src/sqlutilpy/sqlutil.py
1
"""Sqlutilpy module to access SQL databases
2
"""
3
from __future__ import print_function
1✔
4
import numpy
1✔
5
import numpy as np
1✔
6
import psycopg
1✔
7
import threading
1✔
8
import collections
1✔
9
import warnings
1✔
10

11
try:
1✔
12
    import astropy.table as atpy
1✔
13
except ImportError:
×
14
    # astropy is not installed
15
    atpy = None
×
16
try:
1✔
17
    import pandas
1✔
18
except ImportError:
×
19
    # pandas is not installed
20
    pandas = None
×
21

22
import queue
1✔
23

24
_WAIT_SELECT_TIMEOUT = 10
1✔
25
STRLEN_DEFAULT = 20
1✔
26

27

28
class config:
1✔
29
    arraysize = 100000
1✔
30

31

32
class SqlUtilException(Exception):
1✔
33
    pass
1✔
34

35

36
def getConnection(db=None,
1✔
37
                  driver=None,
38
                  user=None,
39
                  password=None,
40
                  host=None,
41
                  port=None,
42
                  timeout=None):
43
    """
44
    Obtain the connection object to the DB.
45
    It may be useful to avoid reconnecting to the DB repeatedly.
46

47
    Parameters
48
    ----------
49

50
    db : string
51
        The name of the database (in case of PostgreSQL) or filename in
52
        case of sqlite db
53
    driver :  string
54
        The db driver (either 'psycopg' or 'sqlite3')
55
    user : string, optional
56
        Username
57
    password: string, optional
58
        Password
59
    host : string, optional
60
        Host-name
61
    port : integer
62
        Connection port (by default 5432 for PostgreSQL)
63
    timeout : integer
64
        Connection timeout for sqlite
65

66
    Returns
67
    -------
68
    conn : object
69
         Database Connection
70

71
    """
72
    if driver == 'psycopg2':
1✔
73
        warnings.warn(
1✔
74
            'psycopg2 driver is not supported anymore using psycopg instead')
75
        driver = 'psycopg'
1✔
76
    if driver == 'psycopg':
1✔
77
        conn_dict = dict()
1✔
78
        if db is not None:
1✔
79
            conn_dict['dbname'] = db
1✔
80
        if host is not None:
1✔
81
            conn_dict['host'] = host
1✔
82
        if port is not None:
1✔
83
            conn_dict['port'] = port
1✔
84
        if user is not None:
1✔
85
            conn_dict['user'] = user
1✔
86
        if password is not None:
1✔
87
            conn_dict['password'] = password
1✔
88
        conn = psycopg.connect(**conn_dict)
1✔
89
    elif driver == 'sqlite3':
1✔
90
        import sqlite3
1✔
91
        if timeout is None:
1✔
92
            timeout = 5
1✔
93
        conn = sqlite3.connect(db, timeout=timeout)
1✔
94
    elif driver == 'duckdb':
1✔
95
        import duckdb
1✔
96
        conn = duckdb.connect(db)
1✔
97
    else:
98
        raise Exception("Unknown driver")
1✔
99
    return conn
1✔
100

101

102
def getCursor(conn, driver=None, preamb=None, notNamed=False):
1✔
103
    """
104
    Retrieve the database cursor
105
    """
106
    if driver == 'psycopg2':
1✔
107
        warnings.warn(
×
108
            'psycopg2 driver is not supported anymore using psycopg instead')
109
        driver = 'psycopg'
×
110
    if driver == 'psycopg':
1✔
111
        cur = conn.cursor()
1✔
112
        if preamb is not None:
1✔
113
            cur.execute(preamb)
1✔
114
        else:
115
            cur.execute('set cursor_tuple_fraction TO 1')
1✔
116
            # this is required because otherwise PG may decide to execute a
117
            # different plan
118
        if notNamed:
1✔
119
            return cur
1✔
120
        cur = conn.cursor(name='sqlutilcursor')
1✔
121
        cur.arraysize = config.arraysize
1✔
122
    elif driver == 'sqlite3':
1✔
123
        cur = conn.cursor()
1✔
124
        if preamb is not None:
1✔
125
            cur.execute(preamb)
1✔
126
    elif driver == 'duckdb':
1✔
127
        cur = conn.cursor()
1✔
128
        if preamb is not None:
1✔
129
            cur.execute(preamb)
1✔
130
    else:
131
        raise SqlUtilException('unrecognized driver')
×
132
    return cur
1✔
133

134

135
def __fromrecords(recList, dtype=None, intNullVal=None):
1✔
136
    """
137
    This function was taken from np.core.records and updated to
138
    support conversion null integers to intNullVal
139
    """
140

141
    shape = None
1✔
142
    descr = np.dtype((np.record, dtype))
1✔
143
    try:
1✔
144
        retval = np.array(recList, dtype=descr)
1✔
145
    except TypeError:  # list of lists instead of list of tuples
1✔
146
        shape = (len(recList), )
1✔
147
        _array = np.recarray(shape, descr)
1✔
148
        try:
1✔
149
            for k in range(_array.size):
1✔
150
                _array[k] = tuple(recList[k])
1✔
151
        except TypeError:
1✔
152
            convs = []
1✔
153
            ncols = len(dtype.fields)
1✔
154
            for _k in dtype.names:
1✔
155
                _v = dtype.fields[_k]
1✔
156
                if _v[0] in [np.int16, np.int32, np.int64]:
1✔
157
                    convs.append(lambda x: intNullVal if x is None else x)
1✔
158
                else:
159
                    convs.append(lambda x: x)
1✔
160
            convs = tuple(convs)
1✔
161

162
            def convF(x):
1✔
163
                return [convs[_](x[_]) for _ in range(ncols)]
1✔
164

165
            for k in range(k, _array.size):
1✔
166
                try:
1✔
167
                    _array[k] = tuple(recList[k])
1✔
168
                except TypeError:
1✔
169
                    _array[k] = tuple(convF(recList[k]))
1✔
170
        return _array
1✔
171
    else:
172
        if shape is not None and retval.shape != shape:
1✔
173
            retval.shape = shape
×
174

175
    res = retval.view(numpy.recarray)
1✔
176

177
    return res
1✔
178

179

180
def __converter(qIn, qOut, endEvent, dtype, intNullVal):
1✔
181
    """ Convert the input stream of tuples into numpy arrays """
182
    while (not endEvent.is_set()):
1✔
183
        try:
1✔
184
            tups = qIn.get(True, 0.1)
1✔
185
        except queue.Empty:
1✔
186
            continue
1✔
187
        try:
1✔
188
            res = __fromrecords(tups, dtype=dtype, intNullVal=intNullVal)
1✔
189
        except Exception:
×
190
            print('Failed to convert input data into array')
×
191
            endEvent.set()
×
192
            raise
×
193
        qOut.put(res)
1✔
194

195

196
def __getDType(row, typeCodes, strLength):
1✔
197
    pgTypeHash = {
1✔
198
        16: bool,
199
        18: str,
200
        19: str,  # name type used in information schema
201
        20: 'i8',
202
        21: 'i2',
203
        23: 'i4',
204
        1007: 'i4',
205
        700: 'f4',
206
        701: 'f8',
207
        1000: bool,
208
        1005: 'i2',
209
        1007: 'i4',
210
        1016: 'i8',
211
        1021: 'f4',
212
        1022: 'f8',
213
        1700: 'f8',  # numeric
214
        1114: '<M8[us]',  # timestamp
215
        1184: '<M8[us]',  # timestamp with timezone
216
        1082: '<M8[us]',  # date
217
        25: '|U%d',
218
        1042: '|U%d',  # character()
219
        1043: '|U%d'  # varchar
220
    }
221
    strTypes = [25, 1042, 1043]
1✔
222

223
    pgTypes = []
1✔
224

225
    for i, (curv, curt) in enumerate(zip(row, typeCodes)):
1✔
226
        if curt not in pgTypeHash:
1✔
227
            raise SqlUtilException('Unknown PG type %d' % curt)
1✔
228
        pgType = pgTypeHash[curt]
1✔
229
        if curt in strTypes:
1✔
230
            if curv is not None:
1✔
231
                # if the first string value is longer than
232
                # strLength use that as a maximum
233
                curmax = max(strLength, len(curv))
1✔
234
            else:
235
                # if the first value is null
236
                # just use strLength
237
                curmax = strLength
1✔
238
            pgType = pgType % (curmax, )
1✔
239
        if curt not in strTypes:
1✔
240
            try:
1✔
241
                len(curv)
1✔
242
                pgType = 'O'
1✔
243
            except TypeError:
1✔
244
                pass
1✔
245
        pgTypes.append(('a%d' % i, pgType))
1✔
246
    dtype = numpy.dtype(pgTypes)
1✔
247
    return dtype
1✔
248

249

250
def get(query,
1✔
251
        params=None,
252
        db="wsdb",
253
        driver="psycopg",
254
        user=None,
255
        password=None,
256
        host=None,
257
        preamb=None,
258
        conn=None,
259
        port=None,
260
        strLength=STRLEN_DEFAULT,
261
        timeout=None,
262
        notNamed=False,
263
        asDict=False,
264
        intNullVal=-9999):
265
    '''
266
    Executes the sql query and returns the tuple or dictionary
267
    with the numpy arrays.
268

269
    Parameters
270
    ----------
271
    query : string
272
        Query you want to execute, can include question
273
        marks to refer to query parameters
274
    params : tuple
275
        Query parameters
276
    conn : object
277
        The connection object to the DB (optional) to avoid reconnecting
278
    asDict : boolean
279
        Flag whether to retrieve the results as a dictionary with column
280
        names as keys
281
    strLength : integer
282
        The maximum length of the string.
283
        Strings will be truncated to this length
284
    intNullVal : integer, optional
285
        All the integer columns with nulls will have null replaced by
286
        this value
287
    db : string
288
        The name of the database
289
    driver : string, optional
290
        The sql driver to be used (psycopg or sqlite3)
291
    user : string, optional
292
        User name for the DB connection
293
    password : string, optional
294
        DB connection password
295
    host : string, optional
296
        Hostname of the database
297
    port : integer, optional
298
        Port of the database
299
    preamb : string
300
        SQL code to be executed before the query
301

302
    Returns
303
    -------
304
    ret : Tuple or dictionary
305
        By default you get a tuple with numpy arrays for each column
306
        in your query.
307
        If you specified asDict keyword then you get an ordered dictionary with
308
        your columns.
309

310
    Examples
311
    --------
312
    >>> a, b, c = sqlutilpy.get('select ra,dec,d25 from rc3')
313

314
    You can also use the parameters in your query:
315

316
    >>> a, b = sqlutilpy.get('select ra,dec from rc3 where name=?', "NGC 3166")
317
    '''
318

319
    connSupplied = (conn is not None)
1✔
320
    if not connSupplied:
1✔
321
        conn = getConnection(db=db,
1✔
322
                             driver=driver,
323
                             user=user,
324
                             password=password,
325
                             host=host,
326
                             port=port,
327
                             timeout=timeout)
328
    try:
1✔
329
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=notNamed)
1✔
330

331
        if params is None:
1✔
332
            res = cur.execute(query)
1✔
333
        else:
334
            res = cur.execute(query, params)
1✔
335

336
        qIn = queue.Queue(1)
1✔
337
        qOut = queue.Queue()
1✔
338
        endEvent = threading.Event()
1✔
339
        nrec = 0  # keeps the number of arrays sent to the other thread
1✔
340
        # minus number received
341
        reslist = []
1✔
342
        proc = None
1✔
343
        colNames = []
1✔
344
        if driver == 'psycopg2':
1✔
345
            warnings.warn('psycopg2 driver is not supported anymore. '
×
346
                          'We using psycopg instead')
347
            driver = 'psycopg'
×
348
        if driver == 'psycopg':
1✔
349
            try:
1✔
350
                while True:
1✔
351
                    # Iterating over the cursor, retrieving batches of results
352
                    # and then sending them for conversion
353
                    tups = cur.fetchmany()
1✔
354
                    desc = cur.description
1✔
355

356
                    # If the is just the start we need to launch the
357
                    # thread doing the conversion
358
                    no_results = tups == []
1✔
359
                    if nrec == 0:
1✔
360
                        typeCodes = [_tmp.type_code for _tmp in desc]
1✔
361
                        colNames = [_tmp.name for _tmp in cur.description]
1✔
362

363
                    # No more data
364
                    if no_results:
1✔
365
                        dtype = __getDType([None] * len(typeCodes), typeCodes,
1✔
366
                                           strLength)
367
                        break
1✔
368

369
                    dtype = __getDType(tups[0], typeCodes, strLength)
1✔
370

371
                    # Send the new batch for conversion
372
                    qIn.put(tups)
1✔
373

374
                    # If the is just the start we need to launch the
375
                    # thread doing the conversion
376
                    if nrec == 0:
1✔
377
                        proc = threading.Thread(target=__converter,
1✔
378
                                                args=(qIn, qOut, endEvent,
379
                                                      dtype, intNullVal))
380
                        proc.start()
1✔
381

382
                    # nrec is the number of batches in conversion currently
383
                    nrec += 1
1✔
384

385
                    # Try to retrieve one processed batch without waiting
386
                    # on it
387
                    try:
1✔
388
                        reslist.append(qOut.get(False))
1✔
389
                        nrec -= 1
1✔
390
                    except queue.Empty:
1✔
391
                        pass
1✔
392

393
                # Now we are done fetching the data from the DB, we
394
                # just need to assemble the converted results
395
                while (nrec != 0):
1✔
396
                    try:
1✔
397
                        reslist.append(qOut.get(True, 0.1))
1✔
398
                        nrec -= 1
1✔
399
                    except queue.Empty:
×
400
                        # continue looping unless the endEvent was set
401
                        # which should happen in the case of the crash
402
                        # of the converter thread
403
                        if endEvent.is_set():
×
404
                            raise Exception('Child thread failed')
×
405
                endEvent.set()
1✔
406
            except BaseException:
1✔
407
                endEvent.set()
1✔
408
                if proc is not None:
1✔
409
                    # notice that here the timeout is larger than the timeout
410
                    proc.join(0.2)
×
411
                    # in the converter process
412
                    if proc.is_alive():
×
413
                        # could not kill
414
                        pass
×
415
                raise
1✔
416
            if proc is not None:
1✔
417
                proc.join()
1✔
418
            if reslist == []:
1✔
419
                nCols = len(desc)
1✔
420
                res = numpy.array([],
1✔
421
                                  dtype=numpy.dtype([('a%d' % i, 'f')
422
                                                     for i in range(nCols)]))
423
            else:
424
                res = numpy.concatenate(reslist)
1✔
425

426
        elif driver == 'sqlite3':
1✔
427
            tups = cur.fetchall()
1✔
428
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
429
            if len(tups) > 0:
1✔
430
                res = numpy.rec.array(tups)
1✔
431
            else:
432
                return [[]] * len(cur.description)
1✔
433
        elif driver == 'duckdb':
1✔
434
            tups = cur.fetchall()
1✔
435
            colNames = [_tmp[0] for _tmp in cur.description]
1✔
436
            if len(tups) > 0:
1✔
437
                res = numpy.rec.array(tups)
1✔
438
            else:
439
                return [[]] * len(cur.description)
1✔
440

441
        res = [res[tmp] for tmp in res.dtype.names]
1✔
442

443
    except BaseException:
1✔
444
        failure_cleanup(conn, connSupplied)
1✔
445
        raise
1✔
446

447
    cur.close()
1✔
448

449
    if not connSupplied:
1✔
450
        conn.close()
1✔
451

452
    if asDict:
1✔
453
        resDict = collections.OrderedDict()
1✔
454
        repeats = {}
1✔
455
        for _n, _v in zip(colNames, res):
1✔
456
            if _n in resDict:
1✔
457
                curn = _n + '_%d' % (repeats[_n])
1✔
458
                repeats[_n] += 1
1✔
459
                warnings.warn(('Column name %s is repeated in the output, ' +
1✔
460
                               'new name %s assigned') % (_n, curn))
461
            else:
462
                repeats[_n] = 1
1✔
463
                curn = _n
1✔
464
            resDict[curn] = _v
1✔
465
        res = resDict
1✔
466
    return res
1✔
467

468

469
def execute(query,
1✔
470
            params=None,
471
            db='wsdb',
472
            driver="psycopg",
473
            user=None,
474
            password=None,
475
            host=None,
476
            conn=None,
477
            preamb=None,
478
            timeout=None,
479
            noCommit=False):
480
    """
481
    Execute a given SQL command without returning the results
482

483
    Parameters
484
    ----------
485
    query: string
486
        The query or command you are executing
487
    params: tuple, optional
488
        Optional parameters of your query
489
    db : string
490
        Database name
491
    driver : string
492
        Driver for the DB connection ('psycopg' or 'sqlite3')
493
    user : string, optional
494
        user name for the DB connection
495
    password : string, optional
496
        DB connection password
497
    host : string, optional
498
        Hostname of the database
499
    port : integer, optional
500
        Port of the database
501
    noCommit: bool
502
        By default execute() will commit your command.
503
        If you say noCommit, the commit won't be issued.
504

505
    Examples
506
    --------
507
    >>> sqlutil.execute('drop table mytab', conn=conn)
508
    >>> sqlutil.execute('create table mytab (a int)', db='mydb')
509

510
    """
511
    connSupplied = (conn is not None)
1✔
512
    if not connSupplied:
1✔
513
        conn = getConnection(db=db,
1✔
514
                             driver=driver,
515
                             user=user,
516
                             password=password,
517
                             host=host,
518
                             timeout=timeout)
519
    try:
1✔
520
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
521
        if params is not None:
1✔
522
            cur.execute(query, params)
1✔
523
        else:
524
            # sqlite3 doesn't like params here...
525
            cur.execute(query)
1✔
526
    except BaseException:
1✔
527
        failure_cleanup(conn, connSupplied)
1✔
528
        raise
1✔
529
    cur.close()
1✔
530
    if not noCommit:
1✔
531
        conn.commit()
1✔
532
    if not connSupplied:
1✔
533
        conn.close()  # do not close if we were given the connection
1✔
534

535

536
def __create_schema(tableName, arrays, names, temp=False):
1✔
537
    hash = dict([(np.int32, 'integer'), (np.int64, 'bigint'),
1✔
538
                 (np.uint64, 'bigint'), (np.int16, 'smallint'),
539
                 (np.uint8, 'smallint'), (np.int8, 'smallint'),
540
                 (np.float32, 'real'), (np.float64, 'double precision'),
541
                 (np.bytes_, 'varchar'), (np.str_, 'varchar'),
542
                 (np.bool_, 'boolean'), (np.datetime64, 'timestamp')])
543
    if temp:
1✔
544
        temp = 'temporary'
1✔
545
    else:
546
        temp = ''
1✔
547
    outp = 'create %s table %s ' % (temp, tableName)
1✔
548
    outp1 = []
1✔
549
    for arr, name in zip(arrays, names):
1✔
550
        curtyp = arr.dtype.type
1✔
551
        if curtyp == np.object_:
1✔
552
            # array
553
            curotyp = hash[arr[0].dtype.type] + '[]'
1✔
554
        else:
555
            curotyp = hash[arr.dtype.type]
1✔
556
        outp1.append(f'"{name}" {curotyp}')
1✔
557
    return outp + '(' + ','.join(outp1) + ')'
1✔
558

559

560
def __print_arrays(arrays, f, delimiter=' '):
1✔
561
    """
562
    print the input arrays into the open file separated by a delimiter
563
    """
564
    format_dict = dict([(np.int32, '%d'), (np.int64, '%d'), (np.int16, '%d'),
1✔
565
                        (np.int8, '%d'), (np.uint8, '%d'),
566
                        (np.float32, '%.18e'), (np.float64, '%.18e'),
567
                        (np.bytes_, '%s'), (np.str_, '%s'),
568
                        (np.datetime64, '%s'), (np.bool_, '%d')])
569
    fmts = []
1✔
570
    array_mode = False
1✔
571
    names = []
1✔
572
    for i, x in enumerate(arrays):
1✔
573
        names.append('f%d' % i)
1✔
574
        if x.dtype.type in format_dict:
1✔
575
            fmts.append(format_dict[x.dtype.type])
1✔
576
        else:
577
            if x.dtype.type == np.object_:
1✔
578
                array_mode = True
1✔
579
                fmts.append(None)
1✔
580
            else:
581
                raise RuntimeError(
×
582
                    'Unsupported column type %s. Please report' %
583
                    (x.dtype.type))
584
    recarr = np.rec.fromarrays(arrays)
1✔
585
    if not array_mode:
1✔
586
        np.savetxt(f, recarr, fmt=fmts, delimiter=delimiter)
1✔
587
    else:
588
        # this is really slow
589
        for row in recarr:
1✔
590
            for i, field in enumerate(names):
1✔
591
                if i != 0:
1✔
592
                    f.write(delimiter.encode())
1✔
593
                if fmts[i] is None:
1✔
594
                    curstr = np.array2string(
1✔
595
                        row[field],
596
                        max_line_width=np.inf,
597
                        threshold=None,
598
                        separator=',',
599
                        formatter={'all': lambda x: str(x)})
600
                    # formatter is needed because otherwise there is
601
                    # whitespace padding
602
                    curstr = '{' + curstr[1:-1] + '}'
1✔
603
                    f.write(curstr.encode())
1✔
604
                else:
605
                    f.write(str(row[field]).encode())
1✔
606
            f.write(b'\n')
1✔
607

608

609
def failure_cleanup(conn, connSupplied):
1✔
610
    try:
1✔
611
        conn.rollback()
1✔
612
    except Exception:
×
613
        pass
×
614
    if not connSupplied:
1✔
615
        try:
1✔
616
            conn.close()  # do not close if we were given the connection
1✔
617
        except Exception:
×
618
            pass
×
619

620

621
def upload(tableName,
1✔
622
           arrays,
623
           names=None,
624
           db="wsdb",
625
           driver="psycopg",
626
           user=None,
627
           password=None,
628
           host=None,
629
           conn=None,
630
           preamb=None,
631
           timeout=None,
632
           noCommit=False,
633
           temp=False,
634
           analyze=False,
635
           createTable=True,
636
           delimiter='|'):
637
    """
638
    Upload the data stored in the tuple of arrays in the DB
639

640
    Parameters
641
    ----------
642
    tableName : string
643
        The name of the table where the data will be uploaded
644
    arrays_or_table : tuple
645
        Tuple of arrays that will be columns of the new table
646
        If names are not specified, I this parameter can be pandas or
647
        astropy table
648
    names : tuple
649
    db: string
650
         Databas name
651
    driver: string
652
         Python database driver "psycopg",
653
    user: string,
654
    password: string
655
    host: string
656
    conn: object
657
         SQL connection
658
    preamb: string
659
         The string/query to be executed before your command
660
    noCommit: bool
661
         If true, the commit is not executed and the table will go away
662
         after the disconnect
663
    temp: bool
664
         If true a temporary table will be created
665
    analyze: bool
666
         if True, the table will be analyzed after the upload
667
    createTable: bool
668
         if True the table will be created before uploading (default)
669
    delimiter: string
670
         the string used for delimiting the input data when ingesting into
671
         the db (default is |)
672

673
    Examples
674
    --------
675
    >>> x = np.arange(10)
676
    >>> y = x**.5
677
    >>> sqlutilpy.upload('mytable', (x, y), ('xcol', 'ycol'))
678

679
    >>> T = astropy.Table({'x':[1, 2, 3], 'y':['a', 'b', 'c'])
680
    >>> sqlutilpy.upload('mytable', T)
681
    """
682
    connSupplied = (conn is not None)
1✔
683
    if not connSupplied:
1✔
684
        conn = getConnection(db=db,
1✔
685
                             driver=driver,
686
                             user=user,
687
                             password=password,
688
                             host=host,
689
                             timeout=timeout)
690
    if names is None:
1✔
691
        for i in range(1):
1✔
692
            # we assume that we were given a table
693
            if atpy is not None:
1✔
694
                if isinstance(arrays, atpy.Table):
1✔
695
                    names = arrays.columns
1✔
696
                    arrays = [arrays[_] for _ in names]
1✔
697
                    break
1✔
698
            if pandas is not None:
1✔
699
                if isinstance(arrays, pandas.DataFrame):
1✔
700
                    names = arrays.columns
1✔
701
                    arrays = [arrays[_] for _ in names]
1✔
702
                    break
1✔
703
            if isinstance(arrays, dict):
1✔
704
                names = arrays.keys()
1✔
705
                arrays = [arrays[_] for _ in names]
1✔
706
                break
1✔
707
            if names is None:
1✔
708
                raise Exception('you either need to give astropy \
1✔
709
table/pandas/dictionary or provide a separate list of arrays and their names')
710

711
    arrays = [np.asarray(_) for _ in arrays]
1✔
712
    if len(arrays) != len(names):
1✔
713
        raise RuntimeError('The column names list must have the same '
×
714
                           'length as array list')
715
    repl_char = {
1✔
716
        ' ': '_',
717
        '-': '_',
718
        '(': '_',
719
        ')': '_',
720
        '[': '_',
721
        ']': '_',
722
        '<': '_',
723
        '>': '_'
724
    }
725
    fixed_names = []
1✔
726
    for name in names:
1✔
727
        fixed_name = name + ''
1✔
728
        for k in repl_char.keys():
1✔
729
            fixed_name = fixed_name.replace(k, repl_char[k])
1✔
730
        if fixed_name != name:
1✔
731
            warnings.warn('''Renamed column '%s' to '%s' ''' %
1✔
732
                          (name, fixed_name))
733
        fixed_names.append(fixed_name)
1✔
734
    names = fixed_names
1✔
735
    try:
1✔
736
        cur = getCursor(conn, driver=driver, preamb=preamb, notNamed=True)
1✔
737
        if createTable:
1✔
738
            query1 = __create_schema(tableName, arrays, names, temp=temp)
1✔
739
            cur.execute(query1)
1✔
740
        nsplit = 100000
1✔
741
        N = len(arrays[0])
1✔
742
        names = ','.join(names)
1✔
743
        for i in range(0, N, nsplit):
1✔
744
            try:
1✔
745
                with cur.copy(f'''copy {tableName}({names}) from stdin
1✔
746
                        with delimiter '{delimiter}' ''') as copy:
747
                    __print_arrays([_[i:i + nsplit] for _ in arrays],
1✔
748
                                   copy,
749
                                   delimiter=delimiter)
750

751
            finally:
752
                pass
1✔
753
    except BaseException:
1✔
754
        failure_cleanup(conn, connSupplied)
1✔
755
        raise
1✔
756
    if analyze:
1✔
757
        cur.execute('analyze %s' % tableName)
1✔
758
    cur.close()
1✔
759
    if not noCommit:
1✔
760
        conn.commit()
1✔
761
    if not connSupplied:
1✔
762
        conn.close()  # do not close if we were given the connection
1✔
763

764

765
def local_join(query,
1✔
766
               tableName,
767
               arrays,
768
               names,
769
               db=None,
770
               driver="psycopg",
771
               user=None,
772
               password=None,
773
               host=None,
774
               port=None,
775
               conn=None,
776
               preamb=None,
777
               timeout=None,
778
               strLength=STRLEN_DEFAULT,
779
               intNullVal=-9999,
780
               asDict=False):
781
    """
782
    Join your local data in python with the data in the database
783
    This command first uploads the data in the DB creating a temporary table
784
    and then runs a user specified query that can your local data.
785

786
    Parameters
787
    ----------
788
    query : String with the query to be executed
789
    tableName : The name of the temporary table that is going to be created
790
    arrays : The tuple with list of arrays with the data to be loaded in the DB
791
    names : The tuple with the column names for the user table
792

793
    Examples
794
    --------
795
    This will extract the rows from the table sometable matching
796
    to the provided array x
797

798
    >>> x = np.arange(10)
799
    >>> y = x**.5
800
    >>> sqlutilpy.local_join('''
801
    ... SELECT s.* FROM mytable AS m LEFT JOIN sometable AS s
802
    ... ON s.x = m.x ORDER BY m.xcol''',
803
    ... 'mytable', (x, y), ('x', 'y'))
804
    """
805

806
    connSupplied = (conn is not None)
1✔
807
    if not connSupplied:
1✔
808
        conn = getConnection(db=db,
1✔
809
                             driver=driver,
810
                             user=user,
811
                             password=password,
812
                             host=host,
813
                             timeout=timeout,
814
                             port=port)
815
    try:
1✔
816
        upload(tableName,
1✔
817
               arrays,
818
               names,
819
               conn=conn,
820
               noCommit=True,
821
               temp=True,
822
               analyze=True)
823
    except BaseException:
×
824
        failure_cleanup(conn, connSupplied)
×
825
        raise
×
826
    try:
1✔
827
        res = get(query,
1✔
828
                  conn=conn,
829
                  preamb=preamb,
830
                  strLength=strLength,
831
                  asDict=asDict,
832
                  intNullVal=intNullVal)
833
    except BaseException:
×
834
        failure_cleanup(conn, connSupplied)
×
835
        raise
×
836

837
    conn.rollback()
1✔
838

839
    if not connSupplied:
1✔
840
        conn.close()
1✔
841
    return res
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc