• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 19876733021

02 Dec 2025 11:19PM UTC coverage: 80.896% (+0.007%) from 80.889%
19876733021

Pull #1954

github

web-flow
Merge ba3c7acda into f24c1be29
Pull Request #1954: Fix duplicate-column sorting issue in Text2SQL evaluation utils

1607 of 2006 branches covered (80.11%)

Branch coverage included in aggregate %.

10948 of 13514 relevant lines covered (81.01%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.4
src/unitxt/text2sql_utils.py
1
import functools
1✔
2
import glob
1✔
3
import hashlib
1✔
4
import json
1✔
5
import os
1✔
6
import re
1✔
7
import time
1✔
8
from abc import ABC, abstractmethod
1✔
9
from collections import Counter
1✔
10
from dataclasses import dataclass
1✔
11
from functools import lru_cache
1✔
12
from typing import Any, List, Optional, Tuple
1✔
13

14
import numpy as np
1✔
15
import pandas as pd
1✔
16
import requests
1✔
17
from huggingface_hub import snapshot_download
1✔
18
from requests.exceptions import ConnectionError, ReadTimeout
1✔
19

20
from .logging_utils import get_logger
1✔
21
from .types import SQLDatabase
1✔
22

23
try:
1✔
24
    import sqlite3
1✔
25
except ImportError:
×
26
    sqlite3 = None
×
27

28

29
logger = get_logger()
30

31
# Check if caching is enabled via environment variable
32
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
1✔
33

34
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
35
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
1✔
36

37
_cache_instance = None
1✔
38

39

40
class DatabaseConnector(ABC):
1✔
41
    """Abstract base class for database connectors."""
42

43
    def __init__(self, db_config: SQLDatabase):
1✔
44
        self.db_config = db_config
1✔
45
        self.databases_folder = os.path.join(
1✔
46
            os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
47
        )
48
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
49

50
    @abstractmethod
1✔
51
    def get_table_schema(
1✔
52
        self,
53
    ) -> str:
54
        """Abstract method to get database schema."""
55
        pass
56

57
    @abstractmethod
1✔
58
    def execute_query(self, query: str) -> Any:
1✔
59
        """Abstract method to execute a query against the database."""
60
        pass
61

62

63
@lru_cache(maxsize=128)
1✔
64
def execute_query_local(db_path: str, query: str) -> Any:
1✔
65
    """Executes a query against the SQLite database."""
66
    conn = None  # Initialize conn to None outside the try block
1✔
67
    try:
1✔
68
        conn = sqlite3.connect(db_path)
1✔
69
        cursor = conn.cursor()
1✔
70
        cursor.execute(query)
1✔
71
        return cursor.fetchall(), None
1✔
72
    except sqlite3.Error as e:
1✔
73
        logger.info(f"Error executing SQL: {e}")
74
        return None, f"Error executing SQL: {e}"
1✔
75
    finally:
76
        if conn:
1✔
77
            conn.close()
1✔
78

79

80
class LocalSQLiteConnector(DatabaseConnector):
1✔
81
    """Database connector for SQLite databases."""
82

83
    def __init__(self, db_config: SQLDatabase):
1✔
84
        super().__init__(db_config)
1✔
85
        db_id = self.db_config.get("db_id")
1✔
86
        if not db_id:
1✔
87
            raise ValueError("db_id is required for SQLiteConnector.")
88
        self.db_path = self.get_db_file_path(db_id)
1✔
89
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
90
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
91

92
    def download_database(self, db_id):
1✔
93
        """Downloads the database from huggingface if needed."""
94
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
95
        if "bird/" in db_id:
1✔
96
            if not os.path.exists(done_file_path):
×
97
                snapshot_download(
×
98
                    repo_id="premai-io/birdbench",
99
                    repo_type="dataset",
100
                    local_dir=self.databases_folder,
101
                    force_download=False,
102
                    allow_patterns="*validation*",
103
                )
104
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
105
        else:
106
            raise NotImplementedError(
1✔
107
                f"current local db: {db_id} is not supported, only bird"
108
            )
109

110
    def get_db_file_path(self, db_id):
1✔
111
        """Gets the local path of a downloaded database file."""
112
        self.download_database(db_id)
1✔
113
        db_id = db_id.split("/")[-1]
×
114

115
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
116
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
117

118
        if not db_file_paths:
×
119
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
120
        if len(db_file_paths) > 1:
×
121
            raise FileExistsError(f"More than one files matched for {db_id}")
×
122
        return db_file_paths[0]
×
123

124
    def get_table_schema(
1✔
125
        self,
126
    ) -> str:
127
        """Extracts schema from an SQLite database."""
128
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
129
        tables: list[tuple[str]] = self.cursor.fetchall()
×
130
        schemas: dict[str, str] = {}
×
131

132
        for table in tables:
×
133
            if isinstance(table, tuple):
×
134
                table = table[0]
×
135
            if table == "sqlite_sequence":
×
136
                continue
×
137
            sql_query: str = (
×
138
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
139
            )
140
            self.cursor.execute(sql_query)
×
141
            schema_prompt: str = self.cursor.fetchone()[0]
×
142

143
            schemas[table] = schema_prompt
×
144

145
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
146
        return schema_prompt
×
147

148
    def execute_query(self, query: str) -> Any:
1✔
149
        """Executes a query against the SQLite database."""
150
        return execute_query_local(self.db_path, query)
1✔
151

152

153
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
154
    """Database connector for mocking databases with in-memory data structures."""
155

156
    def __init__(self, db_config: SQLDatabase):
1✔
157
        super().__init__(db_config)
1✔
158
        self.tables = db_config.get("data", None)
1✔
159

160
        if not self.tables:
1✔
161
            raise ValueError("data is required for InMemoryDatabaseConnector.")
162

163
    def get_table_schema(
1✔
164
        self,
165
        select_tables: Optional[List[str]] = None,
166
    ) -> str:
167
        """Generates a mock schema from the tables structure."""
168
        schemas = {}
1✔
169
        for table_name, table_data in self.tables.items():
1✔
170
            if select_tables and table_name.lower() not in select_tables:
1✔
171
                continue
1✔
172
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
173
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
174

175
            schemas[table_name] = schema
1✔
176

177
        return "\n\n".join(list(schemas.values()))
1✔
178

179
    def execute_query(self, query: str) -> Any:
1✔
180
        """Simulates executing a query against the mock database."""
181
        # Initialize in-memory database from the 'tables' dictionary
182
        conn = sqlite3.connect(":memory:")
1✔
183
        cursor = conn.cursor()
1✔
184
        logger.debug("Running SQL query over in-memory DB")
185

186
        # Create tables and insert data from the 'db' dictionary
187
        for table_name, table_data in self.tables.items():
1✔
188
            columns = table_data["columns"]
1✔
189
            rows = table_data["rows"]
1✔
190

191
            # Create table
192
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
193

194
            # Insert data
195
            placeholders = ", ".join(["?"] * len(columns))
1✔
196
            cursor.executemany(
1✔
197
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
198
            )
199

200
        try:
1✔
201
            cursor.execute(query)
1✔
202
            return cursor.fetchall(), None
1✔
203
        except sqlite3.Error as e:
1✔
204
            logger.info(f"Error executing SQL: {e}")
205
            return None, f"Error executing SQL: {e}"
1✔
206
        finally:
207
            conn.close()
1✔
208

209

210
def get_cache():
1✔
211
    """Returns a singleton cache instance, initializing it if necessary."""
212
    global _cache_instance
213
    if _cache_instance is None:
1✔
214
        _cache_instance = Cache()
1✔
215
    return _cache_instance
1✔
216

217

218
def generate_cache_key(*args, **kwargs):
1✔
219
    """Generate a stable hashable cache key for various input types.
220

221
    :param args: Positional arguments of the function.
222
    :param kwargs: Keyword arguments of the function.
223
    :return: A hashed key as a string.
224
    """
225
    try:
1✔
226
        # Convert args and kwargs to a JSON string (sorted to ensure consistency)
227
        serialized = json.dumps(
1✔
228
            {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
229
        )
230
    except TypeError:
×
231
        # Fallback for non-serializable objects
232
        serialized = repr((args, kwargs))
×
233

234
    # Hash the serialized data
235
    return hashlib.md5(serialized.encode()).hexdigest()
1✔
236

237

238
class Cache:
1✔
239
    """A class that provides disk-based caching functionality for a given function."""
240

241
    def __init__(self):
1✔
242
        """Initializes the cache.
243

244
        If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
245
        cache is created using `diskcache`.
246

247
        Args:
248
            None
249

250
        Returns:
251
            None
252
        """
253
        if CACHE_LOCATION:
1✔
254
            try:
×
255
                import diskcache
×
256

257
                # Ensure the cache directory exists
258
                os.makedirs(CACHE_LOCATION, exist_ok=True)
×
259

260
                # Create a global diskcache Cache instance
261
                self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
×
262
                logger.info(f"Caching enabled at {CACHE_LOCATION}")
263
            except ImportError as e:
×
264
                raise ImportError(
×
265
                    "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
266
                    "Please install diskcache `pip install diskcache` "
267
                    "or unset UNITXT_CACHE_LOCATION."
268
                ) from e
269
        else:
270
            self.cache = None  # Disable caching
1✔
271

272
    def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
273
        if not self.cache or no_cache:
1✔
274
            logger.info(f"Bypassing cache for key: {key}")
275
            return compute_fn()
1✔
276

277
        if refresh and key in self.cache:
×
278
            logger.info(f"Refreshing cache for key: {key}")
279
            del self.cache[key]
×
280

281
        if key in self.cache:
×
282
            logger.info(f"Cache hit for key: {key}")
283
            return self.cache[key]
×
284

285
        logger.info(f"Cache miss for key: {key}. Computing value...")
286
        result = compute_fn()
×
287

288
        if result and not (
×
289
            isinstance(result, tuple) and len(result) == 2 and result[0] is None
290
        ):
291
            self.cache[key] = result
×
292
            logger.info(f"Stored result in cache for key: {key}")
293
        else:
294
            logger.info(f"None result. Bypassing caching for key: {key}")
295

296
        return result
×
297

298
    async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
299
        if not self.cache or no_cache:
×
300
            logger.info(f"Bypassing cache for key: {key}")
301
            return await compute_fn()
×
302

303
        if refresh and key in self.cache:
×
304
            logger.info(f"Refreshing cache for key: {key}")
305
            del self.cache[key]
×
306

307
        if key in self.cache:
×
308
            logger.info(f"Cache hit for key: {key}")
309
            return self.cache[key]
×
310

311
        logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
312
        result = await compute_fn()
×
313
        self.cache[key] = result
×
314
        logger.info(f"Stored result in cache for key: {key}")
315
        return result
×
316

317
    def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
318
        def decorator(func):
×
319
            @functools.wraps(func)
×
320
            def wrapper(*args, **kwargs):
×
321
                if not self.cache or no_cache:
×
322
                    logger.info(f"Bypassing cache for function: {func.__name__}")
323
                    return func(*args, **kwargs)
×
324

325
                key = key_func(func.__name__, *args, **kwargs)
×
326

327
                if refresh and key in self.cache:
×
328
                    logger.info(
329
                        f"Refreshing cache for function: {func.__name__}, key: {key}"
330
                    )
331
                    del self.cache[key]
×
332

333
                if key in self.cache:
×
334
                    logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
335
                    return self.cache[key]
×
336

337
                logger.info(
338
                    f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
339
                )
340
                result = func(*args, **kwargs)
×
341
                self.cache[key] = result
×
342
                logger.info(
343
                    f"Stored result in cache for function: {func.__name__}, key: {key}"
344
                )
345
                return result
×
346

347
            return wrapper
×
348

349
        return decorator
×
350

351
    def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
352
        def decorator(func):
×
353
            @functools.wraps(func)
×
354
            async def wrapper(*args, **kwargs):
×
355
                if no_cache:
×
356
                    logger.info(f"Bypassing cache for async function: {func.__name__}")
357
                    return await func(*args, **kwargs)
×
358

359
                key = key_func(func.__name__, *args, **kwargs)
×
360

361
                if refresh and key in self.cache:
×
362
                    logger.info(
363
                        f"Refreshing cache for async function: {func.__name__}, key: {key}"
364
                    )
365
                    del self.cache[key]
×
366

367
                if key in self.cache:
×
368
                    logger.info(
369
                        f"Cache hit for async function: {func.__name__}, key: {key}"
370
                    )
371
                    return self.cache[key]
×
372

373
                logger.info(
374
                    f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
375
                )
376
                result = await func(*args, **kwargs)
×
377
                self.cache[key] = result
×
378
                logger.info(
379
                    f"Stored result in cache for async function: {func.__name__}, key: {key}"
380
                )
381
                return result
×
382

383
            return wrapper
×
384

385
        return decorator
×
386

387

388
@lru_cache(maxsize=128)
1✔
389
def execute_query_remote(
1✔
390
    api_url: str,
391
    database_id: str,
392
    api_key: str,
393
    query: str,
394
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
395
    max_retries: int = 3,
396
    retry_delay: int = 5,  # seconds
397
    timeout: int = 30,  # seconds
398
) -> (Optional[dict], str):
399
    """Executes a query against the remote database, with retries for certain exceptions."""
400
    headers = {
1✔
401
        "Content-Type": "application/json",
402
        "accept": "application/json",
403
        "Authorization": f"Bearer {api_key}",
404
    }
405
    retries = 0
1✔
406
    while retries <= max_retries:
1✔
407
        try:
1✔
408
            response = requests.post(
1✔
409
                f"{api_url}/sql",
410
                headers=headers,
411
                json={"sql": query, "dataSourceId": database_id},
412
                verify=False,
413
                timeout=timeout,
414
            )
415
            response.raise_for_status()
×
416
            return response.json(), None
×
417

418
        except retryable_exceptions as e:
1✔
419
            retries += 1
×
420
            logger.warning(
421
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
422
            )
423
            if retries <= max_retries:
×
424
                time.sleep(retry_delay)
×
425
            else:
426
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
427
                return (
×
428
                    None,
429
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
430
                )
431

432
        except requests.exceptions.HTTPError as e:
1✔
433
            if e.response.status_code >= 500:
×
434
                retries += 1
×
435
                logger.warning(
436
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
437
                )
438
                if retries <= max_retries:
×
439
                    time.sleep(retry_delay)
×
440
                else:
441
                    logger.error(
442
                        f"Max retries ({max_retries}) exceeded for query: {query}"
443
                    )
444
                    return (
×
445
                        None,
446
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
447
                    )
448
            else:
449
                logger.error(f"HTTP Error on attempt {retries}: {e}")
450
                return (
×
451
                    None,
452
                    f"HTTP Error on attempt {retries}: {e}",
453
                )
454

455
        except Exception as e:
456
            logger.error(f"Unexpected error on attempt {retries}: {e}")
457
            return (None, f"Unexpected error on attempt {retries}: {e}")
458

459
    return None, "Unknown Error in SQL execution"
×
460

461

462
class RemoteDatabaseConnector(DatabaseConnector):
1✔
463
    """Database connector for remote databases accessed via HTTP."""
464

465
    def __init__(self, db_config: SQLDatabase):
1✔
466
        super().__init__(db_config)
1✔
467

468
        assert db_config[
1✔
469
            "db_id"
470
        ], "db_id must be in db_config for RemoteDatabaseConnector"
471
        self.api_url, self.database_id = (
1✔
472
            db_config["db_id"].split(",")[0],
473
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
474
        )
475

476
        if not self.api_url or not self.database_id:
1✔
477
            raise ValueError(
478
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
479
            )
480

481
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
482
        if not self.api_key:
1✔
483
            raise ValueError(
484
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
485
            )
486

487
        self.headers = {
1✔
488
            "Content-Type": "application/json",
489
            "accept": "application/json",
490
            "Authorization": f"Bearer {self.api_key}",
491
        }
492

493
        self.timeout = 30
1✔
494

495
    def get_table_schema(
1✔
496
        self,
497
    ) -> str:
498
        """Retrieves the schema of a database."""
499
        cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
1✔
500
        response = requests.get(
1✔
501
            cur_api_url,
502
            headers=self.headers,
503
            verify=False,
504
            timeout=self.timeout,
505
        )
506
        if response.status_code == 200:
1✔
507
            schema = response.json()["schema"]
1✔
508
        else:
509
            raise OSError(f"Could not fetch schema from {cur_api_url}")
×
510

511
        schema_text = ""
×
512
        for table in schema["tables"]:
×
513
            schema_text += f"Table: {table['name'] if 'name' in table else table['table_name']} has columns: {[col['name'] if 'name' in col else col['column_name'] for col in table['columns']]}\n"
×
514

515
        return schema_text
×
516

517
    def execute_query(self, query: str) -> Any:
1✔
518
        """Executes a query against the remote database, with retries for certain exceptions."""
519
        cache = get_cache()
1✔
520

521
        cache_key = generate_cache_key(
1✔
522
            "sql_request", self.api_url, self.database_id, query
523
        )
524
        return cache.get_or_set(
1✔
525
            cache_key,
526
            lambda: execute_query_remote(
527
                api_url=self.api_url,
528
                database_id=self.database_id,
529
                api_key=self.api_key,
530
                query=query,
531
                timeout=self.timeout,
532
            ),
533
        )
534

535

536
def get_db_connector(db_type: str):
1✔
537
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
538
    if db_type == "local":
1✔
539
        connector = LocalSQLiteConnector
×
540
    elif db_type == "in_memory":
1✔
541
        connector = InMemoryDatabaseConnector
1✔
542
    elif db_type == "remote":
1✔
543
        connector = RemoteDatabaseConnector
×
544

545
    else:
546
        raise ValueError(f"Unsupported database type: {db_type}")
547

548
    return connector
1✔
549

550

551
@dataclass
1✔
552
class SQLNonExecutionMetricResult:
1✔
553
    sqlglot_validity: int  # Whether SQL parses with sqlglot
1✔
554
    sqlparse_validity: int  # Whether SQL parses with sqlparse
1✔
555
    sqlglot_equivalence: int  # Semantic equivalence using sqlglot AST
1✔
556
    sqlglot_optimized_equivalence: int  # Equivalence after optimization via sqlglot
1✔
557
    sqlparse_equivalence: int  # Equivalence using sqlparse AST
1✔
558
    sql_exact_match: int  # Exact string match of predicted and gold SQL
1✔
559
    sql_syntactic_equivalence: int  # Any of the above equivalence conditions hold
1✔
560

561

562
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
1✔
563
    """Returns True if sqlglot does not encounter any error, False otherwise."""
564
    from sqlglot import parse
1✔
565

566
    if not sql.strip():
1✔
567
        return False
×
568
    if db_type == "db2":
1✔
569
        db_type = "postgres"  ## TODO: temporary until sqlglot adds support for db2
×
570
    try:
1✔
571
        parse(sql, read=db_type)
1✔
572
        return True
1✔
573
    except Exception as e:
574
        logger.debug(f"SQL query could not parse: {e}")
575
        return False
576

577

578
def is_sqlparse_parsable(sql: str) -> bool:
1✔
579
    """Returns True if sqlparse does not encounter any error, False otherwise."""
580
    from sqlparse import parse
1✔
581
    from sqlparse.tokens import Error
1✔
582

583
    if not sql.strip():
1✔
584
        return False
×
585
    try:
1✔
586
        statements = parse(sql)
1✔
587
        for statement in statements:
1✔
588
            for token in statement.tokens:
1✔
589
                if token.ttype == Error:
1✔
590
                    return False
×
591
        return True
1✔
592
    except Exception as e:
593
        logger.debug(f"SQL query could not parse: {e}")
594
        return False
595

596

597
def sqlglot_optimized_equivalence(expected: str, generated: str) -> int:
1✔
598
    """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them."""
599
    from sqlglot import diff, parse_one
1✔
600
    from sqlglot.optimizer import optimize
1✔
601

602
    try:
1✔
603
        t_diff = diff(
1✔
604
            optimize(parse_one(expected.lower()).sql(pretty=True)),
605
            optimize(parse_one(generated.lower()).sql(pretty=True)),
606
        )
607
        sql_diff = sum(0 if (e.__class__.__name__ == "Keep") else 1 for e in t_diff)
1✔
608

609
        return 1 if sql_diff == 0 else 0
1✔
610
    except Exception as e:
611
        logger.debug(f"Error parsing SQL for comparison: {e}")
612
        return False
613

614

615
def extract_select_columns(statement):
1✔
616
    """Parse SQL using sqlparse and extract columns."""
617
    from sqlparse.sql import Identifier, IdentifierList
1✔
618
    from sqlparse.tokens import DML, Keyword
1✔
619

620
    columns = []
1✔
621
    select_seen = False
1✔
622
    for token in statement.tokens:
1✔
623
        if token.ttype is DML and token.value.upper() == "SELECT":
1✔
624
            select_seen = True
1✔
625
            continue
1✔
626
        if select_seen:
1✔
627
            if token.ttype is Keyword and token.value.upper() in (
1✔
628
                "FROM",
629
                "WHERE",
630
                "GROUP",
631
                "HAVING",
632
                "ORDER",
633
                "LIMIT",
634
            ):
635
                break
1✔
636
            if isinstance(token, IdentifierList):
1✔
637
                for identifier in token.get_identifiers():
1✔
638
                    columns.append(strip_alias(identifier.value))
1✔
639
            elif isinstance(token, Identifier):
1✔
640
                columns.append(strip_alias(token.value))
1✔
641
            else:
642
                val = token.value.strip()
1✔
643
                if val:
1✔
644
                    columns.append(strip_alias(val))
1✔
645
    return frozenset(columns)
1✔
646

647

648
def strip_alias(col: str) -> str:
1✔
649
    """Remove any AS alias from a column."""
650
    col = col.strip()
1✔
651
    upper = col.upper()
1✔
652
    if " AS " in upper:
1✔
653
        return col[: upper.index(" AS ")].strip()
1✔
654
    parts_alias = col.split()
1✔
655
    if len(parts_alias) > 1:
1✔
656
        return " ".join(parts_alias[:-1])
×
657
    return col
1✔
658

659

660
def collect_clause(statement, clause_keyword):
1✔
661
    """Parse SQL statement and collect clauses."""
662
    from sqlparse.tokens import Keyword
1✔
663

664
    found = False
1✔
665
    collected = []
1✔
666
    for token in statement.tokens:
1✔
667
        tvalue = token.value.upper()
1✔
668
        if token.ttype is Keyword:
1✔
669
            if tvalue.startswith(clause_keyword):
1✔
670
                found = True
1✔
671
                continue
1✔
672
            if found and tvalue in (
1✔
673
                "FROM",
674
                "WHERE",
675
                "GROUP",
676
                "HAVING",
677
                "ORDER",
678
                "LIMIT",
679
            ):
680
                break
×
681
        if found:
1✔
682
            collected.append(token.value)
1✔
683
    return " ".join(collected).strip()
1✔
684

685

686
def extract_select_info(sql: str):
1✔
687
    """Parse SQL using sqlparse and return a dict of extracted columns and clauses."""
688
    from sqlparse import parse
1✔
689
    from sqlparse.tokens import DML
1✔
690

691
    statements = parse(sql)
1✔
692
    if len(statements) != 1:
1✔
693
        return None
×
694
    stmt = statements[0]
1✔
695
    if not any(t.ttype is DML and t.value.upper() == "SELECT" for t in stmt.tokens):
1✔
696
        return None
×
697
    parts = {
1✔
698
        "columns": None,
699
        "from": "",
700
        "where": "",
701
        "group": "",
702
        "having": "",
703
        "order": "",
704
    }
705
    columns = extract_select_columns(stmt)
1✔
706
    if not columns:
1✔
707
        columns = frozenset()
×
708
    parts["columns"] = columns
1✔
709
    parts["from"] = collect_clause(stmt, "FROM")
1✔
710
    parts["where"] = collect_clause(stmt, "WHERE")
1✔
711
    parts["group"] = collect_clause(stmt, "GROUP")
1✔
712
    parts["having"] = collect_clause(stmt, "HAVING")
1✔
713
    parts["order"] = collect_clause(stmt, "ORDER")
1✔
714
    return parts
1✔
715

716

717
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
1✔
718
    """Returns True if both SQL queries are naively considered equivalent."""
719
    try:
1✔
720
        info1 = extract_select_info(sql1)
1✔
721
        info2 = extract_select_info(sql2)
1✔
722
        if not info1 or not info2:
1✔
723
            return False
×
724
        if info1["columns"] != info2["columns"]:
1✔
725
            return False
1✔
726
        for k in ["from", "where", "group", "having", "order"]:
1✔
727
            if info1[k].replace(" ", "").upper() != info2[k].replace(" ", "").upper():
1✔
728
                return False
1✔
729
        return True
×
730
    except Exception as e:
731
        logger.debug(f"Error parsing SQL query for comparison: {e}")
732
        return False
733

734

735
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
1✔
736
    """Return True if two SQL queries match after parsing with SQLGlot."""
737
    from sqlglot import exp, parse_one
1✔
738

739
    try:
1✔
740
        ast1 = parse_one(sql1, read=dialect)
1✔
741
        ast2 = parse_one(sql2, read=dialect)
1✔
742
    except:
×
743
        return False
×
744
    if not (isinstance(ast1, exp.Select) and isinstance(ast2, exp.Select)):
1✔
745
        return False
×
746

747
    def normalized_select_columns(select_expr: exp.Select):
1✔
748
        cols = []
1✔
749
        for item in select_expr.expressions:
1✔
750
            copy_item = item.copy()
1✔
751
            copy_item.set("alias", None)
1✔
752
            cols.append(copy_item.sql(dialect=dialect, normalize=True))
1✔
753
        return frozenset(cols)
1✔
754

755
    if normalized_select_columns(ast1) != normalized_select_columns(ast2):
1✔
756
        return False
1✔
757

758
    def normalized_clause(expr: exp.Expression, key: str):
1✔
759
        clause = expr.args.get(key)
1✔
760
        return clause.sql(dialect=dialect, normalize=True) if clause else ""
1✔
761

762
    for clause_key in ("from", "where", "group", "having", "order"):
1✔
763
        if normalized_clause(ast1, clause_key) != normalized_clause(ast2, clause_key):
1✔
764
            return False
×
765

766
    return True
1✔
767

768

769
def sql_exact_match(sql1: str, sql2: str) -> bool:
1✔
770
    """Return True if two SQL strings match after very basic normalization."""
771

772
    def normalize_sql(s: str) -> str:
1✔
773
        s = s.strip().rstrip(";")
1✔
774
        s = re.sub(r"\s+", " ", s)
1✔
775
        return s.upper()
1✔
776

777
    return normalize_sql(sql1) == normalize_sql(sql2)
1✔
778

779

780
@dataclass
1✔
781
class SQLExecutionResult:
1✔
782
    execution_accuracy: int  # Whether the predicted and gold SQL results match exactly
1✔
783
    non_empty_execution_accuracy: (
1✔
784
        int  # Same as execution_accuracy but only if gold is non-empty
785
    )
786
    subset_non_empty_execution_accuracy: (
1✔
787
        int  # Whether predicted is a subset of gold or vice versa, non-empty only
788
    )
789
    execution_accuracy_bird: (
1✔
790
        int  # Whether the predicted SQL matches gold using BIRD evaluation logic
791
    )
792
    non_empty_gold_df: int  # Whether the gold SQL produced a non-empty dataframe
1✔
793
    gold_sql_runtime: float  # Time taken to execute the gold SQL
1✔
794
    predicted_sql_runtime: float  # Time taken to execute the predicted SQL
1✔
795
    pred_to_gold_runtime_ratio: float  # Ratio of predicted runtime to gold runtime
1✔
796
    gold_error: int  # Whether the gold SQL had an execution error
1✔
797
    predicted_error: int  # Whether the predicted SQL had an execution error
1✔
798
    gold_df_json: str  # JSON representation of the gold SQL result dataframe
1✔
799
    predicted_df_json: str  # JSON representation of the predicted SQL result dataframe
1✔
800
    error_message: str  # Error message from predicted execution if any
1✔
801

802

803
def compare_dfs_ignore_colnames_ordered_rows(
1✔
804
    df1: pd.DataFrame, df2: pd.DataFrame
805
) -> bool:
806
    if df1.shape != df2.shape:
×
807
        return False
×
808
    df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
×
809
    df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
×
810
    return np.array_equal(df1_sorted_rows, df2_sorted_rows)
×
811

812

813
def compare_dfs_ignore_colnames_unordered_rows(
1✔
814
    df1: pd.DataFrame, df2: pd.DataFrame
815
) -> bool:
816
    if df1.shape != df2.shape:
1✔
817
        return False
×
818
    df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
1✔
819
    df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
1✔
820
    return np.array_equal(df1_sorted, df2_sorted)
1✔
821

822

823
def compare_dfs_ignore_colnames_subset(
1✔
824
    df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
825
) -> bool:
826
    """Checks if the smaller of the two DataFrames is likely a subset of the other.
827

828
    Subset comparison is column-based, to support Text2SQL evaluation for when the
829
    predicted SQL dataframe has missing or additional columns. Each row is treated as
830
    a multiset of (stringified) values, and the function checks if every row in the
831
    smaller DataFrame (by column count) is a multiset subset of the corresponding row
832
    in the larger DataFrame. When ground truth SQL does not have ORDER BY,
833
    ignore_row_order can be set to True to ignore the order of rows. In this case,
834
    column values are sorted before comparison. This means that there could be cases
835
    where the dataframes have the exact same number of rows and column values after
836
    sort are the same, but the dataframes are not actually a subset of each other.
837
    This is unlikely to happen in practice, but the score is not guaranteed to be
838
    100% accurate and may overestimate the accuracy.
839

840
    Args:
841
        df1 (pd.DataFrame): The first DataFrame to compare.
842
        df2 (pd.DataFrame): The second DataFrame to compare.
843
        ignore_row_order (bool, optional): If True, ignores the order of rows by
844
            sorting them before comparison. Defaults to True.
845

846
    Returns:
847
        bool: True if the smaller DataFrame (column-wise) is likely a subset of the
848
            larger one, False otherwise.
849
    """
850

851
    def row_to_multiset(row):
1✔
852
        return Counter(str(x) for x in row)
1✔
853

854
    def rows_to_multisets(df):
1✔
855
        return [row_to_multiset(row) for row in df.values]
1✔
856

857
    def sort_df(df):
1✔
858
        sorted_df = df.copy()
1✔
859
        for i in range(len(sorted_df.columns)):
1✔
860
            sorted_df.iloc[:, i] = (
1✔
861
                sorted_df.iloc[:, i].astype(str).sort_values(ignore_index=True)
862
            )
863
        return sorted_df
1✔
864

865
    if df1.empty or df2.empty or len(df1) != len(df2):
1✔
866
        return False
×
867

868
    df1.columns = range(df1.shape[1])
1✔
869
    df2.columns = range(df2.shape[1])
1✔
870
    subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
1✔
871

872
    if ignore_row_order:
1✔
873
        subset_df = sort_df(subset_df)
1✔
874
        superset_df = sort_df(superset_df)
1✔
875

876
    subset_rows = rows_to_multisets(subset_df)
1✔
877
    superset_rows = rows_to_multisets(superset_df)
1✔
878

879
    for r1, r2 in zip(subset_rows, superset_rows):
1✔
880
        if not all(r1[k] <= r2.get(k, 0) for k in r1):
1✔
881
            return False
×
882
    return True
1✔
883

884

885
def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
1✔
886
    """Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
887

888
    This function checks if the set of rows returned by the predicted SQL query
889
    (`predicted_res`) is exactly equal to the set of rows returned by the ground truth
890
    SQL query (`ground_truth_res`). This is the logic used in the original BIRD
891
    evaluation code:
892
    https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
893
    """
894
    df1_set = {tuple(row) for row in df1.values.astype(str)}
1✔
895
    df2_set = {tuple(row) for row in df2.values.astype(str)}
1✔
896
    return int(df1_set == df2_set)
1✔
897

898

899
def compare_result_dfs(
1✔
900
    gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
901
) -> Tuple[int, int, int]:
902
    """Compares two DataFrames representing SQL query results.
903

904
    Args:
905
        gold_df (pd.DataFrame): The ground truth DataFrame.
906
        pred_df (pd.DataFrame): The predicted DataFrame.
907
        gold_sql (str): The ground truth SQL query string.
908

909
    Returns:
910
        Tuple[int, int, int]: A tuple containing:
911
            - match (int): 1 if the predicted DataFrame matches the gold DataFrame
912
            - non_empty_match (int): 1 if both DataFrames are non-empty and match,
913
              0 otherwise.
914
            - subset_match (int): 1 if the predicted DataFrame is a subset or
915
              superset of the gold DataFrame.
916

917
    Notes:
918
        - The comparison ignores column names.
919
        - Row order is considered only if 'ORDER BY' is present in the SQL query.
920
    """
921
    subset_match = 0
1✔
922
    non_empty_match = 0
1✔
923
    if "ORDER BY" in gold_sql.upper():
1✔
924
        match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
×
925
        if not gold_df.empty and not pred_df.empty:
×
926
            non_empty_match = match
×
927
            if compare_dfs_ignore_colnames_subset(
×
928
                gold_df, pred_df, ignore_row_order=False
929
            ):
930
                subset_match = 1
×
931
    else:
932
        match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
1✔
933
        if not gold_df.empty and not pred_df.empty:
1✔
934
            non_empty_match = match
1✔
935
            if compare_dfs_ignore_colnames_subset(
1✔
936
                gold_df, pred_df, ignore_row_order=True
937
            ):
938
                subset_match = 1
1✔
939
    return match, non_empty_match, subset_match
1✔
940

941

942
def run_query(
1✔
943
    sql: str, connector, sql_timeout: float
944
) -> Tuple[Optional[pd.DataFrame], float, str]:
945
    """Executes a SQL query using the provided connector with a timeout.
946

947
    Args:
948
        sql (str): The SQL query string to execute.
949
        connector: An object with an `execute_query` method that executes the SQL
950
            query.
951
        sql_timeout (float): The maximum time in seconds to allow for query
952
            execution.
953

954
    Returns:
955
        Tuple[Optional[pd.DataFrame], float, str]:
956
            - A pandas DataFrame containing the query results, or None if an error
957
              occurred.
958
            - The duration in seconds taken to execute the query. 0.0 if an error.
959
            - An error message string if an error occurred, otherwise an empty
960
              string.
961

962
    Notes:
963
        - If the SQL string is empty or only whitespace, returns immediately with a
964
          message.
965
        - If the query execution exceeds the timeout, returns a timeout error
966
          message.
967
        - Any other exceptions are caught and returned as error messages.
968
    """
969
    import time
1✔
970

971
    from func_timeout import func_timeout
1✔
972
    from func_timeout.exceptions import FunctionTimedOut
1✔
973

974
    if not sql.strip():
1✔
975
        return None, 0.0, "No SQL query found in the prediction."
×
976

977
    try:
1✔
978
        start = time.perf_counter()
1✔
979
        result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
1✔
980
        duration = time.perf_counter() - start
1✔
981
        if isinstance(result, dict) and "results" in result:
1✔
982
            result = result["results"]
×
983
        if error:
1✔
984
            return None, duration, error
1✔
985
        return pd.DataFrame(result), duration, ""
1✔
986
    except FunctionTimedOut as e:
×
987
        return None, 0.0, f"Timeout: {e}"
×
988
    except Exception as e:
989
        return None, 0.0, f"Error: {e}"
990

991

992
def get_sql_execution_results(
1✔
993
    predicted_sql: str, gold_sql: str, connector, sql_timeout: float
994
) -> SQLExecutionResult:
995
    """Execute and compare predicted and gold SQL queries, returning execution metrics.
996

997
    Args:
998
        predicted_sql (str): The SQL query predicted by the model.
999
        gold_sql (str): The reference (gold) SQL query.
1000
        connector: Database connector object used to execute the queries.
1001
        sql_timeout (float): Maximum time (in seconds) allowed for query execution.
1002

1003
    Returns:
1004
        SQLExecutionResult: An object containing various execution metrics, including:
1005
            - execution_accuracy (int): 1 if predicted and gold queries produce
1006
              equivalent results, else 0.
1007
            - non_empty_execution_accuracy (int): 1 if both queries produce non-empty
1008
              and equivalent results, else 0.
1009
            - subset_non_empty_execution_accuracy (int): 1 if predicted results are a
1010
              subset or superset of gold results and non-empty, else 0. Subset
1011
              comparison is column-based. This means that the predicted SQL dataframe
1012
              can have missing or additional columns compared to the gold SQL dataframe.
1013
            - execution_accuracy_bird (int): 1 if results match according to BIRD
1014
              evaluation logic, else 0.
1015
            - non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
1016
            - gold_sql_runtime (float): Execution time for the gold SQL query.
1017
            - predicted_sql_runtime (float): Execution time for the predicted SQL query.
1018
            - pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
1019
              runtimes.
1020
            - gold_error (int): 1 if the gold query failed, else 0.
1021
            - predicted_error (int): 1 if the predicted query failed, else 0.
1022
            - gold_df_json (str): JSON representation of the gold query result
1023
              DataFrame.
1024
            - predicted_df_json (str): JSON representation of the predicted query
1025
              result DataFrame.
1026
            - error_message (str): Error message if any query failed, else empty
1027
              string.
1028

1029
    Notes:
1030
        - If the gold query fails, the function returns early with error details.
1031
        - If the predicted query is identical or SQL-equivalent to the gold query,
1032
          results are considered correct without execution.
1033
        - Otherwise, both queries are executed and their results compared using
1034
          multiple metrics.
1035
    """
1036
    gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
1✔
1037
    gold_error = int(bool(gold_error_msg))
1✔
1038

1039
    if gold_error:
1✔
1040
        return SQLExecutionResult(
×
1041
            execution_accuracy=0,
1042
            non_empty_execution_accuracy=0,
1043
            subset_non_empty_execution_accuracy=0,
1044
            execution_accuracy_bird=0,
1045
            non_empty_gold_df=0,
1046
            gold_sql_runtime=gold_runtime,
1047
            predicted_sql_runtime=0,
1048
            pred_to_gold_runtime_ratio=0,
1049
            gold_error=gold_error,
1050
            predicted_error=0,
1051
            gold_df_json="",
1052
            predicted_df_json="",
1053
            error_message=gold_error_msg,
1054
        )
1055

1056
    non_empty_gold_df = int(not gold_df.empty)
1✔
1057
    if predicted_sql.strip().lower() == gold_sql.strip().lower():
1✔
1058
        return SQLExecutionResult(
×
1059
            execution_accuracy=1,
1060
            non_empty_execution_accuracy=non_empty_gold_df,
1061
            subset_non_empty_execution_accuracy=non_empty_gold_df,
1062
            execution_accuracy_bird=1,
1063
            non_empty_gold_df=non_empty_gold_df,
1064
            gold_sql_runtime=gold_runtime,
1065
            predicted_sql_runtime=0,
1066
            pred_to_gold_runtime_ratio=0,
1067
            gold_error=0,
1068
            predicted_error=0,
1069
            gold_df_json=gold_df.to_json(),
1070
            predicted_df_json=gold_df.to_json(),
1071
            error_message="",
1072
        )
1073

1074
    try:
1✔
1075
        if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
1✔
1076
            return SQLExecutionResult(
1✔
1077
                execution_accuracy=1,
1078
                non_empty_execution_accuracy=non_empty_gold_df,
1079
                subset_non_empty_execution_accuracy=non_empty_gold_df,
1080
                execution_accuracy_bird=1,
1081
                non_empty_gold_df=non_empty_gold_df,
1082
                gold_sql_runtime=gold_runtime,
1083
                predicted_sql_runtime=0,
1084
                pred_to_gold_runtime_ratio=0,
1085
                gold_error=0,
1086
                predicted_error=0,
1087
                gold_df_json=gold_df.to_json(),
1088
                predicted_df_json=gold_df.to_json(),
1089
                error_message="",
1090
            )
1091
    except Exception as e:
1092
        logger.info(f"Could not check SQL equivalence: {e}")
1093

1094
    pred_df, pred_runtime, pred_error_msg = run_query(
1✔
1095
        predicted_sql, connector, sql_timeout
1096
    )
1097
    pred_error = 1 if pred_error_msg else 0
1✔
1098

1099
    if pred_df is None:
1✔
1100
        return SQLExecutionResult(
1✔
1101
            execution_accuracy=0,
1102
            non_empty_execution_accuracy=0,
1103
            subset_non_empty_execution_accuracy=0,
1104
            execution_accuracy_bird=0,
1105
            non_empty_gold_df=non_empty_gold_df,
1106
            gold_sql_runtime=gold_runtime,
1107
            predicted_sql_runtime=pred_runtime,
1108
            pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1109
            if gold_runtime > 0
1110
            else 0,
1111
            gold_error=0,
1112
            predicted_error=pred_error,
1113
            gold_df_json=gold_df.to_json(),
1114
            predicted_df_json="",
1115
            error_message=pred_error_msg,
1116
        )
1117

1118
    match, non_empty_match, subset_match = compare_result_dfs(
1✔
1119
        gold_df, pred_df, gold_sql
1120
    )
1121
    bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
1✔
1122

1123
    return SQLExecutionResult(
1✔
1124
        execution_accuracy=match,
1125
        non_empty_execution_accuracy=non_empty_match,
1126
        subset_non_empty_execution_accuracy=subset_match,
1127
        execution_accuracy_bird=bird_match,
1128
        non_empty_gold_df=non_empty_gold_df,
1129
        gold_sql_runtime=gold_runtime,
1130
        predicted_sql_runtime=pred_runtime,
1131
        pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1132
        if gold_runtime > 0
1133
        else 0,
1134
        gold_error=0,
1135
        predicted_error=0,
1136
        gold_df_json=gold_df.to_json(),
1137
        predicted_df_json=pred_df.to_json(),
1138
        error_message=pred_error_msg,
1139
    )
1140

1141

1142
def replace_select_clause(
1✔
1143
    source_query: str, target_query: str, dialect: str = "postgres"
1144
) -> str:
1145
    """Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
1146

1147
    Args:
1148
        source_query (str): SQL query whose SELECT clause will be used.
1149
        target_query (str): SQL query whose SELECT clause will be replaced.
1150
        dialect (str): SQL dialect for parsing and rendering (default: "postgres").
1151

1152
    Returns:
1153
        str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
1154

1155
    Raises:
1156
        ValueError: If either query is not a valid SELECT statement.
1157

1158
    Example:
1159
        >>> replace_select_clause(
1160
        ...     "SELECT id FROM employees",
1161
        ...     "SELECT name FROM employees WHERE age > 30"
1162
        ... )
1163
        "SELECT id FROM employees WHERE age > 30"
1164
    """
1165
    from sqlglot import exp, parse_one
×
1166

1167
    if not dialect:
×
1168
        dialect = "postgres"
×
1169

1170
    # Parse queries using the specified dialect
1171
    source_ast = parse_one(source_query, read=dialect)
×
1172
    target_ast = parse_one(target_query, read=dialect)
×
1173

1174
    if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
×
1175
        raise ValueError("Both queries must be valid SELECT statements.")
1176

1177
    # Replace SELECT expressions in the target with those from the source
1178
    target_ast.set("expressions", source_ast.expressions)
×
1179

1180
    # Return the updated SQL string using the dialect
1181
    return target_ast.sql(dialect=dialect)
×
1182

1183

1184
def extract_sql_from_text(text: str) -> str:
1✔
1185
    """Extracts the first SQL query from the given text.
1186

1187
    Priority:
1188
    1. SQL inside fenced blocks like ```sql ... ```
1189
    2. SQL starting on a new line or after a colon/label
1190
    3. SQL without semicolon
1191

1192
    Returns:
1193
        The SQL query string, or an empty string if not found.
1194
    """
1195
    # 1. Look for fenced SQL code block
1196
    fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
1✔
1197
    match = fenced_block_pattern.search(text)
1✔
1198
    if match:
1✔
1199
        return match.group(1).strip()
×
1200

1201
    # 2. Inline SQL with semicolon
1202
    sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
1✔
1203
    sql_start = (
1✔
1204
        r"(?:^|\n|:\s*)"  # Start of string, newline, or colon label like "Just run:"
1205
    )
1206
    sql_pattern = re.compile(
1✔
1207
        rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
1208
    )
1209
    match = sql_pattern.search(text)
1✔
1210
    if match:
1✔
1211
        return match.group(1).strip()
×
1212

1213
    # 3. Inline SQL without semicolon
1214
    fallback_pattern = re.compile(
1✔
1215
        rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
1216
    )
1217
    fallback_match = fallback_pattern.search(text)
1✔
1218
    if fallback_match:
1✔
1219
        return fallback_match.group(1).strip()
1✔
1220

1221
    return ""
×
1222

1223

1224
ALL_DIALECTS = [
1✔
1225
    "Athena",
1226
    "BigQuery",
1227
    "ClickHouse",
1228
    "Databricks",
1229
    "Doris",
1230
    "Drill",
1231
    "Druid",
1232
    "DuckDB",
1233
    "Hive",
1234
    "Materialize",
1235
    "MySQL",
1236
    "Oracle",
1237
    "Postgres",
1238
    "Presto",
1239
    "PRQL",
1240
    "Redshift",
1241
    "RisingWave",
1242
    "Snowflake",
1243
    "Spark",
1244
    "Spark2",
1245
    "SQLite",
1246
    "StarRocks",
1247
    "Tableau",
1248
    "Teradata",
1249
    "Trino",
1250
    "TSQL",
1251
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc