• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 14254434031

03 Apr 2025 11:06PM UTC coverage: 80.217% (-0.09%) from 80.304%
14254434031

Pull #1702

github

web-flow
Merge e8159c3c1 into 4905b2e6c
Pull Request #1702: Text2sql metrics fixes

1581 of 1965 branches covered (80.46%)

Branch coverage included in aggregate %.

9894 of 12340 relevant lines covered (80.18%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.46
src/unitxt/sql_utils.py
1
import functools
1✔
2
import glob
1✔
3
import hashlib
1✔
4
import json
1✔
5
import os
1✔
6
import re
1✔
7
import sqlite3
1✔
8
import time
1✔
9
from abc import ABC, abstractmethod
1✔
10
from functools import lru_cache
1✔
11
from typing import Any, List, Optional
1✔
12

13
import requests
1✔
14
from huggingface_hub import snapshot_download
1✔
15
from requests.exceptions import ConnectionError, ReadTimeout
1✔
16

17
from .logging_utils import get_logger
1✔
18
from .types import SQLDatabase
1✔
19

20
logger = get_logger()
1✔
21

22
# Check if caching is enabled via environment variable
23
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
1✔
24

25
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
26
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
1✔
27

28
_cache_instance = None
1✔
29

30

31
class DatabaseConnector(ABC):
1✔
32
    """Abstract base class for database connectors."""
33

34
    def __init__(self, db_config: SQLDatabase):
1✔
35
        self.db_config = db_config
1✔
36
        self.databases_folder = os.path.join(
1✔
37
            os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
38
        )
39
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
40

41
    @abstractmethod
1✔
42
    def get_table_schema(
1✔
43
        self,
44
    ) -> str:
45
        """Abstract method to get database schema."""
46
        pass
×
47

48
    @abstractmethod
1✔
49
    def execute_query(self, query: str) -> Any:
1✔
50
        """Abstract method to execute a query against the database."""
51
        pass
×
52

53

54
@lru_cache(maxsize=128)
1✔
55
def execute_query_local(db_path: str, query: str) -> Any:
1✔
56
    """Executes a query against the SQLite database."""
57
    conn = None  # Initialize conn to None outside the try block
1✔
58
    try:
1✔
59
        conn = sqlite3.connect(db_path)
1✔
60
        cursor = conn.cursor()
1✔
61
        cursor.execute(query)
1✔
62
        return cursor.fetchall(), None
1✔
63
    except sqlite3.Error as e:
1✔
64
        logger.info(f"Error executing SQL: {e}")
1✔
65
        return None, f"Error executing SQL: {e}"
1✔
66
    finally:
67
        if conn:
1✔
68
            conn.close()
1✔
69

70

71
class LocalSQLiteConnector(DatabaseConnector):
1✔
72
    """Database connector for SQLite databases."""
73

74
    def __init__(self, db_config: SQLDatabase):
1✔
75
        super().__init__(db_config)
1✔
76
        db_id = self.db_config.get("db_id")
1✔
77
        if not db_id:
1✔
78
            raise ValueError("db_id is required for SQLiteConnector.")
1✔
79
        self.db_path = self.get_db_file_path(db_id)
1✔
80
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
81
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
82

83
    def download_database(self, db_id):
1✔
84
        """Downloads the database from huggingface if needed."""
85
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
86
        if "bird/" in db_id:
1✔
87
            if not os.path.exists(done_file_path):
×
88
                snapshot_download(
×
89
                    repo_id="premai-io/birdbench",
90
                    repo_type="dataset",
91
                    local_dir=self.databases_folder,
92
                    force_download=False,
93
                    allow_patterns="*validation*",
94
                )
95
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
96
        else:
97
            raise NotImplementedError(
1✔
98
                f"current local db: {db_id} is not supported, only bird"
99
            )
100

101
    def get_db_file_path(self, db_id):
1✔
102
        """Gets the local path of a downloaded database file."""
103
        self.download_database(db_id)
1✔
104
        db_id = db_id.split("/")[-1]
×
105

106
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
107
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
108

109
        if not db_file_paths:
×
110
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
111
        if len(db_file_paths) > 1:
×
112
            raise FileExistsError(f"More than one files matched for {db_id}")
×
113
        return db_file_paths[0]
×
114

115
    def get_table_schema(
1✔
116
        self,
117
    ) -> str:
118
        """Extracts schema from an SQLite database."""
119
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
120
        tables: list[tuple[str]] = self.cursor.fetchall()
×
121
        schemas: dict[str, str] = {}
×
122

123
        for table in tables:
×
124
            if isinstance(table, tuple):
×
125
                table = table[0]
×
126
            if table == "sqlite_sequence":
×
127
                continue
×
128
            sql_query: str = (
×
129
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
130
            )
131
            self.cursor.execute(sql_query)
×
132
            schema_prompt: str = self.cursor.fetchone()[0]
×
133

134
            schemas[table] = schema_prompt
×
135

136
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
137
        return schema_prompt
×
138

139
    def execute_query(self, query: str) -> Any:
1✔
140
        """Executes a query against the SQLite database."""
141
        return execute_query_local(self.db_path, query)
1✔
142

143

144
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
145
    """Database connector for mocking databases with in-memory data structures."""
146

147
    def __init__(self, db_config: SQLDatabase):
1✔
148
        super().__init__(db_config)
1✔
149
        self.tables = db_config.get("data", None)
1✔
150

151
        if not self.tables:
1✔
152
            raise ValueError("data is required for InMemoryDatabaseConnector.")
1✔
153

154
    def get_table_schema(
1✔
155
        self,
156
        select_tables: Optional[List[str]] = None,
157
    ) -> str:
158
        """Generates a mock schema from the tables structure."""
159
        schemas = {}
1✔
160
        for table_name, table_data in self.tables.items():
1✔
161
            if select_tables and table_name.lower() not in select_tables:
1✔
162
                continue
1✔
163
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
164
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
165

166
            schemas[table_name] = schema
1✔
167

168
        return "\n\n".join(list(schemas.values()))
1✔
169

170
    def execute_query(self, query: str) -> Any:
1✔
171
        """Simulates executing a query against the mock database."""
172
        # Initialize in-memory database from the 'tables' dictionary
173
        conn = sqlite3.connect(":memory:")
1✔
174
        cursor = conn.cursor()
1✔
175
        logger.debug("Running SQL query over in-memory DB")
1✔
176

177
        # Create tables and insert data from the 'db' dictionary
178
        for table_name, table_data in self.tables.items():
1✔
179
            columns = table_data["columns"]
1✔
180
            rows = table_data["rows"]
1✔
181

182
            # Create table
183
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
184

185
            # Insert data
186
            placeholders = ", ".join(["?"] * len(columns))
1✔
187
            cursor.executemany(
1✔
188
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
189
            )
190

191
        try:
1✔
192
            cursor.execute(query)
1✔
193
            return cursor.fetchall(), None
1✔
194
        except sqlite3.Error as e:
1✔
195
            logger.info(f"Error executing SQL: {e}")
1✔
196
            return None, f"Error executing SQL: {e}"
1✔
197
        finally:
198
            conn.close()
1✔
199

200

201
def get_cache():
1✔
202
    """Returns a singleton cache instance, initializing it if necessary."""
203
    global _cache_instance
204
    if _cache_instance is None:
1✔
205
        _cache_instance = Cache()
1✔
206
    return _cache_instance
1✔
207

208

209
def generate_cache_key(*args, **kwargs):
1✔
210
    """Generate a stable hashable cache key for various input types.
211

212
    :param args: Positional arguments of the function.
213
    :param kwargs: Keyword arguments of the function.
214
    :return: A hashed key as a string.
215
    """
216
    try:
1✔
217
        # Convert args and kwargs to a JSON string (sorted to ensure consistency)
218
        serialized = json.dumps(
1✔
219
            {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
220
        )
221
    except TypeError:
×
222
        # Fallback for non-serializable objects
223
        serialized = repr((args, kwargs))
×
224

225
    # Hash the serialized data
226
    return hashlib.md5(serialized.encode()).hexdigest()
1✔
227

228

229
class Cache:
1✔
230
    """A class that provides disk-based caching functionality for a given function."""
231

232
    def __init__(self):
1✔
233
        """Initializes the cache.
234

235
        If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
236
        cache is created using `diskcache`.
237

238
        Args:
239
            None
240

241
        Returns:
242
            None
243
        """
244
        if CACHE_LOCATION:
1✔
245
            try:
×
246
                import diskcache
×
247

248
                # Ensure the cache directory exists
249
                os.makedirs(CACHE_LOCATION, exist_ok=True)
×
250

251
                # Create a global diskcache Cache instance
252
                self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
×
253
                logger.info(f"Caching enabled at {CACHE_LOCATION}")
×
254
            except ImportError as e:
×
255
                raise ImportError(
×
256
                    "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
257
                    "Please install diskcache `pip install diskcache` "
258
                    "or unset UNITXT_CACHE_LOCATION."
259
                ) from e
260
        else:
261
            self.cache = None  # Disable caching
1✔
262

263
    def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
264
        if not self.cache or no_cache:
1✔
265
            logger.info(f"Bypassing cache for key: {key}")
1✔
266
            return compute_fn()
1✔
267

268
        if refresh and key in self.cache:
×
269
            logger.info(f"Refreshing cache for key: {key}")
×
270
            del self.cache[key]
×
271

272
        if key in self.cache:
×
273
            logger.info(f"Cache hit for key: {key}")
×
274
            return self.cache[key]
×
275

276
        logger.info(f"Cache miss for key: {key}. Computing value...")
×
277
        result = compute_fn()
×
278

279
        if result and not (
×
280
            isinstance(result, tuple) and len(result) == 2 and result[0] is None
281
        ):
282
            self.cache[key] = result
×
283
            logger.info(f"Stored result in cache for key: {key}")
×
284
        else:
285
            logger.info(f"None result. Bypassing caching for key: {key}")
×
286

287
        return result
×
288

289
    async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
290
        if not self.cache or no_cache:
×
291
            logger.info(f"Bypassing cache for key: {key}")
×
292
            return await compute_fn()
×
293

294
        if refresh and key in self.cache:
×
295
            logger.info(f"Refreshing cache for key: {key}")
×
296
            del self.cache[key]
×
297

298
        if key in self.cache:
×
299
            logger.info(f"Cache hit for key: {key}")
×
300
            return self.cache[key]
×
301

302
        logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
×
303
        result = await compute_fn()
×
304
        self.cache[key] = result
×
305
        logger.info(f"Stored result in cache for key: {key}")
×
306
        return result
×
307

308
    def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
309
        def decorator(func):
×
310
            @functools.wraps(func)
×
311
            def wrapper(*args, **kwargs):
×
312
                if not self.cache or no_cache:
×
313
                    logger.info(f"Bypassing cache for function: {func.__name__}")
×
314
                    return func(*args, **kwargs)
×
315

316
                key = key_func(func.__name__, *args, **kwargs)
×
317

318
                if refresh and key in self.cache:
×
319
                    logger.info(
×
320
                        f"Refreshing cache for function: {func.__name__}, key: {key}"
321
                    )
322
                    del self.cache[key]
×
323

324
                if key in self.cache:
×
325
                    logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
×
326
                    return self.cache[key]
×
327

328
                logger.info(
×
329
                    f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
330
                )
331
                result = func(*args, **kwargs)
×
332
                self.cache[key] = result
×
333
                logger.info(
×
334
                    f"Stored result in cache for function: {func.__name__}, key: {key}"
335
                )
336
                return result
×
337

338
            return wrapper
×
339

340
        return decorator
×
341

342
    def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
343
        def decorator(func):
×
344
            @functools.wraps(func)
×
345
            async def wrapper(*args, **kwargs):
×
346
                if no_cache:
×
347
                    logger.info(f"Bypassing cache for async function: {func.__name__}")
×
348
                    return await func(*args, **kwargs)
×
349

350
                key = key_func(func.__name__, *args, **kwargs)
×
351

352
                if refresh and key in self.cache:
×
353
                    logger.info(
×
354
                        f"Refreshing cache for async function: {func.__name__}, key: {key}"
355
                    )
356
                    del self.cache[key]
×
357

358
                if key in self.cache:
×
359
                    logger.info(
×
360
                        f"Cache hit for async function: {func.__name__}, key: {key}"
361
                    )
362
                    return self.cache[key]
×
363

364
                logger.info(
×
365
                    f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
366
                )
367
                result = await func(*args, **kwargs)
×
368
                self.cache[key] = result
×
369
                logger.info(
×
370
                    f"Stored result in cache for async function: {func.__name__}, key: {key}"
371
                )
372
                return result
×
373

374
            return wrapper
×
375

376
        return decorator
×
377

378

379
@lru_cache(maxsize=128)
1✔
380
def execute_query_remote(
1✔
381
    api_url: str,
382
    database_id: str,
383
    api_key: str,
384
    query: str,
385
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
386
    max_retries: int = 3,
387
    retry_delay: int = 5,  # seconds
388
    timeout: int = 30,  # seconds
389
) -> (Optional[dict], str):
390
    """Executes a query against the remote database, with retries for certain exceptions."""
391
    headers = {
1✔
392
        "Content-Type": "application/json",
393
        "accept": "application/json",
394
        "Authorization": f"Bearer {api_key}",
395
    }
396
    retries = 0
1✔
397
    while retries <= max_retries:
1✔
398
        try:
1✔
399
            response = requests.post(
1✔
400
                f"{api_url}/sql",
401
                headers=headers,
402
                json={"sql": query, "dataSourceId": database_id},
403
                verify=False,
404
                timeout=timeout,
405
            )
406
            response.raise_for_status()
×
407
            return response.json(), None
×
408

409
        except retryable_exceptions as e:
1✔
410
            retries += 1
×
411
            logger.warning(
×
412
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
413
            )
414
            if retries <= max_retries:
×
415
                time.sleep(retry_delay)
×
416
            else:
417
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
×
418
                return (
×
419
                    None,
420
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
421
                )
422

423
        except requests.exceptions.HTTPError as e:
1✔
424
            if e.response.status_code >= 500:
×
425
                retries += 1
×
426
                logger.warning(
×
427
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
428
                )
429
                if retries <= max_retries:
×
430
                    time.sleep(retry_delay)
×
431
                else:
432
                    logger.error(
×
433
                        f"Max retries ({max_retries}) exceeded for query: {query}"
434
                    )
435
                    return (
×
436
                        None,
437
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
438
                    )
439
            else:
440
                logger.error(f"HTTP Error on attempt {retries}: {e}")
×
441
                return (
×
442
                    None,
443
                    f"HTTP Error on attempt {retries}: {e}",
444
                )
445

446
        except Exception as e:
1✔
447
            logger.error(f"Unexpected error on attempt {retries}: {e}")
1✔
448
            return (None, f"Unexpected error on attempt {retries}: {e}")
1✔
449

450
    return None, "Unknown Error in SQL execution"
×
451

452

453
class RemoteDatabaseConnector(DatabaseConnector):
1✔
454
    """Database connector for remote databases accessed via HTTP."""
455

456
    def __init__(self, db_config: SQLDatabase):
1✔
457
        super().__init__(db_config)
1✔
458

459
        assert db_config[
1✔
460
            "db_id"
461
        ], "db_id must be in db_config for RemoteDatabaseConnector"
462
        self.api_url, self.database_id = (
1✔
463
            db_config["db_id"].split(",")[0],
464
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
465
        )
466

467
        if not self.api_url or not self.database_id:
1✔
468
            raise ValueError(
1✔
469
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
470
            )
471

472
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
473
        if not self.api_key:
1✔
474
            raise ValueError(
1✔
475
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
476
            )
477

478
        self.headers = {
1✔
479
            "Content-Type": "application/json",
480
            "accept": "application/json",
481
            "Authorization": f"Bearer {self.api_key}",
482
        }
483

484
        self.timeout = 30
1✔
485

486
    def get_table_schema(
1✔
487
        self,
488
    ) -> str:
489
        """Retrieves the schema of a database."""
490
        cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
1✔
491
        response = requests.get(
1✔
492
            cur_api_url,
493
            headers=self.headers,
494
            verify=False,
495
            timeout=self.timeout,
496
        )
497
        if response.status_code == 200:
1✔
498
            schema = response.json()["schema"]
1✔
499
        else:
500
            raise OSError(f"Could not fetch schema from {cur_api_url}")
×
501

502
        schema_text = ""
×
503
        for table in schema["tables"]:
×
504
            schema_text += f"Table: {table['name'] if 'name' in table else table['table_name']} has columns: {[col['name'] if 'name' in col else col['column_name'] for col in table['columns']]}\n"
×
505

506
        return schema_text
×
507

508
    def execute_query(self, query: str) -> Any:
1✔
509
        """Executes a query against the remote database, with retries for certain exceptions."""
510
        cache = get_cache()
1✔
511

512
        cache_key = generate_cache_key(
1✔
513
            "sql_request", self.api_url, self.database_id, query
514
        )
515
        return cache.get_or_set(
1✔
516
            cache_key,
517
            lambda: execute_query_remote(
518
                api_url=self.api_url,
519
                database_id=self.database_id,
520
                api_key=self.api_key,
521
                query=query,
522
                timeout=self.timeout,
523
            ),
524
        )
525

526

527
def get_db_connector(db_type: str):
1✔
528
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
529
    if db_type == "local":
1✔
530
        connector = LocalSQLiteConnector
×
531
    elif db_type == "in_memory":
1✔
532
        connector = InMemoryDatabaseConnector
1✔
533
    elif db_type == "remote":
1✔
534
        connector = RemoteDatabaseConnector
×
535

536
    else:
537
        raise ValueError(f"Unsupported database type: {db_type}")
1✔
538

539
    return connector
1✔
540

541

542
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
1✔
543
    """Returns True if sqlglot does not encounter any error, False otherwise."""
544
    from sqlglot import parse
1✔
545

546
    if not sql.strip():
1✔
547
        return False
×
548
    if db_type == "db2":
1✔
549
        db_type = "postgres"  ## TODO: temporary until sqlglot adds support for db2
×
550
    try:
1✔
551
        parse(sql, read=db_type)
1✔
552
        return True
1✔
553
    except Exception as e:
×
554
        logger.debug(f"SQL query could not parse: {e}")
×
555
        return False
×
556

557

558
def is_sqlparse_parsable(sql: str) -> bool:
1✔
559
    """Returns True if sqlparse does not encounter any error, False otherwise."""
560
    from sqlparse import parse
1✔
561
    from sqlparse.tokens import Error
1✔
562

563
    if not sql.strip():
1✔
564
        return False
×
565
    try:
1✔
566
        statements = parse(sql)
1✔
567
        for statement in statements:
1✔
568
            for token in statement.tokens:
1✔
569
                if token.ttype == Error:
1✔
570
                    return False
×
571
        return True
1✔
572
    except Exception as e:
×
573
        logger.debug(f"SQL query could not parse: {e}")
×
574
        return False
×
575

576

577
def sqlglot_optimized_equivalence(expected: str, generated: str) -> int:
1✔
578
    """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them."""
579
    from sqlglot import diff, parse_one
1✔
580
    from sqlglot.optimizer import optimize
1✔
581

582
    try:
1✔
583
        t_diff = diff(
1✔
584
            optimize(parse_one(expected.lower()).sql(pretty=True)),
585
            optimize(parse_one(generated.lower()).sql(pretty=True)),
586
        )
587
        sql_diff = sum(0 if (e.__class__.__name__ == "Keep") else 1 for e in t_diff)
1✔
588

589
        return 1 if sql_diff == 0 else 0
1✔
590
    except Exception as e:
×
591
        logger.debug(f"Error parsing SQL for comparison: {e}")
×
592
        return False
×
593

594

595
def extract_select_columns(statement):
1✔
596
    """Parse SQL using sqlparse and extract columns."""
597
    from sqlparse.sql import Identifier, IdentifierList
1✔
598
    from sqlparse.tokens import DML, Keyword
1✔
599

600
    columns = []
1✔
601
    select_seen = False
1✔
602
    for token in statement.tokens:
1✔
603
        if token.ttype is DML and token.value.upper() == "SELECT":
1✔
604
            select_seen = True
1✔
605
            continue
1✔
606
        if select_seen:
1✔
607
            if token.ttype is Keyword and token.value.upper() in (
1✔
608
                "FROM",
609
                "WHERE",
610
                "GROUP",
611
                "HAVING",
612
                "ORDER",
613
                "LIMIT",
614
            ):
615
                break
1✔
616
            if isinstance(token, IdentifierList):
1✔
617
                for identifier in token.get_identifiers():
1✔
618
                    columns.append(strip_alias(identifier.value))
1✔
619
            elif isinstance(token, Identifier):
1✔
620
                columns.append(strip_alias(token.value))
1✔
621
            else:
622
                val = token.value.strip()
1✔
623
                if val:
1✔
624
                    columns.append(strip_alias(val))
1✔
625
    return frozenset(columns)
1✔
626

627

628
def strip_alias(col: str) -> str:
1✔
629
    """Remove any AS alias from a column."""
630
    col = col.strip()
1✔
631
    upper = col.upper()
1✔
632
    if " AS " in upper:
1✔
633
        return col[: upper.index(" AS ")].strip()
1✔
634
    parts_alias = col.split()
1✔
635
    if len(parts_alias) > 1:
1✔
636
        return " ".join(parts_alias[:-1])
×
637
    return col
1✔
638

639

640
def collect_clause(statement, clause_keyword):
1✔
641
    """Parse SQL statement and collect clauses."""
642
    from sqlparse.tokens import Keyword
1✔
643

644
    found = False
1✔
645
    collected = []
1✔
646
    for token in statement.tokens:
1✔
647
        tvalue = token.value.upper()
1✔
648
        if token.ttype is Keyword:
1✔
649
            if tvalue.startswith(clause_keyword):
1✔
650
                found = True
1✔
651
                continue
1✔
652
            if found and tvalue in (
1✔
653
                "FROM",
654
                "WHERE",
655
                "GROUP",
656
                "HAVING",
657
                "ORDER",
658
                "LIMIT",
659
            ):
660
                break
×
661
        if found:
1✔
662
            collected.append(token.value)
1✔
663
    return " ".join(collected).strip()
1✔
664

665

666
def extract_select_info(sql: str):
1✔
667
    """Parse SQL using sqlparse and return a dict of extracted columns and clauses."""
668
    from sqlparse import parse
1✔
669
    from sqlparse.tokens import DML
1✔
670

671
    statements = parse(sql)
1✔
672
    if len(statements) != 1:
1✔
673
        return None
×
674
    stmt = statements[0]
1✔
675
    if not any(t.ttype is DML and t.value.upper() == "SELECT" for t in stmt.tokens):
1✔
676
        return None
×
677
    parts = {
1✔
678
        "columns": None,
679
        "from": "",
680
        "where": "",
681
        "group": "",
682
        "having": "",
683
        "order": "",
684
    }
685
    columns = extract_select_columns(stmt)
1✔
686
    if not columns:
1✔
687
        columns = frozenset()
×
688
    parts["columns"] = columns
1✔
689
    parts["from"] = collect_clause(stmt, "FROM")
1✔
690
    parts["where"] = collect_clause(stmt, "WHERE")
1✔
691
    parts["group"] = collect_clause(stmt, "GROUP")
1✔
692
    parts["having"] = collect_clause(stmt, "HAVING")
1✔
693
    parts["order"] = collect_clause(stmt, "ORDER")
1✔
694
    return parts
1✔
695

696

697
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
1✔
698
    """Return True if both SQL queries are naively considered equivalent."""
699
    try:
1✔
700
        info1 = extract_select_info(sql1)
1✔
701
        info2 = extract_select_info(sql2)
1✔
702
        if not info1 or not info2:
1✔
703
            return False
×
704
        if info1["columns"] != info2["columns"]:
1✔
705
            return False
1✔
706
        for k in ["from", "where", "group", "having", "order"]:
1✔
707
            if info1[k].replace(" ", "").upper() != info2[k].replace(" ", "").upper():
1✔
708
                return False
1✔
709
        return True
×
710
    except Exception as e:
×
711
        logger.debug(f"Errpr parsing SQL query for comparison: {e}")
×
712
        return False
×
713

714

715
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
1✔
716
    from sqlglot import exp, parse_one
1✔
717

718
    try:
1✔
719
        ast1 = parse_one(sql1, read=dialect)
1✔
720
        ast2 = parse_one(sql2, read=dialect)
1✔
721
    except:
×
722
        return False
×
723
    if not (isinstance(ast1, exp.Select) and isinstance(ast2, exp.Select)):
1✔
724
        return False
×
725

726
    def normalized_select_columns(select_expr: exp.Select):
1✔
727
        cols = []
1✔
728
        for item in select_expr.expressions:
1✔
729
            copy_item = item.copy()
1✔
730
            copy_item.set("alias", None)
1✔
731
            cols.append(copy_item.sql(dialect=dialect, normalize=True))
1✔
732
        return frozenset(cols)
1✔
733

734
    if normalized_select_columns(ast1) != normalized_select_columns(ast2):
1✔
735
        return False
1✔
736

737
    def normalized_clause(expr: exp.Expression, key: str):
1✔
738
        clause = expr.args.get(key)
1✔
739
        return clause.sql(dialect=dialect, normalize=True) if clause else ""
1✔
740

741
    for clause_key in ("from", "where", "group", "having", "order"):
1✔
742
        if normalized_clause(ast1, clause_key) != normalized_clause(ast2, clause_key):
1✔
743
            return False
×
744

745
    return True
1✔
746

747

748
def sql_exact_match(sql1: str, sql2: str) -> bool:
1✔
749
    """Return True if two SQL strings match after very basic normalization."""
750

751
    def normalize_sql(s: str) -> str:
1✔
752
        s = s.strip().rstrip(";")
1✔
753
        s = re.sub(r"\s+", " ", s)
1✔
754
        return s.upper()
1✔
755

756
    return normalize_sql(sql1) == normalize_sql(sql2)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc