• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16348455941

17 Jul 2025 02:51PM UTC coverage: 81.218% (-0.02%) from 81.238%
16348455941

Pull #1872

github

web-flow
Merge e0de9e73d into 051a17617
Pull Request #1872: Bench and models

1555 of 1927 branches covered (80.7%)

Branch coverage included in aggregate %.

10592 of 13029 relevant lines covered (81.3%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

72.8
src/unitxt/text2sql_utils.py
1
import functools
1✔
2
import glob
1✔
3
import hashlib
1✔
4
import json
1✔
5
import os
1✔
6
import re
1✔
7
import time
1✔
8
from abc import ABC, abstractmethod
1✔
9
from collections import Counter
1✔
10
from dataclasses import dataclass
1✔
11
from functools import lru_cache
1✔
12
from typing import Any, List, Optional, Tuple
1✔
13

14
import numpy as np
1✔
15
import pandas as pd
1✔
16
import requests
1✔
17
from huggingface_hub import snapshot_download
1✔
18
from requests.exceptions import ConnectionError, ReadTimeout
1✔
19

20
from .logging_utils import get_logger
1✔
21
from .types import SQLDatabase
1✔
22

23
try:
1✔
24
    import sqlite3
1✔
25
except ImportError:
×
26
    sqlite3 = None
×
27

28

29
logger = get_logger()
30

31
# Check if caching is enabled via environment variable
32
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
1✔
33

34
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
35
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
1✔
36

37
_cache_instance = None
1✔
38

39

40
class DatabaseConnector(ABC):
1✔
41
    """Abstract base class for database connectors."""
42

43
    def __init__(self, db_config: SQLDatabase):
1✔
44
        self.db_config = db_config
1✔
45
        self.databases_folder = os.path.join(
1✔
46
            os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
47
        )
48
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
49

50
    @abstractmethod
1✔
51
    def get_table_schema(
1✔
52
        self,
53
    ) -> str:
54
        """Abstract method to get database schema."""
55
        pass
56

57
    @abstractmethod
1✔
58
    def execute_query(self, query: str) -> Any:
1✔
59
        """Abstract method to execute a query against the database."""
60
        pass
61

62

63
@lru_cache(maxsize=128)
1✔
64
def execute_query_local(db_path: str, query: str) -> Any:
1✔
65
    """Executes a query against the SQLite database."""
66
    conn = None  # Initialize conn to None outside the try block
1✔
67
    try:
1✔
68
        conn = sqlite3.connect(db_path)
1✔
69
        cursor = conn.cursor()
1✔
70
        cursor.execute(query)
1✔
71
        return cursor.fetchall(), None
1✔
72
    except sqlite3.Error as e:
1✔
73
        logger.info(f"Error executing SQL: {e}")
74
        return None, f"Error executing SQL: {e}"
1✔
75
    finally:
76
        if conn:
1✔
77
            conn.close()
1✔
78

79

80
class LocalSQLiteConnector(DatabaseConnector):
1✔
81
    """Database connector for SQLite databases."""
82

83
    def __init__(self, db_config: SQLDatabase):
1✔
84
        super().__init__(db_config)
1✔
85
        db_id = self.db_config.get("db_id")
1✔
86
        if not db_id:
1✔
87
            raise ValueError("db_id is required for SQLiteConnector.")
88
        self.db_path = self.get_db_file_path(db_id)
1✔
89
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
90
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
91

92
    def download_database(self, db_id):
1✔
93
        """Downloads the database from huggingface if needed."""
94
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
95
        if "bird/" in db_id:
1✔
96
            if not os.path.exists(done_file_path):
×
97
                snapshot_download(
×
98
                    repo_id="premai-io/birdbench",
99
                    repo_type="dataset",
100
                    local_dir=self.databases_folder,
101
                    force_download=False,
102
                    allow_patterns="*validation*",
103
                )
104
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
105
        else:
106
            raise NotImplementedError(
1✔
107
                f"current local db: {db_id} is not supported, only bird"
108
            )
109

110
    def get_db_file_path(self, db_id):
1✔
111
        """Gets the local path of a downloaded database file."""
112
        self.download_database(db_id)
1✔
113
        db_id = db_id.split("/")[-1]
×
114

115
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
116
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
117

118
        if not db_file_paths:
×
119
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
120
        if len(db_file_paths) > 1:
×
121
            raise FileExistsError(f"More than one files matched for {db_id}")
×
122
        return db_file_paths[0]
×
123

124
    def get_table_schema(
1✔
125
        self,
126
    ) -> str:
127
        """Extracts schema from an SQLite database."""
128
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
129
        tables: list[tuple[str]] = self.cursor.fetchall()
×
130
        schemas: dict[str, str] = {}
×
131

132
        for table in tables:
×
133
            if isinstance(table, tuple):
×
134
                table = table[0]
×
135
            if table == "sqlite_sequence":
×
136
                continue
×
137
            sql_query: str = (
×
138
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
139
            )
140
            self.cursor.execute(sql_query)
×
141
            schema_prompt: str = self.cursor.fetchone()[0]
×
142

143
            schemas[table] = schema_prompt
×
144

145
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
146
        return schema_prompt
×
147

148
    def execute_query(self, query: str) -> Any:
1✔
149
        """Executes a query against the SQLite database."""
150
        return execute_query_local(self.db_path, query)
1✔
151

152

153
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
154
    """Database connector for mocking databases with in-memory data structures."""
155

156
    def __init__(self, db_config: SQLDatabase):
1✔
157
        super().__init__(db_config)
1✔
158
        self.tables = db_config.get("data", None)
1✔
159

160
        if not self.tables:
1✔
161
            raise ValueError("data is required for InMemoryDatabaseConnector.")
162

163
    def get_table_schema(
1✔
164
        self,
165
        select_tables: Optional[List[str]] = None,
166
    ) -> str:
167
        """Generates a mock schema from the tables structure."""
168
        schemas = {}
1✔
169
        for table_name, table_data in self.tables.items():
1✔
170
            if select_tables and table_name.lower() not in select_tables:
1✔
171
                continue
1✔
172
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
173
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
174

175
            schemas[table_name] = schema
1✔
176

177
        return "\n\n".join(list(schemas.values()))
1✔
178

179
    def execute_query(self, query: str) -> Any:
1✔
180
        """Simulates executing a query against the mock database."""
181
        # Initialize in-memory database from the 'tables' dictionary
182
        conn = sqlite3.connect(":memory:")
1✔
183
        cursor = conn.cursor()
1✔
184
        logger.debug("Running SQL query over in-memory DB")
185

186
        # Create tables and insert data from the 'db' dictionary
187
        for table_name, table_data in self.tables.items():
1✔
188
            columns = table_data["columns"]
1✔
189
            rows = table_data["rows"]
1✔
190

191
            # Create table
192
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
193

194
            # Insert data
195
            placeholders = ", ".join(["?"] * len(columns))
1✔
196
            cursor.executemany(
1✔
197
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
198
            )
199

200
        try:
1✔
201
            cursor.execute(query)
1✔
202
            return cursor.fetchall(), None
1✔
203
        except sqlite3.Error as e:
1✔
204
            logger.info(f"Error executing SQL: {e}")
205
            return None, f"Error executing SQL: {e}"
1✔
206
        finally:
207
            conn.close()
1✔
208

209

210
def get_cache():
1✔
211
    """Returns a singleton cache instance, initializing it if necessary."""
212
    global _cache_instance
213
    if _cache_instance is None:
1✔
214
        _cache_instance = Cache()
1✔
215
    return _cache_instance
1✔
216

217

218
def generate_cache_key(*args, **kwargs):
1✔
219
    """Generate a stable hashable cache key for various input types.
220

221
    :param args: Positional arguments of the function.
222
    :param kwargs: Keyword arguments of the function.
223
    :return: A hashed key as a string.
224
    """
225
    try:
1✔
226
        # Convert args and kwargs to a JSON string (sorted to ensure consistency)
227
        serialized = json.dumps(
1✔
228
            {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
229
        )
230
    except TypeError:
×
231
        # Fallback for non-serializable objects
232
        serialized = repr((args, kwargs))
×
233

234
    # Hash the serialized data
235
    return hashlib.md5(serialized.encode()).hexdigest()
1✔
236

237

238
class Cache:
1✔
239
    """A class that provides disk-based caching functionality for a given function."""
240

241
    def __init__(self):
1✔
242
        """Initializes the cache.
243

244
        If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
245
        cache is created using `diskcache`.
246

247
        Args:
248
            None
249

250
        Returns:
251
            None
252
        """
253
        if CACHE_LOCATION:
1✔
254
            try:
×
255
                import diskcache
×
256

257
                # Ensure the cache directory exists
258
                os.makedirs(CACHE_LOCATION, exist_ok=True)
×
259

260
                # Create a global diskcache Cache instance
261
                self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
×
262
                logger.info(f"Caching enabled at {CACHE_LOCATION}")
263
            except ImportError as e:
×
264
                raise ImportError(
×
265
                    "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
266
                    "Please install diskcache `pip install diskcache` "
267
                    "or unset UNITXT_CACHE_LOCATION."
268
                ) from e
269
        else:
270
            self.cache = None  # Disable caching
1✔
271

272
    def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
273
        if not self.cache or no_cache:
1✔
274
            logger.info(f"Bypassing cache for key: {key}")
275
            return compute_fn()
1✔
276

277
        if refresh and key in self.cache:
×
278
            logger.info(f"Refreshing cache for key: {key}")
279
            del self.cache[key]
×
280

281
        if key in self.cache:
×
282
            logger.info(f"Cache hit for key: {key}")
283
            return self.cache[key]
×
284

285
        logger.info(f"Cache miss for key: {key}. Computing value...")
286
        result = compute_fn()
×
287

288
        if result and not (
×
289
            isinstance(result, tuple) and len(result) == 2 and result[0] is None
290
        ):
291
            self.cache[key] = result
×
292
            logger.info(f"Stored result in cache for key: {key}")
293
        else:
294
            logger.info(f"None result. Bypassing caching for key: {key}")
295

296
        return result
×
297

298
    async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
299
        if not self.cache or no_cache:
×
300
            logger.info(f"Bypassing cache for key: {key}")
301
            return await compute_fn()
×
302

303
        if refresh and key in self.cache:
×
304
            logger.info(f"Refreshing cache for key: {key}")
305
            del self.cache[key]
×
306

307
        if key in self.cache:
×
308
            logger.info(f"Cache hit for key: {key}")
309
            return self.cache[key]
×
310

311
        logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
312
        result = await compute_fn()
×
313
        self.cache[key] = result
×
314
        logger.info(f"Stored result in cache for key: {key}")
315
        return result
×
316

317
    def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
318
        def decorator(func):
×
319
            @functools.wraps(func)
×
320
            def wrapper(*args, **kwargs):
×
321
                if not self.cache or no_cache:
×
322
                    logger.info(f"Bypassing cache for function: {func.__name__}")
323
                    return func(*args, **kwargs)
×
324

325
                key = key_func(func.__name__, *args, **kwargs)
×
326

327
                if refresh and key in self.cache:
×
328
                    logger.info(
329
                        f"Refreshing cache for function: {func.__name__}, key: {key}"
330
                    )
331
                    del self.cache[key]
×
332

333
                if key in self.cache:
×
334
                    logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
335
                    return self.cache[key]
×
336

337
                logger.info(
338
                    f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
339
                )
340
                result = func(*args, **kwargs)
×
341
                self.cache[key] = result
×
342
                logger.info(
343
                    f"Stored result in cache for function: {func.__name__}, key: {key}"
344
                )
345
                return result
×
346

347
            return wrapper
×
348

349
        return decorator
×
350

351
    def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
352
        def decorator(func):
×
353
            @functools.wraps(func)
×
354
            async def wrapper(*args, **kwargs):
×
355
                if no_cache:
×
356
                    logger.info(f"Bypassing cache for async function: {func.__name__}")
357
                    return await func(*args, **kwargs)
×
358

359
                key = key_func(func.__name__, *args, **kwargs)
×
360

361
                if refresh and key in self.cache:
×
362
                    logger.info(
363
                        f"Refreshing cache for async function: {func.__name__}, key: {key}"
364
                    )
365
                    del self.cache[key]
×
366

367
                if key in self.cache:
×
368
                    logger.info(
369
                        f"Cache hit for async function: {func.__name__}, key: {key}"
370
                    )
371
                    return self.cache[key]
×
372

373
                logger.info(
374
                    f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
375
                )
376
                result = await func(*args, **kwargs)
×
377
                self.cache[key] = result
×
378
                logger.info(
379
                    f"Stored result in cache for async function: {func.__name__}, key: {key}"
380
                )
381
                return result
×
382

383
            return wrapper
×
384

385
        return decorator
×
386

387

388
@lru_cache(maxsize=128)
1✔
389
def execute_query_remote(
1✔
390
    api_url: str,
391
    database_id: str,
392
    api_key: str,
393
    query: str,
394
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
395
    max_retries: int = 3,
396
    retry_delay: int = 5,  # seconds
397
    timeout: int = 30,  # seconds
398
) -> (Optional[dict], str):
399
    """Executes a query against the remote database, with retries for certain exceptions."""
400
    headers = {
1✔
401
        "Content-Type": "application/json",
402
        "accept": "application/json",
403
        "Authorization": f"Bearer {api_key}",
404
    }
405
    retries = 0
1✔
406
    while retries <= max_retries:
1✔
407
        try:
1✔
408
            response = requests.post(
1✔
409
                f"{api_url}/sql",
410
                headers=headers,
411
                json={"sql": query, "dataSourceId": database_id},
412
                verify=False,
413
                timeout=timeout,
414
            )
415
            response.raise_for_status()
×
416
            return response.json(), None
×
417

418
        except retryable_exceptions as e:
1✔
419
            retries += 1
×
420
            logger.warning(
421
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
422
            )
423
            if retries <= max_retries:
×
424
                time.sleep(retry_delay)
×
425
            else:
426
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
427
                return (
×
428
                    None,
429
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
430
                )
431

432
        except requests.exceptions.HTTPError as e:
1✔
433
            if e.response.status_code >= 500:
×
434
                retries += 1
×
435
                logger.warning(
436
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
437
                )
438
                if retries <= max_retries:
×
439
                    time.sleep(retry_delay)
×
440
                else:
441
                    logger.error(
442
                        f"Max retries ({max_retries}) exceeded for query: {query}"
443
                    )
444
                    return (
×
445
                        None,
446
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
447
                    )
448
            else:
449
                logger.error(f"HTTP Error on attempt {retries}: {e}")
450
                return (
×
451
                    None,
452
                    f"HTTP Error on attempt {retries}: {e}",
453
                )
454

455
        except Exception as e:
456
            logger.error(f"Unexpected error on attempt {retries}: {e}")
457
            return (None, f"Unexpected error on attempt {retries}: {e}")
458

459
    return None, "Unknown Error in SQL execution"
×
460

461

462
class RemoteDatabaseConnector(DatabaseConnector):
1✔
463
    """Database connector for remote databases accessed via HTTP."""
464

465
    def __init__(self, db_config: SQLDatabase):
1✔
466
        super().__init__(db_config)
1✔
467

468
        assert db_config[
1✔
469
            "db_id"
470
        ], "db_id must be in db_config for RemoteDatabaseConnector"
471
        self.api_url, self.database_id = (
1✔
472
            db_config["db_id"].split(",")[0],
473
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
474
        )
475

476
        if not self.api_url or not self.database_id:
1✔
477
            raise ValueError(
478
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
479
            )
480

481
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
482
        if not self.api_key:
1✔
483
            raise ValueError(
484
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
485
            )
486

487
        self.headers = {
1✔
488
            "Content-Type": "application/json",
489
            "accept": "application/json",
490
            "Authorization": f"Bearer {self.api_key}",
491
        }
492

493
        self.timeout = 30
1✔
494

495
    def get_table_schema(
1✔
496
        self,
497
    ) -> str:
498
        """Retrieves the schema of a database."""
499
        cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
1✔
500
        response = requests.get(
1✔
501
            cur_api_url,
502
            headers=self.headers,
503
            verify=False,
504
            timeout=self.timeout,
505
        )
506
        if response.status_code == 200:
×
507
            schema = response.json()["schema"]
×
508
        else:
509
            raise OSError(f"Could not fetch schema from {cur_api_url}")
×
510

511
        schema_text = ""
×
512
        for table in schema["tables"]:
×
513
            schema_text += f"Table: {table['name'] if 'name' in table else table['table_name']} has columns: {[col['name'] if 'name' in col else col['column_name'] for col in table['columns']]}\n"
×
514

515
        return schema_text
×
516

517
    def execute_query(self, query: str) -> Any:
1✔
518
        """Executes a query against the remote database, with retries for certain exceptions."""
519
        cache = get_cache()
1✔
520

521
        cache_key = generate_cache_key(
1✔
522
            "sql_request", self.api_url, self.database_id, query
523
        )
524
        return cache.get_or_set(
1✔
525
            cache_key,
526
            lambda: execute_query_remote(
527
                api_url=self.api_url,
528
                database_id=self.database_id,
529
                api_key=self.api_key,
530
                query=query,
531
                timeout=self.timeout,
532
            ),
533
        )
534

535

536
def get_db_connector(db_type: str):
1✔
537
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
538
    if db_type == "local":
1✔
539
        connector = LocalSQLiteConnector
×
540
    elif db_type == "in_memory":
1✔
541
        connector = InMemoryDatabaseConnector
1✔
542
    elif db_type == "remote":
1✔
543
        connector = RemoteDatabaseConnector
×
544

545
    else:
546
        raise ValueError(f"Unsupported database type: {db_type}")
547

548
    return connector
1✔
549

550

551
@dataclass
1✔
552
class SQLNonExecutionMetricResult:
1✔
553
    sqlglot_validity: int  # Whether SQL parses with sqlglot
1✔
554
    sqlparse_validity: int  # Whether SQL parses with sqlparse
1✔
555
    sqlglot_equivalence: int  # Semantic equivalence using sqlglot AST
1✔
556
    sqlglot_optimized_equivalence: int  # Equivalence after optimization via sqlglot
1✔
557
    sqlparse_equivalence: int  # Equivalence using sqlparse AST
1✔
558
    sql_exact_match: int  # Exact string match of predicted and gold SQL
1✔
559
    sql_syntactic_equivalence: int  # Any of the above equivalence conditions hold
1✔
560

561

562
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
1✔
563
    """Returns True if sqlglot does not encounter any error, False otherwise."""
564
    from sqlglot import parse
1✔
565

566
    if not sql.strip():
1✔
567
        return False
×
568
    if db_type == "db2":
1✔
569
        db_type = "postgres"  ## TODO: temporary until sqlglot adds support for db2
×
570
    try:
1✔
571
        parse(sql, read=db_type)
1✔
572
        return True
1✔
573
    except Exception as e:
574
        logger.debug(f"SQL query could not parse: {e}")
575
        return False
576

577

578
def is_sqlparse_parsable(sql: str) -> bool:
1✔
579
    """Returns True if sqlparse does not encounter any error, False otherwise."""
580
    from sqlparse import parse
1✔
581
    from sqlparse.tokens import Error
1✔
582

583
    if not sql.strip():
1✔
584
        return False
×
585
    try:
1✔
586
        statements = parse(sql)
1✔
587
        for statement in statements:
1✔
588
            for token in statement.tokens:
1✔
589
                if token.ttype == Error:
1✔
590
                    return False
×
591
        return True
1✔
592
    except Exception as e:
593
        logger.debug(f"SQL query could not parse: {e}")
594
        return False
595

596

597
def sqlglot_optimized_equivalence(expected: str, generated: str) -> int:
1✔
598
    """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them."""
599
    from sqlglot import diff, parse_one
1✔
600
    from sqlglot.optimizer import optimize
1✔
601

602
    try:
1✔
603
        t_diff = diff(
1✔
604
            optimize(parse_one(expected.lower()).sql(pretty=True)),
605
            optimize(parse_one(generated.lower()).sql(pretty=True)),
606
        )
607
        sql_diff = sum(0 if (e.__class__.__name__ == "Keep") else 1 for e in t_diff)
1✔
608

609
        return 1 if sql_diff == 0 else 0
1✔
610
    except Exception as e:
611
        logger.debug(f"Error parsing SQL for comparison: {e}")
612
        return False
613

614

615
def extract_select_columns(statement):
1✔
616
    """Parse SQL using sqlparse and extract columns."""
617
    from sqlparse.sql import Identifier, IdentifierList
1✔
618
    from sqlparse.tokens import DML, Keyword
1✔
619

620
    columns = []
1✔
621
    select_seen = False
1✔
622
    for token in statement.tokens:
1✔
623
        if token.ttype is DML and token.value.upper() == "SELECT":
1✔
624
            select_seen = True
1✔
625
            continue
1✔
626
        if select_seen:
1✔
627
            if token.ttype is Keyword and token.value.upper() in (
1✔
628
                "FROM",
629
                "WHERE",
630
                "GROUP",
631
                "HAVING",
632
                "ORDER",
633
                "LIMIT",
634
            ):
635
                break
1✔
636
            if isinstance(token, IdentifierList):
1✔
637
                for identifier in token.get_identifiers():
1✔
638
                    columns.append(strip_alias(identifier.value))
1✔
639
            elif isinstance(token, Identifier):
1✔
640
                columns.append(strip_alias(token.value))
1✔
641
            else:
642
                val = token.value.strip()
1✔
643
                if val:
1✔
644
                    columns.append(strip_alias(val))
1✔
645
    return frozenset(columns)
1✔
646

647

648
def strip_alias(col: str) -> str:
1✔
649
    """Remove any AS alias from a column."""
650
    col = col.strip()
1✔
651
    upper = col.upper()
1✔
652
    if " AS " in upper:
1✔
653
        return col[: upper.index(" AS ")].strip()
1✔
654
    parts_alias = col.split()
1✔
655
    if len(parts_alias) > 1:
1✔
656
        return " ".join(parts_alias[:-1])
×
657
    return col
1✔
658

659

660
def collect_clause(statement, clause_keyword):
1✔
661
    """Parse SQL statement and collect clauses."""
662
    from sqlparse.tokens import Keyword
1✔
663

664
    found = False
1✔
665
    collected = []
1✔
666
    for token in statement.tokens:
1✔
667
        tvalue = token.value.upper()
1✔
668
        if token.ttype is Keyword:
1✔
669
            if tvalue.startswith(clause_keyword):
1✔
670
                found = True
1✔
671
                continue
1✔
672
            if found and tvalue in (
1✔
673
                "FROM",
674
                "WHERE",
675
                "GROUP",
676
                "HAVING",
677
                "ORDER",
678
                "LIMIT",
679
            ):
680
                break
×
681
        if found:
1✔
682
            collected.append(token.value)
1✔
683
    return " ".join(collected).strip()
1✔
684

685

686
def extract_select_info(sql: str):
1✔
687
    """Parse SQL using sqlparse and return a dict of extracted columns and clauses."""
688
    from sqlparse import parse
1✔
689
    from sqlparse.tokens import DML
1✔
690

691
    statements = parse(sql)
1✔
692
    if len(statements) != 1:
1✔
693
        return None
×
694
    stmt = statements[0]
1✔
695
    if not any(t.ttype is DML and t.value.upper() == "SELECT" for t in stmt.tokens):
1✔
696
        return None
×
697
    parts = {
1✔
698
        "columns": None,
699
        "from": "",
700
        "where": "",
701
        "group": "",
702
        "having": "",
703
        "order": "",
704
    }
705
    columns = extract_select_columns(stmt)
1✔
706
    if not columns:
1✔
707
        columns = frozenset()
×
708
    parts["columns"] = columns
1✔
709
    parts["from"] = collect_clause(stmt, "FROM")
1✔
710
    parts["where"] = collect_clause(stmt, "WHERE")
1✔
711
    parts["group"] = collect_clause(stmt, "GROUP")
1✔
712
    parts["having"] = collect_clause(stmt, "HAVING")
1✔
713
    parts["order"] = collect_clause(stmt, "ORDER")
1✔
714
    return parts
1✔
715

716

717
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
1✔
718
    """Returns True if both SQL queries are naively considered equivalent."""
719
    try:
1✔
720
        info1 = extract_select_info(sql1)
1✔
721
        info2 = extract_select_info(sql2)
1✔
722
        if not info1 or not info2:
1✔
723
            return False
×
724
        if info1["columns"] != info2["columns"]:
1✔
725
            return False
1✔
726
        for k in ["from", "where", "group", "having", "order"]:
1✔
727
            if info1[k].replace(" ", "").upper() != info2[k].replace(" ", "").upper():
1✔
728
                return False
1✔
729
        return True
×
730
    except Exception as e:
731
        logger.debug(f"Errpr parsing SQL query for comparison: {e}")
732
        return False
733

734

735
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
1✔
736
    """Return True if two SQL queries match after parsing with SQLGlot."""
737
    from sqlglot import exp, parse_one
1✔
738

739
    try:
1✔
740
        ast1 = parse_one(sql1, read=dialect)
1✔
741
        ast2 = parse_one(sql2, read=dialect)
1✔
742
    except:
×
743
        return False
×
744
    if not (isinstance(ast1, exp.Select) and isinstance(ast2, exp.Select)):
1✔
745
        return False
×
746

747
    def normalized_select_columns(select_expr: exp.Select):
1✔
748
        cols = []
1✔
749
        for item in select_expr.expressions:
1✔
750
            copy_item = item.copy()
1✔
751
            copy_item.set("alias", None)
1✔
752
            cols.append(copy_item.sql(dialect=dialect, normalize=True))
1✔
753
        return frozenset(cols)
1✔
754

755
    if normalized_select_columns(ast1) != normalized_select_columns(ast2):
1✔
756
        return False
1✔
757

758
    def normalized_clause(expr: exp.Expression, key: str):
1✔
759
        clause = expr.args.get(key)
1✔
760
        return clause.sql(dialect=dialect, normalize=True) if clause else ""
1✔
761

762
    for clause_key in ("from", "where", "group", "having", "order"):
1✔
763
        if normalized_clause(ast1, clause_key) != normalized_clause(ast2, clause_key):
1✔
764
            return False
×
765

766
    return True
1✔
767

768

769
def sql_exact_match(sql1: str, sql2: str) -> bool:
1✔
770
    """Return True if two SQL strings match after very basic normalization."""
771

772
    def normalize_sql(s: str) -> str:
1✔
773
        s = s.strip().rstrip(";")
1✔
774
        s = re.sub(r"\s+", " ", s)
1✔
775
        return s.upper()
1✔
776

777
    return normalize_sql(sql1) == normalize_sql(sql2)
1✔
778

779

780
@dataclass
1✔
781
class SQLExecutionResult:
1✔
782
    execution_accuracy: int  # Whether the predicted and gold SQL results match exactly
1✔
783
    non_empty_execution_accuracy: (
1✔
784
        int  # Same as execution_accuracy but only if gold is non-empty
785
    )
786
    subset_non_empty_execution_accuracy: (
1✔
787
        int  # Whether predicted is a subset of gold or vice versa, non-empty only
788
    )
789
    execution_accuracy_bird: (
1✔
790
        int  # Whether the predicted SQL matches gold using BIRD evaluation logic
791
    )
792
    non_empty_gold_df: int  # Whether the gold SQL produced a non-empty dataframe
1✔
793
    gold_sql_runtime: float  # Time taken to execute the gold SQL
1✔
794
    predicted_sql_runtime: float  # Time taken to execute the predicted SQL
1✔
795
    pred_to_gold_runtime_ratio: float  # Ratio of predicted runtime to gold runtime
1✔
796
    gold_error: int  # Whether the gold SQL had an execution error
1✔
797
    predicted_error: int  # Whether the predicted SQL had an execution error
1✔
798
    gold_df_json: str  # JSON representation of the gold SQL result dataframe
1✔
799
    predicted_df_json: str  # JSON representation of the predicted SQL result dataframe
1✔
800
    error_message: str  # Error message from predicted execution if any
1✔
801

802

803
def compare_dfs_ignore_colnames_ordered_rows(
1✔
804
    df1: pd.DataFrame, df2: pd.DataFrame
805
) -> bool:
806
    if df1.shape != df2.shape:
×
807
        return False
×
808
    df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
×
809
    df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
×
810
    return np.array_equal(df1_sorted_rows, df2_sorted_rows)
×
811

812

813
def compare_dfs_ignore_colnames_unordered_rows(
1✔
814
    df1: pd.DataFrame, df2: pd.DataFrame
815
) -> bool:
816
    if df1.shape != df2.shape:
1✔
817
        return False
×
818
    df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
1✔
819
    df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
1✔
820
    return np.array_equal(df1_sorted, df2_sorted)
1✔
821

822

823
def compare_dfs_ignore_colnames_subset(
1✔
824
    df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
825
) -> bool:
826
    """Checks if the smaller of the two DataFrames is likely a subset of the other.
827

828
    Subset comparison is column-based, to support Text2SQL evaluation for when the
829
    predicted SQL dataframe has missing or additional columns. Each row is treated as
830
    a multiset of (stringified) values, and the function checks if every row in the
831
    smaller DataFrame (by column count) is a multiset subset of the corresponding row
832
    in the larger DataFrame. When ground truth SQL does not have ORDER BY,
833
    ignore_row_order can be set to True to ignore the order of rows. In this case,
834
    column values are sorted before comparison. This means that there could be cases
835
    where the dataframes have the exact same number of rows and column values after
836
    sort are the same, but the dataframes are not actually a subset of each other.
837
    This is unlikely to happen in practice, but the score is not guaranteed to be
838
    100% accurate and may overestimate the accuracy.
839

840
    Args:
841
        df1 (pd.DataFrame): The first DataFrame to compare.
842
        df2 (pd.DataFrame): The second DataFrame to compare.
843
        ignore_row_order (bool, optional): If True, ignores the order of rows by
844
            sorting them before comparison. Defaults to True.
845

846
    Returns:
847
        bool: True if the smaller DataFrame (column-wise) is likely a subset of the
848
            larger one, False otherwise.
849
    """
850

851
    def row_to_multiset(row):
1✔
852
        return Counter(str(x) for x in row)
1✔
853

854
    def rows_to_multisets(df):
1✔
855
        return [row_to_multiset(row) for row in df.values]
1✔
856

857
    def sort_df(df):
1✔
858
        sorted_df = df.copy()
1✔
859
        for col in sorted_df.columns:
1✔
860
            sorted_df[col] = sorted_df[col].astype(str).sort_values(ignore_index=True)
1✔
861
        return sorted_df
1✔
862

863
    if df1.empty or df2.empty or len(df1) != len(df2):
1✔
864
        return False
×
865

866
    subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
1✔
867

868
    if ignore_row_order:
1✔
869
        subset_df = sort_df(subset_df)
1✔
870
        superset_df = sort_df(superset_df)
1✔
871

872
    subset_rows = rows_to_multisets(subset_df)
1✔
873
    superset_rows = rows_to_multisets(superset_df)
1✔
874

875
    for r1, r2 in zip(subset_rows, superset_rows):
1✔
876
        if not all(r1[k] <= r2.get(k, 0) for k in r1):
1✔
877
            return False
×
878
    return True
1✔
879

880

881
def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
1✔
882
    """Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
883

884
    This function checks if the set of rows returned by the predicted SQL query
885
    (`predicted_res`) is exactly equal to the set of rows returned by the ground truth
886
    SQL query (`ground_truth_res`). This is the logic used in the original BIRD
887
    evaluation code:
888
    https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
889
    """
890
    df1_set = {tuple(row) for row in df1.values.astype(str)}
1✔
891
    df2_set = {tuple(row) for row in df2.values.astype(str)}
1✔
892
    return int(df1_set == df2_set)
1✔
893

894

895
def compare_result_dfs(
1✔
896
    gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
897
) -> Tuple[int, int, int]:
898
    """Compares two DataFrames representing SQL query results.
899

900
    Args:
901
        gold_df (pd.DataFrame): The ground truth DataFrame.
902
        pred_df (pd.DataFrame): The predicted DataFrame.
903
        gold_sql (str): The ground truth SQL query string.
904

905
    Returns:
906
        Tuple[int, int, int]: A tuple containing:
907
            - match (int): 1 if the predicted DataFrame matches the gold DataFrame
908
            - non_empty_match (int): 1 if both DataFrames are non-empty and match,
909
              0 otherwise.
910
            - subset_match (int): 1 if the predicted DataFrame is a subset or
911
              superset of the gold DataFrame.
912

913
    Notes:
914
        - The comparison ignores column names.
915
        - Row order is considered only if 'ORDER BY' is present in the SQL query.
916
    """
917
    subset_match = 0
1✔
918
    non_empty_match = 0
1✔
919
    if "ORDER BY" in gold_sql.upper():
1✔
920
        match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
×
921
        if not gold_df.empty and not pred_df.empty:
×
922
            non_empty_match = match
×
923
            if compare_dfs_ignore_colnames_subset(
×
924
                gold_df, pred_df, ignore_row_order=False
925
            ):
926
                subset_match = 1
×
927
    else:
928
        match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
1✔
929
        if not gold_df.empty and not pred_df.empty:
1✔
930
            non_empty_match = match
1✔
931
            if compare_dfs_ignore_colnames_subset(
1✔
932
                gold_df, pred_df, ignore_row_order=True
933
            ):
934
                subset_match = 1
1✔
935
    return match, non_empty_match, subset_match
1✔
936

937

938
def run_query(
1✔
939
    sql: str, connector, sql_timeout: float
940
) -> Tuple[Optional[pd.DataFrame], float, str]:
941
    """Executes a SQL query using the provided connector with a timeout.
942

943
    Args:
944
        sql (str): The SQL query string to execute.
945
        connector: An object with an `execute_query` method that executes the SQL
946
            query.
947
        sql_timeout (float): The maximum time in seconds to allow for query
948
            execution.
949

950
    Returns:
951
        Tuple[Optional[pd.DataFrame], float, str]:
952
            - A pandas DataFrame containing the query results, or None if an error
953
              occurred.
954
            - The duration in seconds taken to execute the query. 0.0 if an error.
955
            - An error message string if an error occurred, otherwise an empty
956
              string.
957

958
    Notes:
959
        - If the SQL string is empty or only whitespace, returns immediately with a
960
          message.
961
        - If the query execution exceeds the timeout, returns a timeout error
962
          message.
963
        - Any other exceptions are caught and returned as error messages.
964
    """
965
    import time
1✔
966

967
    from func_timeout import func_timeout
1✔
968
    from func_timeout.exceptions import FunctionTimedOut
1✔
969

970
    if not sql.strip():
1✔
971
        return None, 0.0, "No SQL query found in the prediction."
×
972

973
    try:
1✔
974
        start = time.perf_counter()
1✔
975
        result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
1✔
976
        duration = time.perf_counter() - start
1✔
977
        if isinstance(result, dict) and "results" in result:
1✔
978
            result = result["results"]
×
979
        if error:
1✔
980
            return None, duration, error
1✔
981
        return pd.DataFrame(result), duration, ""
1✔
982
    except FunctionTimedOut as e:
×
983
        return None, 0.0, f"Timeout: {e}"
×
984
    except Exception as e:
985
        return None, 0.0, f"Error: {e}"
986

987

988
def get_sql_execution_results(
1✔
989
    predicted_sql: str, gold_sql: str, connector, sql_timeout: float
990
) -> SQLExecutionResult:
991
    """Execute and compare predicted and gold SQL queries, returning execution metrics.
992

993
    Args:
994
        predicted_sql (str): The SQL query predicted by the model.
995
        gold_sql (str): The reference (gold) SQL query.
996
        connector: Database connector object used to execute the queries.
997
        sql_timeout (float): Maximum time (in seconds) allowed for query execution.
998

999
    Returns:
1000
        SQLExecutionResult: An object containing various execution metrics, including:
1001
            - execution_accuracy (int): 1 if predicted and gold queries produce
1002
              equivalent results, else 0.
1003
            - non_empty_execution_accuracy (int): 1 if both queries produce non-empty
1004
              and equivalent results, else 0.
1005
            - subset_non_empty_execution_accuracy (int): 1 if predicted results are a
1006
              subset or superset of gold results and non-empty, else 0. Subset
1007
              comparison is column-based. This means that the predicted SQL dataframe
1008
              can have missing or additional columns compared to the gold SQL dataframe.
1009
            - execution_accuracy_bird (int): 1 if results match according to BIRD
1010
              evaluation logic, else 0.
1011
            - non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
1012
            - gold_sql_runtime (float): Execution time for the gold SQL query.
1013
            - predicted_sql_runtime (float): Execution time for the predicted SQL query.
1014
            - pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
1015
              runtimes.
1016
            - gold_error (int): 1 if the gold query failed, else 0.
1017
            - predicted_error (int): 1 if the predicted query failed, else 0.
1018
            - gold_df_json (str): JSON representation of the gold query result
1019
              DataFrame.
1020
            - predicted_df_json (str): JSON representation of the predicted query
1021
              result DataFrame.
1022
            - error_message (str): Error message if any query failed, else empty
1023
              string.
1024

1025
    Notes:
1026
        - If the gold query fails, the function returns early with error details.
1027
        - If the predicted query is identical or SQL-equivalent to the gold query,
1028
          results are considered correct without execution.
1029
        - Otherwise, both queries are executed and their results compared using
1030
          multiple metrics.
1031
    """
1032
    gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
1✔
1033
    gold_error = int(bool(gold_error_msg))
1✔
1034

1035
    if gold_error:
1✔
1036
        return SQLExecutionResult(
×
1037
            execution_accuracy=0,
1038
            non_empty_execution_accuracy=0,
1039
            subset_non_empty_execution_accuracy=0,
1040
            execution_accuracy_bird=0,
1041
            non_empty_gold_df=0,
1042
            gold_sql_runtime=gold_runtime,
1043
            predicted_sql_runtime=0,
1044
            pred_to_gold_runtime_ratio=0,
1045
            gold_error=gold_error,
1046
            predicted_error=0,
1047
            gold_df_json="",
1048
            predicted_df_json="",
1049
            error_message=gold_error_msg,
1050
        )
1051

1052
    non_empty_gold_df = int(not gold_df.empty)
1✔
1053
    if predicted_sql.strip().lower() == gold_sql.strip().lower():
1✔
1054
        return SQLExecutionResult(
×
1055
            execution_accuracy=1,
1056
            non_empty_execution_accuracy=non_empty_gold_df,
1057
            subset_non_empty_execution_accuracy=non_empty_gold_df,
1058
            execution_accuracy_bird=1,
1059
            non_empty_gold_df=non_empty_gold_df,
1060
            gold_sql_runtime=gold_runtime,
1061
            predicted_sql_runtime=0,
1062
            pred_to_gold_runtime_ratio=0,
1063
            gold_error=0,
1064
            predicted_error=0,
1065
            gold_df_json=gold_df.to_json(),
1066
            predicted_df_json=gold_df.to_json(),
1067
            error_message="",
1068
        )
1069

1070
    try:
1✔
1071
        if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
1✔
1072
            return SQLExecutionResult(
1✔
1073
                execution_accuracy=1,
1074
                non_empty_execution_accuracy=non_empty_gold_df,
1075
                subset_non_empty_execution_accuracy=non_empty_gold_df,
1076
                execution_accuracy_bird=1,
1077
                non_empty_gold_df=non_empty_gold_df,
1078
                gold_sql_runtime=gold_runtime,
1079
                predicted_sql_runtime=0,
1080
                pred_to_gold_runtime_ratio=0,
1081
                gold_error=0,
1082
                predicted_error=0,
1083
                gold_df_json=gold_df.to_json(),
1084
                predicted_df_json=gold_df.to_json(),
1085
                error_message="",
1086
            )
1087
    except Exception as e:
1088
        logger.info(f"Could not check SQL equivalence: {e}")
1089

1090
    pred_df, pred_runtime, pred_error_msg = run_query(
1✔
1091
        predicted_sql, connector, sql_timeout
1092
    )
1093
    pred_error = 1 if pred_error_msg else 0
1✔
1094

1095
    if pred_df is None:
1✔
1096
        return SQLExecutionResult(
1✔
1097
            execution_accuracy=0,
1098
            non_empty_execution_accuracy=0,
1099
            subset_non_empty_execution_accuracy=0,
1100
            execution_accuracy_bird=0,
1101
            non_empty_gold_df=non_empty_gold_df,
1102
            gold_sql_runtime=gold_runtime,
1103
            predicted_sql_runtime=pred_runtime,
1104
            pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1105
            if gold_runtime > 0
1106
            else 0,
1107
            gold_error=0,
1108
            predicted_error=pred_error,
1109
            gold_df_json=gold_df.to_json(),
1110
            predicted_df_json="",
1111
            error_message=pred_error_msg,
1112
        )
1113

1114
    match, non_empty_match, subset_match = compare_result_dfs(
1✔
1115
        gold_df, pred_df, gold_sql
1116
    )
1117
    bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
1✔
1118

1119
    return SQLExecutionResult(
1✔
1120
        execution_accuracy=match,
1121
        non_empty_execution_accuracy=non_empty_match,
1122
        subset_non_empty_execution_accuracy=subset_match,
1123
        execution_accuracy_bird=bird_match,
1124
        non_empty_gold_df=non_empty_gold_df,
1125
        gold_sql_runtime=gold_runtime,
1126
        predicted_sql_runtime=pred_runtime,
1127
        pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1128
        if gold_runtime > 0
1129
        else 0,
1130
        gold_error=0,
1131
        predicted_error=0,
1132
        gold_df_json=gold_df.to_json(),
1133
        predicted_df_json=pred_df.to_json(),
1134
        error_message=pred_error_msg,
1135
    )
1136

1137

1138
def replace_select_clause(
1✔
1139
    source_query: str, target_query: str, dialect: str = "postgres"
1140
) -> str:
1141
    """Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
1142

1143
    Args:
1144
        source_query (str): SQL query whose SELECT clause will be used.
1145
        target_query (str): SQL query whose SELECT clause will be replaced.
1146
        dialect (str): SQL dialect for parsing and rendering (default: "postgres").
1147

1148
    Returns:
1149
        str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
1150

1151
    Raises:
1152
        ValueError: If either query is not a valid SELECT statement.
1153

1154
    Example:
1155
        >>> replace_select_clause(
1156
        ...     "SELECT id FROM employees",
1157
        ...     "SELECT name FROM employees WHERE age > 30"
1158
        ... )
1159
        "SELECT id FROM employees WHERE age > 30"
1160
    """
1161
    from sqlglot import exp, parse_one
×
1162

1163
    if not dialect:
×
1164
        dialect = "postgres"
×
1165

1166
    # Parse queries using the specified dialect
1167
    source_ast = parse_one(source_query, read=dialect)
×
1168
    target_ast = parse_one(target_query, read=dialect)
×
1169

1170
    if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
×
1171
        raise ValueError("Both queries must be valid SELECT statements.")
1172

1173
    # Replace SELECT expressions in the target with those from the source
1174
    target_ast.set("expressions", source_ast.expressions)
×
1175

1176
    # Return the updated SQL string using the dialect
1177
    return target_ast.sql(dialect=dialect)
×
1178

1179

1180
def extract_sql_from_text(text: str) -> str:
1✔
1181
    """Extracts the first SQL query from the given text.
1182

1183
    Priority:
1184
    1. SQL inside fenced blocks like ```sql ... ```
1185
    2. SQL starting on a new line or after a colon/label
1186
    3. SQL without semicolon
1187

1188
    Returns:
1189
        The SQL query string, or an empty string if not found.
1190
    """
1191
    # 1. Look for fenced SQL code block
1192
    fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
1✔
1193
    match = fenced_block_pattern.search(text)
1✔
1194
    if match:
1✔
1195
        return match.group(1).strip()
×
1196

1197
    # 2. Inline SQL with semicolon
1198
    sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
1✔
1199
    sql_start = (
1✔
1200
        r"(?:^|\n|:\s*)"  # Start of string, newline, or colon label like "Just run:"
1201
    )
1202
    sql_pattern = re.compile(
1✔
1203
        rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
1204
    )
1205
    match = sql_pattern.search(text)
1✔
1206
    if match:
1✔
1207
        return match.group(1).strip()
×
1208

1209
    # 3. Inline SQL without semicolon
1210
    fallback_pattern = re.compile(
1✔
1211
        rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
1212
    )
1213
    fallback_match = fallback_pattern.search(text)
1✔
1214
    if fallback_match:
1✔
1215
        return fallback_match.group(1).strip()
1✔
1216

1217
    return ""
×
1218

1219

1220
ALL_DIALECTS = [
1✔
1221
    "Athena",
1222
    "BigQuery",
1223
    "ClickHouse",
1224
    "Databricks",
1225
    "Doris",
1226
    "Drill",
1227
    "Druid",
1228
    "DuckDB",
1229
    "Hive",
1230
    "Materialize",
1231
    "MySQL",
1232
    "Oracle",
1233
    "Postgres",
1234
    "Presto",
1235
    "PRQL",
1236
    "Redshift",
1237
    "RisingWave",
1238
    "Snowflake",
1239
    "Spark",
1240
    "Spark2",
1241
    "SQLite",
1242
    "StarRocks",
1243
    "Tableau",
1244
    "Teradata",
1245
    "Trino",
1246
    "TSQL",
1247
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc