• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13287936324

12 Feb 2025 02:41PM UTC coverage: 80.957% (-0.2%) from 81.124%
13287936324

Pull #1604

github

web-flow
Merge 3c0280796 into d7200b518
Pull Request #1604: Text2sql execution accuracy metric updates

1499 of 1850 branches covered (81.03%)

Branch coverage included in aggregate %.

9520 of 11761 relevant lines covered (80.95%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

68.69
src/unitxt/db_utils.py
1
import glob
1✔
2
import os
1✔
3
import sqlite3
1✔
4
import time
1✔
5
from abc import ABC, abstractmethod
1✔
6
from functools import lru_cache
1✔
7
from typing import Any, List, Optional
1✔
8

9
import requests
1✔
10
from huggingface_hub import snapshot_download
1✔
11
from requests.exceptions import ConnectionError, ReadTimeout
1✔
12

13
from .logging_utils import get_logger
1✔
14
from .types import SQLDatabase
1✔
15

16
logger = get_logger()
1✔
17

18

19
class DatabaseConnector(ABC):
1✔
20
    """Abstract base class for database connectors."""
21

22
    def __init__(self, db_config: SQLDatabase):
1✔
23
        self.db_config = db_config
1✔
24
        self.databases_folder = os.path.join(
1✔
25
            os.environ.get("UNITXT_TEXT2SQL_CACHE", "cache/text2sql"), "databases"
26
        )
27
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
28

29
    @abstractmethod
1✔
30
    def get_table_schema(
1✔
31
        self,
32
    ) -> str:
33
        """Abstract method to get database schema."""
34
        pass
×
35

36
    @abstractmethod
1✔
37
    def execute_query(self, query: str) -> Any:
1✔
38
        """Abstract method to execute a query against the database."""
39
        pass
×
40

41

42
@lru_cache(maxsize=128)
1✔
43
def execute_query_local(db_path: str, query: str) -> Any:
1✔
44
    """Executes a query against the SQLite database."""
45
    conn = None  # Initialize conn to None outside the try block
1✔
46
    try:
1✔
47
        conn = sqlite3.connect(db_path)
1✔
48
        cursor = conn.cursor()
1✔
49
        cursor.execute(query)
1✔
50
        return cursor.fetchall(), None
1✔
51
    except sqlite3.Error as e:
1✔
52
        logger.info(f"Error executing SQL: {e}")
1✔
53
        return None, f"Error executing SQL: {e}"
1✔
54
    finally:
55
        if conn:
1✔
56
            conn.close()
1✔
57

58

59
class LocalSQLiteConnector(DatabaseConnector):
1✔
60
    """Database connector for SQLite databases."""
61

62
    def __init__(self, db_config: SQLDatabase):
1✔
63
        super().__init__(db_config)
1✔
64
        db_id = self.db_config.get("db_id")
1✔
65
        if not db_id:
1✔
66
            raise ValueError("db_id is required for SQLiteConnector.")
1✔
67
        self.db_path = self.get_db_file_path(db_id)
1✔
68
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
69
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
70

71
    def download_database(self, db_id):
1✔
72
        """Downloads the database from huggingface if needed."""
73
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
74
        if "bird/" in db_id:
1✔
75
            if not os.path.exists(done_file_path):
×
76
                snapshot_download(
×
77
                    repo_id="premai-io/birdbench",
78
                    repo_type="dataset",
79
                    local_dir=self.databases_folder,
80
                    force_download=False,
81
                    allow_patterns="*validation*",
82
                )
83
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
84
        else:
85
            raise NotImplementedError(
1✔
86
                f"current local db: {db_id} is not supported, only bird"
87
            )
88

89
    def get_db_file_path(self, db_id):
1✔
90
        """Gets the local path of a downloaded database file."""
91
        self.download_database(db_id)
1✔
92
        db_id = db_id.split("/")[-1]
×
93

94
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
95
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
96

97
        if not db_file_paths:
×
98
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
99
        if len(db_file_paths) > 1:
×
100
            raise FileExistsError(f"More than one files matched for {db_id}")
×
101
        return db_file_paths[0]
×
102

103
    def get_table_schema(
1✔
104
        self,
105
    ) -> str:
106
        """Extracts schema from an SQLite database."""
107
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
108
        tables: list[tuple[str]] = self.cursor.fetchall()
×
109
        schemas: dict[str, str] = {}
×
110

111
        for table in tables:
×
112
            if isinstance(table, tuple):
×
113
                table = table[0]
×
114
            if table == "sqlite_sequence":
×
115
                continue
×
116
            sql_query: str = (
×
117
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
118
            )
119
            self.cursor.execute(sql_query)
×
120
            schema_prompt: str = self.cursor.fetchone()[0]
×
121

122
            schemas[table] = schema_prompt
×
123

124
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
125
        return schema_prompt
×
126

127
    def execute_query(self, query: str) -> Any:
1✔
128
        """Executes a query against the SQLite database."""
129
        return execute_query_local(self.db_path, query)
1✔
130

131

132
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
133
    """Database connector for mocking databases with in-memory data structures."""
134

135
    def __init__(self, db_config: SQLDatabase):
1✔
136
        super().__init__(db_config)
1✔
137
        self.tables = db_config.get("data", None)
1✔
138

139
        if not self.tables:
1✔
140
            raise ValueError("data is required for InMemoryDatabaseConnector.")
1✔
141

142
    def get_table_schema(
1✔
143
        self,
144
        select_tables: Optional[List[str]] = None,
145
    ) -> str:
146
        """Generates a mock schema from the tables structure."""
147
        schemas = {}
1✔
148
        for table_name, table_data in self.tables.items():
1✔
149
            if select_tables and table_name.lower() not in select_tables:
1✔
150
                continue
1✔
151
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
152
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
153

154
            schemas[table_name] = schema
1✔
155

156
        return "\n\n".join(list(schemas.values()))
1✔
157

158
    def execute_query(self, query: str) -> Any:
1✔
159
        """Simulates executing a query against the mock database."""
160
        # Initialize in-memory database from the 'tables' dictionary
161
        conn = sqlite3.connect(":memory:")
1✔
162
        cursor = conn.cursor()
1✔
163
        logger.debug("Running SQL query over in-memory DB")
1✔
164

165
        # Create tables and insert data from the 'db' dictionary
166
        for table_name, table_data in self.tables.items():
1✔
167
            columns = table_data["columns"]
1✔
168
            rows = table_data["rows"]
1✔
169

170
            # Create table
171
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
172

173
            # Insert data
174
            placeholders = ", ".join(["?"] * len(columns))
1✔
175
            cursor.executemany(
1✔
176
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
177
            )
178

179
        try:
1✔
180
            cursor.execute(query)
1✔
181
            return cursor.fetchall(), None
1✔
182
        except sqlite3.Error as e:
1✔
183
            logger.info(f"Error executing SQL: {e}")
1✔
184
            return None, f"Error executing SQL: {e}"
1✔
185
        finally:
186
            conn.close()
1✔
187

188

189
@lru_cache(maxsize=128)
1✔
190
def execute_query_remote(
1✔
191
    api_url: str,
192
    database_id: str,
193
    api_key: str,
194
    query: str,
195
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
196
    max_retries: int = 3,
197
    retry_delay: int = 5,  # seconds
198
    timeout: int = 30,  # seconds
199
) -> (Optional[dict], str):
200
    """Executes a query against the remote database, with retries for certain exceptions."""
201
    headers = {
1✔
202
        "Content-Type": "application/json",
203
        "accept": "application/json",
204
        "Authorization": f"Bearer {api_key}",
205
    }
206
    retries = 0
1✔
207
    while retries <= max_retries:
1✔
208
        try:
1✔
209
            response = requests.post(
1✔
210
                f"{api_url}/sql",
211
                headers=headers,
212
                json={"sql": query, "dataSourceId": database_id},
213
                verify=True,
214
                timeout=timeout,
215
            )
216
            response.raise_for_status()
×
217
            return response.json(), None
×
218

219
        except retryable_exceptions as e:
1✔
220
            retries += 1
×
221
            logger.warning(
×
222
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
223
            )
224
            if retries <= max_retries:
×
225
                time.sleep(retry_delay)
×
226
            else:
227
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
×
228
                return (
×
229
                    None,
230
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
231
                )
232

233
        except requests.exceptions.HTTPError as e:
1✔
234
            if e.response.status_code >= 500:
×
235
                retries += 1
×
236
                logger.warning(
×
237
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
238
                )
239
                if retries <= max_retries:
×
240
                    time.sleep(retry_delay)
×
241
                else:
242
                    logger.error(
×
243
                        f"Max retries ({max_retries}) exceeded for query: {query}"
244
                    )
245
                    return (
×
246
                        None,
247
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
248
                    )
249
            else:
250
                logger.error(f"HTTP Error on attempt {retries}: {e}")
×
251
                return (
×
252
                    None,
253
                    f"HTTP Error on attempt {retries}: {e}",
254
                )
255

256
        except Exception as e:
1✔
257
            logger.error(f"Unexpected error on attempt {retries}: {e}")
1✔
258
            return (None, f"Unexpected error on attempt {retries}: {e}")
1✔
259

260
    return None, "Unknown Error in SQL execution"
×
261

262

263
class RemoteDatabaseConnector(DatabaseConnector):
1✔
264
    """Database connector for remote databases accessed via HTTP."""
265

266
    def __init__(self, db_config: SQLDatabase):
1✔
267
        super().__init__(db_config)
1✔
268

269
        assert db_config[
1✔
270
            "db_id"
271
        ], "db_id must be in db_config for RemoteDatabaseConnector"
272
        self.api_url, self.database_id = (
1✔
273
            db_config["db_id"].split(",")[0],
274
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
275
        )
276

277
        if not self.api_url or not self.database_id:
1✔
278
            raise ValueError(
1✔
279
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
280
            )
281

282
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
283
        if not self.api_key:
1✔
284
            raise ValueError(
1✔
285
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
286
            )
287

288
        self.headers = {
1✔
289
            "Content-Type": "application/json",
290
            "accept": "application/json",
291
            "Authorization": f"Bearer {self.api_key}",
292
        }
293

294
        self.timeout = 30
1✔
295

296
    def get_table_schema(
1✔
297
        self,
298
    ) -> str:
299
        """Retrieves the schema of a database."""
300
        cur_api_url = f"{self.api_url}/datasource/{self.database_id}"
1✔
301
        response = requests.get(
1✔
302
            cur_api_url,
303
            headers=self.headers,
304
            verify=True,
305
            timeout=self.timeout,
306
        )
307
        if response.status_code == 200:
1✔
308
            schema = response.json()["schema"]
1✔
309
        else:
310
            raise OSError(f"Could not fetch schema from {cur_api_url}")
×
311

312
        schema_text = ""
×
313
        for table in schema["tables"]:
×
314
            schema_text += f"Table: {table['table_name']} has columns: {[col['column_name'] for col in table['columns']]}\n"
×
315

316
        return schema_text
×
317

318
    def execute_query(self, query: str) -> Any:
1✔
319
        """Executes a query against the remote database, with retries for certain exceptions."""
320
        return execute_query_remote(
1✔
321
            api_url=self.api_url,
322
            database_id=self.database_id,
323
            api_key=self.api_key,
324
            query=query,
325
            timeout=self.timeout,
326
        )
327

328

329
def get_db_connector(db_type: str):
1✔
330
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
331
    if db_type == "local":
1✔
332
        connector = LocalSQLiteConnector
×
333
    elif db_type == "in_memory":
1✔
334
        connector = InMemoryDatabaseConnector
1✔
335
    elif db_type == "remote":
1✔
336
        connector = RemoteDatabaseConnector
×
337

338
    else:
339
        raise ValueError(f"Unsupported database type: {db_type}")
1✔
340

341
    return connector
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc