• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

georgia-tech-db / eva / d8e82e0e-8b3d-4d22-9a86-fe9f65c8b264

28 Oct 2023 10:50PM UTC coverage: 67.392% (-9.6%) from 76.956%
d8e82e0e-8b3d-4d22-9a86-fe9f65c8b264

push

circle-ci

xzdandy
Use table for side by side display

8763 of 13003 relevant lines covered (67.39%)

0.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

0.0
/evadb/third_party/databases/postgres/postgres_handler.py
1
# coding=utf-8
2
# Copyright 2018-2023 EvaDB
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
import numpy as np
×
16
import pandas as pd
×
17
import psycopg2
×
18

19
from evadb.third_party.databases.types import (
×
20
    DBHandler,
21
    DBHandlerResponse,
22
    DBHandlerStatus,
23
)
24

25

26
class PostgresHandler(DBHandler):
×
27
    def __init__(self, name: str, **kwargs):
×
28
        """
29
        Initialize the handler.
30
        Args:
31
            name (str): name of the DB handler instance
32
            **kwargs: arbitrary keyword arguments for establishing the connection.
33
        """
34
        super().__init__(name)
×
35
        self.host = kwargs.get("host")
×
36
        self.port = kwargs.get("port")
×
37
        self.user = kwargs.get("user")
×
38
        self.password = kwargs.get("password")
×
39
        self.database = kwargs.get("database")
×
40
        self.connection = None
×
41

42
    def connect(self) -> DBHandlerStatus:
×
43
        """
44
        Set up the connection required by the handler.
45
        Returns:
46
            DBHandlerStatus
47
        """
48
        try:
×
49
            self.connection = psycopg2.connect(
×
50
                host=self.host,
51
                port=self.port,
52
                user=self.user,
53
                password=self.password,
54
                database=self.database,
55
            )
56
            self.connection.autocommit = True
×
57
            return DBHandlerStatus(status=True)
×
58
        except psycopg2.Error as e:
59
            return DBHandlerStatus(status=False, error=str(e))
60

61
    def disconnect(self):
×
62
        """
63
        Close any existing connections.
64
        """
65
        if self.connection:
×
66
            self.connection.close()
×
67

68
    def get_sqlalchmey_uri(self) -> str:
×
69
        return f"postgresql+psycopg2://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}"
×
70

71
    def check_connection(self) -> DBHandlerStatus:
×
72
        """
73
        Check connection to the handler.
74
        Returns:
75
            DBHandlerStatus
76
        """
77
        if self.connection:
×
78
            return DBHandlerStatus(status=True)
×
79
        else:
80
            return DBHandlerStatus(status=False, error="Not connected to the database.")
×
81

82
    def get_tables(self) -> DBHandlerResponse:
×
83
        """
84
        Return the list of tables in the database.
85
        Returns:
86
            DBHandlerResponse
87
        """
88
        if not self.connection:
×
89
            return DBHandlerResponse(data=None, error="Not connected to the database.")
×
90

91
        try:
×
92
            query = "SELECT table_name FROM information_schema.tables WHERE table_schema NOT IN ('information_schema', 'pg_catalog')"
×
93
            tables_df = pd.read_sql_query(query, self.connection)
×
94
            return DBHandlerResponse(data=tables_df)
×
95
        except psycopg2.Error as e:
96
            return DBHandlerResponse(data=None, error=str(e))
97

98
    def get_columns(self, table_name: str) -> DBHandlerResponse:
×
99
        """
100
        Returns the list of columns for the given table.
101
        Args:
102
            table_name (str): name of the table whose columns are to be retrieved.
103
        Returns:
104
            DBHandlerResponse
105
        """
106
        if not self.connection:
×
107
            return DBHandlerResponse(data=None, error="Not connected to the database.")
×
108

109
        try:
×
110
            query = f"SELECT column_name as name, data_type as dtype, udt_name FROM information_schema.columns WHERE table_name='{table_name}'"
×
111
            columns_df = pd.read_sql_query(query, self.connection)
×
112
            columns_df["dtype"] = columns_df.apply(
×
113
                lambda x: self._pg_to_python_types(x["dtype"], x["udt_name"]), axis=1
114
            )
115
            return DBHandlerResponse(data=columns_df)
×
116
        except psycopg2.Error as e:
117
            return DBHandlerResponse(data=None, error=str(e))
118

119
    def _fetch_results_as_df(self, cursor):
×
120
        """
121
        This is currently the only clean solution that we have found so far.
122
        Reference to Postgres API: https://www.psycopg.org/docs/cursor.html#fetch
123

124
        In short, currently there is no very clean programming way to differentiate
125
        CREATE, INSERT, SELECT. CREATE and INSERT do not return any result, so calling
126
        fetchall() on those will yield a programming error. Cursor has an attribute
127
        rowcount, but it indicates # of rows that are affected. In that case, for both
128
        INSERT and SELECT rowcount is not 0, so we also cannot use this API to
129
        differentiate INSERT and SELECT.
130
        """
131
        try:
×
132
            res = cursor.fetchall()
×
133
            res_df = pd.DataFrame(
×
134
                res, columns=[desc[0].lower() for desc in cursor.description]
135
            )
136
            return res_df
×
137
        except psycopg2.ProgrammingError as e:
138
            if str(e) == "no results to fetch":
139
                return pd.DataFrame({"status": ["success"]})
140
            raise e
141

142
    def execute_native_query(self, query_string: str) -> DBHandlerResponse:
×
143
        """
144
        Executes the native query on the database.
145
        Args:
146
            query_string (str): query in native format
147
        Returns:
148
            DBHandlerResponse
149
        """
150
        if not self.connection:
×
151
            return DBHandlerResponse(data=None, error="Not connected to the database.")
×
152

153
        try:
×
154
            cursor = self.connection.cursor()
×
155
            cursor.execute(query_string)
×
156
            return DBHandlerResponse(data=self._fetch_results_as_df(cursor))
×
157
        except psycopg2.Error as e:
158
            return DBHandlerResponse(data=None, error=str(e))
159

160
    def _pg_to_python_types(self, pg_type: str, udt_name: str):
×
161
        primitive_type_mapping = {
×
162
            "integer": int,
163
            "bigint": int,
164
            "smallint": int,
165
            "numeric": float,
166
            "real": float,
167
            "double precision": float,
168
            "character": str,
169
            "character varying": str,
170
            "text": str,
171
            "boolean": bool,
172
            # Add more mappings as needed
173
        }
174

175
        user_defined_type_mapping = {
×
176
            "vector": np.ndarray
177
            # Handle user defined types constructed by Postgres extension.
178
        }
179

180
        if pg_type in primitive_type_mapping:
×
181
            return primitive_type_mapping[pg_type]
×
182
        elif pg_type == "USER-DEFINED" and udt_name in user_defined_type_mapping:
×
183
            return user_defined_type_mapping[udt_name]
×
184
        else:
185
            raise Exception(
186
                f"Unsupported column {pg_type} encountered in the postgres table. Please raise a feature request!"
187
            )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc