• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 13946131191

19 Mar 2025 12:08PM UTC coverage: 80.216% (-0.5%) from 80.739%
13946131191

push

github

web-flow
Text2sql metrics update and optional caching (#1672)

* caching and other fixes for text2sql metrics

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* text2sql caching ruff fixes

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* diskcache dependency

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* text2sql: making caching optional

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* text2sql optional caching: ruff fixes

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* text2sql: test for ORDER BY queries

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* text2sql metric fixing import loc

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* fixing cache initialization

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* merging cach_utils into sql_utils

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>

* fix formatting

---------

Signed-off-by: Oktie Hassanzadeh <hassanzadeh@us.ibm.com>
Co-authored-by: Elron Bandel <elronbandel@gmail.com>
Co-authored-by: Yotam Perlitz <perlitz@gmail.com>
Co-authored-by: Yotam Perlitz <y.perlitz@ibm.com>

1569 of 1948 branches covered (80.54%)

Branch coverage included in aggregate %.

9820 of 12250 relevant lines covered (80.16%)

0.8 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

64.85
src/unitxt/sql_utils.py
1
import functools
1✔
2
import glob
1✔
3
import hashlib
1✔
4
import json
1✔
5
import os
1✔
6
import re
1✔
7
import sqlite3
1✔
8
import time
1✔
9
from abc import ABC, abstractmethod
1✔
10
from functools import lru_cache
1✔
11
from typing import Any, List, Optional
1✔
12

13
import requests
1✔
14
from huggingface_hub import snapshot_download
1✔
15
from requests.exceptions import ConnectionError, ReadTimeout
1✔
16

17
from .logging_utils import get_logger
1✔
18
from .types import SQLDatabase
1✔
19

20
logger = get_logger()
1✔
21

22
# Check if caching is enabled via environment variable
23
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
1✔
24

25
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
26
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
1✔
27

28
_cache_instance = None
1✔
29

30

31
class DatabaseConnector(ABC):
1✔
32
    """Abstract base class for database connectors."""
33

34
    def __init__(self, db_config: SQLDatabase):
1✔
35
        self.db_config = db_config
1✔
36
        self.databases_folder = os.path.join(
1✔
37
            os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
38
        )
39
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
40

41
    @abstractmethod
1✔
42
    def get_table_schema(
1✔
43
        self,
44
    ) -> str:
45
        """Abstract method to get database schema."""
46
        pass
×
47

48
    @abstractmethod
1✔
49
    def execute_query(self, query: str) -> Any:
1✔
50
        """Abstract method to execute a query against the database."""
51
        pass
×
52

53

54
@lru_cache(maxsize=128)
1✔
55
def execute_query_local(db_path: str, query: str) -> Any:
1✔
56
    """Executes a query against the SQLite database."""
57
    conn = None  # Initialize conn to None outside the try block
1✔
58
    try:
1✔
59
        conn = sqlite3.connect(db_path)
1✔
60
        cursor = conn.cursor()
1✔
61
        cursor.execute(query)
1✔
62
        return cursor.fetchall(), None
1✔
63
    except sqlite3.Error as e:
1✔
64
        logger.info(f"Error executing SQL: {e}")
1✔
65
        return None, f"Error executing SQL: {e}"
1✔
66
    finally:
67
        if conn:
1✔
68
            conn.close()
1✔
69

70

71
class LocalSQLiteConnector(DatabaseConnector):
1✔
72
    """Database connector for SQLite databases."""
73

74
    def __init__(self, db_config: SQLDatabase):
1✔
75
        super().__init__(db_config)
1✔
76
        db_id = self.db_config.get("db_id")
1✔
77
        if not db_id:
1✔
78
            raise ValueError("db_id is required for SQLiteConnector.")
1✔
79
        self.db_path = self.get_db_file_path(db_id)
1✔
80
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
81
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
82

83
    def download_database(self, db_id):
1✔
84
        """Downloads the database from huggingface if needed."""
85
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
86
        if "bird/" in db_id:
1✔
87
            if not os.path.exists(done_file_path):
×
88
                snapshot_download(
×
89
                    repo_id="premai-io/birdbench",
90
                    repo_type="dataset",
91
                    local_dir=self.databases_folder,
92
                    force_download=False,
93
                    allow_patterns="*validation*",
94
                )
95
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
96
        else:
97
            raise NotImplementedError(
1✔
98
                f"current local db: {db_id} is not supported, only bird"
99
            )
100

101
    def get_db_file_path(self, db_id):
1✔
102
        """Gets the local path of a downloaded database file."""
103
        self.download_database(db_id)
1✔
104
        db_id = db_id.split("/")[-1]
×
105

106
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
107
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
108

109
        if not db_file_paths:
×
110
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
111
        if len(db_file_paths) > 1:
×
112
            raise FileExistsError(f"More than one files matched for {db_id}")
×
113
        return db_file_paths[0]
×
114

115
    def get_table_schema(
1✔
116
        self,
117
    ) -> str:
118
        """Extracts schema from an SQLite database."""
119
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
120
        tables: list[tuple[str]] = self.cursor.fetchall()
×
121
        schemas: dict[str, str] = {}
×
122

123
        for table in tables:
×
124
            if isinstance(table, tuple):
×
125
                table = table[0]
×
126
            if table == "sqlite_sequence":
×
127
                continue
×
128
            sql_query: str = (
×
129
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
130
            )
131
            self.cursor.execute(sql_query)
×
132
            schema_prompt: str = self.cursor.fetchone()[0]
×
133

134
            schemas[table] = schema_prompt
×
135

136
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
137
        return schema_prompt
×
138

139
    def execute_query(self, query: str) -> Any:
1✔
140
        """Executes a query against the SQLite database."""
141
        return execute_query_local(self.db_path, query)
1✔
142

143

144
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
145
    """Database connector for mocking databases with in-memory data structures."""
146

147
    def __init__(self, db_config: SQLDatabase):
1✔
148
        super().__init__(db_config)
1✔
149
        self.tables = db_config.get("data", None)
1✔
150

151
        if not self.tables:
1✔
152
            raise ValueError("data is required for InMemoryDatabaseConnector.")
1✔
153

154
    def get_table_schema(
1✔
155
        self,
156
        select_tables: Optional[List[str]] = None,
157
    ) -> str:
158
        """Generates a mock schema from the tables structure."""
159
        schemas = {}
1✔
160
        for table_name, table_data in self.tables.items():
1✔
161
            if select_tables and table_name.lower() not in select_tables:
1✔
162
                continue
1✔
163
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
164
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
165

166
            schemas[table_name] = schema
1✔
167

168
        return "\n\n".join(list(schemas.values()))
1✔
169

170
    def execute_query(self, query: str) -> Any:
1✔
171
        """Simulates executing a query against the mock database."""
172
        # Initialize in-memory database from the 'tables' dictionary
173
        conn = sqlite3.connect(":memory:")
1✔
174
        cursor = conn.cursor()
1✔
175
        logger.debug("Running SQL query over in-memory DB")
1✔
176

177
        # Create tables and insert data from the 'db' dictionary
178
        for table_name, table_data in self.tables.items():
1✔
179
            columns = table_data["columns"]
1✔
180
            rows = table_data["rows"]
1✔
181

182
            # Create table
183
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
184

185
            # Insert data
186
            placeholders = ", ".join(["?"] * len(columns))
1✔
187
            cursor.executemany(
1✔
188
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
189
            )
190

191
        try:
1✔
192
            cursor.execute(query)
1✔
193
            return cursor.fetchall(), None
1✔
194
        except sqlite3.Error as e:
1✔
195
            logger.info(f"Error executing SQL: {e}")
1✔
196
            return None, f"Error executing SQL: {e}"
1✔
197
        finally:
198
            conn.close()
1✔
199

200

201
def get_cache():
1✔
202
    """Returns a singleton cache instance, initializing it if necessary."""
203
    global _cache_instance
204
    if _cache_instance is None:
1✔
205
        _cache_instance = Cache()
1✔
206
    return _cache_instance
1✔
207

208

209
def generate_cache_key(*args, **kwargs):
1✔
210
    """Generate a stable hashable cache key for various input types.
211

212
    :param args: Positional arguments of the function.
213
    :param kwargs: Keyword arguments of the function.
214
    :return: A hashed key as a string.
215
    """
216
    try:
1✔
217
        # Convert args and kwargs to a JSON string (sorted to ensure consistency)
218
        serialized = json.dumps(
1✔
219
            {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
220
        )
221
    except TypeError:
×
222
        # Fallback for non-serializable objects
223
        serialized = repr((args, kwargs))
×
224

225
    # Hash the serialized data
226
    return hashlib.md5(serialized.encode()).hexdigest()
1✔
227

228

229
class Cache:
1✔
230
    """A class that provides disk-based caching functionality for a given function."""
231

232
    def __init__(self):
1✔
233
        """Initializes the cache.
234

235
        If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
236
        cache is created using `diskcache`.
237

238
        Args:
239
            None
240

241
        Returns:
242
            None
243
        """
244
        if CACHE_LOCATION:
1✔
245
            try:
×
246
                import diskcache
×
247

248
                # Ensure the cache directory exists
249
                os.makedirs(CACHE_LOCATION, exist_ok=True)
×
250

251
                # Create a global diskcache Cache instance
252
                self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
×
253
                logger.info(f"Caching enabled at {CACHE_LOCATION}")
×
254
            except ImportError as e:
×
255
                raise ImportError(
×
256
                    "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
257
                    "Please install diskcache `pip install diskcache` "
258
                    "or unset UNITXT_CACHE_LOCATION."
259
                ) from e
260
        else:
261
            self.cache = None  # Disable caching
1✔
262

263
    def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
264
        if not self.cache or no_cache:
1✔
265
            logger.info(f"Bypassing cache for key: {key}")
1✔
266
            return compute_fn()
1✔
267

268
        if refresh and key in self.cache:
×
269
            logger.info(f"Refreshing cache for key: {key}")
×
270
            del self.cache[key]
×
271

272
        if key in self.cache:
×
273
            logger.info(f"Cache hit for key: {key}")
×
274
            return self.cache[key]
×
275

276
        logger.info(f"Cache miss for key: {key}. Computing value...")
×
277
        result = compute_fn()
×
278
        self.cache[key] = result
×
279
        logger.info(f"Stored result in cache for key: {key}")
×
280
        return result
×
281

282
    async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
283
        if not self.cache or no_cache:
×
284
            logger.info(f"Bypassing cache for key: {key}")
×
285
            return await compute_fn()
×
286

287
        if refresh and key in self.cache:
×
288
            logger.info(f"Refreshing cache for key: {key}")
×
289
            del self.cache[key]
×
290

291
        if key in self.cache:
×
292
            logger.info(f"Cache hit for key: {key}")
×
293
            return self.cache[key]
×
294

295
        logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
×
296
        result = await compute_fn()
×
297
        self.cache[key] = result
×
298
        logger.info(f"Stored result in cache for key: {key}")
×
299
        return result
×
300

301
    def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
302
        def decorator(func):
×
303
            @functools.wraps(func)
×
304
            def wrapper(*args, **kwargs):
×
305
                if not self.cache or no_cache:
×
306
                    logger.info(f"Bypassing cache for function: {func.__name__}")
×
307
                    return func(*args, **kwargs)
×
308

309
                key = key_func(func.__name__, *args, **kwargs)
×
310

311
                if refresh and key in self.cache:
×
312
                    logger.info(
×
313
                        f"Refreshing cache for function: {func.__name__}, key: {key}"
314
                    )
315
                    del self.cache[key]
×
316

317
                if key in self.cache:
×
318
                    logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
×
319
                    return self.cache[key]
×
320

321
                logger.info(
×
322
                    f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
323
                )
324
                result = func(*args, **kwargs)
×
325
                self.cache[key] = result
×
326
                logger.info(
×
327
                    f"Stored result in cache for function: {func.__name__}, key: {key}"
328
                )
329
                return result
×
330

331
            return wrapper
×
332

333
        return decorator
×
334

335
    def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
336
        def decorator(func):
×
337
            @functools.wraps(func)
×
338
            async def wrapper(*args, **kwargs):
×
339
                if no_cache:
×
340
                    logger.info(f"Bypassing cache for async function: {func.__name__}")
×
341
                    return await func(*args, **kwargs)
×
342

343
                key = key_func(func.__name__, *args, **kwargs)
×
344

345
                if refresh and key in self.cache:
×
346
                    logger.info(
×
347
                        f"Refreshing cache for async function: {func.__name__}, key: {key}"
348
                    )
349
                    del self.cache[key]
×
350

351
                if key in self.cache:
×
352
                    logger.info(
×
353
                        f"Cache hit for async function: {func.__name__}, key: {key}"
354
                    )
355
                    return self.cache[key]
×
356

357
                logger.info(
×
358
                    f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
359
                )
360
                result = await func(*args, **kwargs)
×
361
                self.cache[key] = result
×
362
                logger.info(
×
363
                    f"Stored result in cache for async function: {func.__name__}, key: {key}"
364
                )
365
                return result
×
366

367
            return wrapper
×
368

369
        return decorator
×
370

371

372
@lru_cache(maxsize=128)
1✔
373
def execute_query_remote(
1✔
374
    api_url: str,
375
    database_id: str,
376
    api_key: str,
377
    query: str,
378
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
379
    max_retries: int = 3,
380
    retry_delay: int = 5,  # seconds
381
    timeout: int = 30,  # seconds
382
) -> (Optional[dict], str):
383
    """Executes a query against the remote database, with retries for certain exceptions."""
384
    headers = {
1✔
385
        "Content-Type": "application/json",
386
        "accept": "application/json",
387
        "Authorization": f"Bearer {api_key}",
388
    }
389
    retries = 0
1✔
390
    while retries <= max_retries:
1✔
391
        try:
1✔
392
            response = requests.post(
1✔
393
                f"{api_url}/sql",
394
                headers=headers,
395
                json={"sql": query, "dataSourceId": database_id},
396
                verify=False,
397
                timeout=timeout,
398
            )
399
            response.raise_for_status()
×
400
            return response.json(), None
×
401

402
        except retryable_exceptions as e:
1✔
403
            retries += 1
×
404
            logger.warning(
×
405
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
406
            )
407
            if retries <= max_retries:
×
408
                time.sleep(retry_delay)
×
409
            else:
410
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
×
411
                return (
×
412
                    None,
413
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
414
                )
415

416
        except requests.exceptions.HTTPError as e:
1✔
417
            if e.response.status_code >= 500:
×
418
                retries += 1
×
419
                logger.warning(
×
420
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
421
                )
422
                if retries <= max_retries:
×
423
                    time.sleep(retry_delay)
×
424
                else:
425
                    logger.error(
×
426
                        f"Max retries ({max_retries}) exceeded for query: {query}"
427
                    )
428
                    return (
×
429
                        None,
430
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
431
                    )
432
            else:
433
                logger.error(f"HTTP Error on attempt {retries}: {e}")
×
434
                return (
×
435
                    None,
436
                    f"HTTP Error on attempt {retries}: {e}",
437
                )
438

439
        except Exception as e:
1✔
440
            logger.error(f"Unexpected error on attempt {retries}: {e}")
1✔
441
            return (None, f"Unexpected error on attempt {retries}: {e}")
1✔
442

443
    return None, "Unknown Error in SQL execution"
×
444

445

446
class RemoteDatabaseConnector(DatabaseConnector):
1✔
447
    """Database connector for remote databases accessed via HTTP."""
448

449
    def __init__(self, db_config: SQLDatabase):
1✔
450
        super().__init__(db_config)
1✔
451

452
        assert db_config[
1✔
453
            "db_id"
454
        ], "db_id must be in db_config for RemoteDatabaseConnector"
455
        self.api_url, self.database_id = (
1✔
456
            db_config["db_id"].split(",")[0],
457
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
458
        )
459

460
        if not self.api_url or not self.database_id:
1✔
461
            raise ValueError(
1✔
462
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
463
            )
464

465
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
466
        if not self.api_key:
1✔
467
            raise ValueError(
1✔
468
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
469
            )
470

471
        self.headers = {
1✔
472
            "Content-Type": "application/json",
473
            "accept": "application/json",
474
            "Authorization": f"Bearer {self.api_key}",
475
        }
476

477
        self.timeout = 30
1✔
478

479
    def get_table_schema(
1✔
480
        self,
481
    ) -> str:
482
        """Retrieves the schema of a database."""
483
        cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
1✔
484
        response = requests.get(
1✔
485
            cur_api_url,
486
            headers=self.headers,
487
            verify=False,
488
            timeout=self.timeout,
489
        )
490
        if response.status_code == 200:
1✔
491
            schema = response.json()["schema"]
1✔
492
        else:
493
            raise OSError(f"Could not fetch schema from {cur_api_url}")
×
494

495
        schema_text = ""
×
496
        for table in schema["tables"]:
×
497
            schema_text += f"Table: {table['table_name']} has columns: {[col['column_name'] for col in table['columns']]}\n"
×
498

499
        return schema_text
×
500

501
    def execute_query(self, query: str) -> Any:
1✔
502
        """Executes a query against the remote database, with retries for certain exceptions."""
503
        cache = get_cache()
1✔
504

505
        cache_key = generate_cache_key(
1✔
506
            "sql_request", self.api_url, self.database_id, query
507
        )
508
        return cache.get_or_set(
1✔
509
            cache_key,
510
            lambda: execute_query_remote(
511
                api_url=self.api_url,
512
                database_id=self.database_id,
513
                api_key=self.api_key,
514
                query=query,
515
                timeout=self.timeout,
516
            ),
517
        )
518

519

520
def get_db_connector(db_type: str):
1✔
521
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
522
    if db_type == "local":
1✔
523
        connector = LocalSQLiteConnector
×
524
    elif db_type == "in_memory":
1✔
525
        connector = InMemoryDatabaseConnector
1✔
526
    elif db_type == "remote":
1✔
527
        connector = RemoteDatabaseConnector
×
528

529
    else:
530
        raise ValueError(f"Unsupported database type: {db_type}")
1✔
531

532
    return connector
1✔
533

534

535
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
1✔
536
    """Returns True if sqlglot does not encounter any error, False otherwise."""
537
    from sqlglot import parse
1✔
538

539
    if not sql.strip():
1✔
540
        return False
×
541
    if db_type == "db2":
1✔
542
        db_type = "postgres"  ## TODO: temporary until sqlglot adds support for db2
×
543
    try:
1✔
544
        parse(sql, read=db_type)
1✔
545
        return True
1✔
546
    except Exception as e:
×
547
        logger.debug(f"SQL query could not parse: {e}")
×
548
        return False
×
549

550

551
def is_sqlparse_parsable(sql: str) -> bool:
1✔
552
    """Returns True if sqlparse does not encounter any error, False otherwise."""
553
    from sqlparse import parse
1✔
554
    from sqlparse.tokens import Error
1✔
555

556
    if not sql.strip():
1✔
557
        return False
×
558
    try:
1✔
559
        statements = parse(sql)
1✔
560
        for statement in statements:
1✔
561
            for token in statement.tokens:
1✔
562
                if token.ttype == Error:
1✔
563
                    return False
×
564
        return True
1✔
565
    except Exception as e:
×
566
        logger.debug(f"SQL query could not parse: {e}")
×
567
        return False
×
568

569

570
def sqlglot_optimized_equivalence(expected: str, generated: str) -> int:
1✔
571
    """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them."""
572
    from sqlglot import diff, parse_one
1✔
573
    from sqlglot.optimizer import optimize
1✔
574

575
    try:
1✔
576
        t_diff = diff(
1✔
577
            optimize(parse_one(expected.lower()).sql(pretty=True)),
578
            optimize(parse_one(generated.lower()).sql(pretty=True)),
579
        )
580
        sql_diff = sum(0 if (e.__class__.__name__ == "Keep") else 1 for e in t_diff)
1✔
581

582
        return 1 if sql_diff == 0 else 0
1✔
583
    except Exception as e:
×
584
        logger.debug(f"Error parsing SQL for comparison: {e}")
×
585
        return False
×
586

587

588
def extract_select_columns(statement):
1✔
589
    """Parse SQL using sqlparse and extract columns."""
590
    from sqlparse.sql import Identifier, IdentifierList
1✔
591
    from sqlparse.tokens import DML, Keyword
1✔
592

593
    columns = []
1✔
594
    select_seen = False
1✔
595
    for token in statement.tokens:
1✔
596
        if token.ttype is DML and token.value.upper() == "SELECT":
1✔
597
            select_seen = True
1✔
598
            continue
1✔
599
        if select_seen:
1✔
600
            if token.ttype is Keyword and token.value.upper() in (
1✔
601
                "FROM",
602
                "WHERE",
603
                "GROUP",
604
                "HAVING",
605
                "ORDER",
606
                "LIMIT",
607
            ):
608
                break
1✔
609
            if isinstance(token, IdentifierList):
1✔
610
                for identifier in token.get_identifiers():
1✔
611
                    columns.append(strip_alias(identifier.value))
1✔
612
            elif isinstance(token, Identifier):
1✔
613
                columns.append(strip_alias(token.value))
1✔
614
            else:
615
                val = token.value.strip()
1✔
616
                if val:
1✔
617
                    columns.append(strip_alias(val))
1✔
618
    return frozenset(columns)
1✔
619

620

621
def strip_alias(col: str) -> str:
1✔
622
    """Remove any AS alias from a column."""
623
    col = col.strip()
1✔
624
    upper = col.upper()
1✔
625
    if " AS " in upper:
1✔
626
        return col[: upper.index(" AS ")].strip()
1✔
627
    parts_alias = col.split()
1✔
628
    if len(parts_alias) > 1:
1✔
629
        return " ".join(parts_alias[:-1])
×
630
    return col
1✔
631

632

633
def collect_clause(statement, clause_keyword):
1✔
634
    """Parse SQL statement and collect clauses."""
635
    from sqlparse.tokens import Keyword
1✔
636

637
    found = False
1✔
638
    collected = []
1✔
639
    for token in statement.tokens:
1✔
640
        tvalue = token.value.upper()
1✔
641
        if token.ttype is Keyword:
1✔
642
            if tvalue.startswith(clause_keyword):
1✔
643
                found = True
1✔
644
                continue
1✔
645
            if found and tvalue in (
1✔
646
                "FROM",
647
                "WHERE",
648
                "GROUP",
649
                "HAVING",
650
                "ORDER",
651
                "LIMIT",
652
            ):
653
                break
×
654
        if found:
1✔
655
            collected.append(token.value)
1✔
656
    return " ".join(collected).strip()
1✔
657

658

659
def extract_select_info(sql: str):
1✔
660
    """Parse SQL using sqlparse and return a dict of extracted columns and clauses."""
661
    from sqlparse import parse
1✔
662
    from sqlparse.tokens import DML
1✔
663

664
    statements = parse(sql)
1✔
665
    if len(statements) != 1:
1✔
666
        return None
×
667
    stmt = statements[0]
1✔
668
    if not any(t.ttype is DML and t.value.upper() == "SELECT" for t in stmt.tokens):
1✔
669
        return None
×
670
    parts = {
1✔
671
        "columns": None,
672
        "from": "",
673
        "where": "",
674
        "group": "",
675
        "having": "",
676
        "order": "",
677
    }
678
    columns = extract_select_columns(stmt)
1✔
679
    if not columns:
1✔
680
        columns = frozenset()
×
681
    parts["columns"] = columns
1✔
682
    parts["from"] = collect_clause(stmt, "FROM")
1✔
683
    parts["where"] = collect_clause(stmt, "WHERE")
1✔
684
    parts["group"] = collect_clause(stmt, "GROUP")
1✔
685
    parts["having"] = collect_clause(stmt, "HAVING")
1✔
686
    parts["order"] = collect_clause(stmt, "ORDER")
1✔
687
    return parts
1✔
688

689

690
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
1✔
691
    """Return True if both SQL queries are naively considered equivalent."""
692
    try:
1✔
693
        info1 = extract_select_info(sql1)
1✔
694
        info2 = extract_select_info(sql2)
1✔
695
        if not info1 or not info2:
1✔
696
            return False
×
697
        if info1["columns"] != info2["columns"]:
1✔
698
            return False
1✔
699
        for k in ["from", "where", "group", "having", "order"]:
1✔
700
            if info1[k].replace(" ", "").upper() != info2[k].replace(" ", "").upper():
1✔
701
                return False
1✔
702
        return True
×
703
    except Exception as e:
×
704
        logger.debug(f"Errpr parsing SQL query for comparison: {e}")
×
705
        return False
×
706

707

708
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
1✔
709
    from sqlglot import exp, parse_one
1✔
710

711
    try:
1✔
712
        ast1 = parse_one(sql1, read=dialect)
1✔
713
        ast2 = parse_one(sql2, read=dialect)
1✔
714
    except:
×
715
        return False
×
716
    if not (isinstance(ast1, exp.Select) and isinstance(ast2, exp.Select)):
1✔
717
        return False
×
718

719
    def normalized_select_columns(select_expr: exp.Select):
1✔
720
        cols = []
1✔
721
        for item in select_expr.expressions:
1✔
722
            copy_item = item.copy()
1✔
723
            copy_item.set("alias", None)
1✔
724
            cols.append(copy_item.sql(dialect=dialect, normalize=True))
1✔
725
        return frozenset(cols)
1✔
726

727
    if normalized_select_columns(ast1) != normalized_select_columns(ast2):
1✔
728
        return False
1✔
729

730
    def normalized_clause(expr: exp.Expression, key: str):
1✔
731
        clause = expr.args.get(key)
1✔
732
        return clause.sql(dialect=dialect, normalize=True) if clause else ""
1✔
733

734
    for clause_key in ("from", "where", "group", "having", "order"):
1✔
735
        if normalized_clause(ast1, clause_key) != normalized_clause(ast2, clause_key):
1✔
736
            return False
×
737

738
    return True
1✔
739

740

741
def sql_exact_match(sql1: str, sql2: str) -> bool:
1✔
742
    """Return True if two SQL strings match after very basic normalization."""
743

744
    def normalize_sql(s: str) -> str:
1✔
745
        s = s.strip().rstrip(";")
1✔
746
        s = re.sub(r"\s+", " ", s)
1✔
747
        return s.upper()
1✔
748

749
    return normalize_sql(sql1) == normalize_sql(sql2)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc