• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

IBM / unitxt / 16290496146

15 Jul 2025 10:13AM UTC coverage: 81.321% (+0.07%) from 81.249%
16290496146

Pull #1861

github

web-flow
Merge 216fef411 into 4e5433a39
Pull Request #1861: Fix compatibility with datasets 4.0

1571 of 1939 branches covered (81.02%)

Branch coverage included in aggregate %.

10615 of 13046 relevant lines covered (81.37%)

0.81 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

73.51
src/unitxt/text2sql_utils.py
1
import functools
1✔
2
import glob
1✔
3
import hashlib
1✔
4
import json
1✔
5
import os
1✔
6
import re
1✔
7
import sqlite3
1✔
8
import time
1✔
9
from abc import ABC, abstractmethod
1✔
10
from collections import Counter
1✔
11
from dataclasses import dataclass
1✔
12
from functools import lru_cache
1✔
13
from typing import Any, List, Optional, Tuple
1✔
14

15
import numpy as np
1✔
16
import pandas as pd
1✔
17
import requests
1✔
18
from huggingface_hub import snapshot_download
1✔
19
from requests.exceptions import ConnectionError, ReadTimeout
1✔
20

21
from .logging_utils import get_logger
1✔
22
from .types import SQLDatabase
1✔
23

24
logger = get_logger()
25

26
# Check if caching is enabled via environment variable
27
CACHE_LOCATION = os.getenv("UNITXT_CACHE_LOCATION")
1✔
28

29
# Set max cache size to 10GB or the value of env var MAX_CACHE_SIZE
30
MAX_CACHE_SIZE = os.getenv("MAX_CACHE_SIZE", 10 * 1024**3)
1✔
31

32
_cache_instance = None
1✔
33

34

35
class DatabaseConnector(ABC):
1✔
36
    """Abstract base class for database connectors."""
37

38
    def __init__(self, db_config: SQLDatabase):
1✔
39
        self.db_config = db_config
1✔
40
        self.databases_folder = os.path.join(
1✔
41
            os.environ.get("UNITXT_CACHE_LOCATION", "cache/text2sql"), "databases"
42
        )
43
        os.makedirs(self.databases_folder, exist_ok=True)
1✔
44

45
    @abstractmethod
1✔
46
    def get_table_schema(
1✔
47
        self,
48
    ) -> str:
49
        """Abstract method to get database schema."""
50
        pass
51

52
    @abstractmethod
1✔
53
    def execute_query(self, query: str) -> Any:
1✔
54
        """Abstract method to execute a query against the database."""
55
        pass
56

57

58
@lru_cache(maxsize=128)
1✔
59
def execute_query_local(db_path: str, query: str) -> Any:
1✔
60
    """Executes a query against the SQLite database."""
61
    conn = None  # Initialize conn to None outside the try block
1✔
62
    try:
1✔
63
        conn = sqlite3.connect(db_path)
1✔
64
        cursor = conn.cursor()
1✔
65
        cursor.execute(query)
1✔
66
        return cursor.fetchall(), None
1✔
67
    except sqlite3.Error as e:
1✔
68
        logger.info(f"Error executing SQL: {e}")
69
        return None, f"Error executing SQL: {e}"
1✔
70
    finally:
71
        if conn:
1✔
72
            conn.close()
1✔
73

74

75
class LocalSQLiteConnector(DatabaseConnector):
1✔
76
    """Database connector for SQLite databases."""
77

78
    def __init__(self, db_config: SQLDatabase):
1✔
79
        super().__init__(db_config)
1✔
80
        db_id = self.db_config.get("db_id")
1✔
81
        if not db_id:
1✔
82
            raise ValueError("db_id is required for SQLiteConnector.")
83
        self.db_path = self.get_db_file_path(db_id)
1✔
84
        self.conn: sqlite3.Connection = sqlite3.connect(self.db_path)
1✔
85
        self.cursor: sqlite3.Cursor = self.conn.cursor()
1✔
86

87
    def download_database(self, db_id):
1✔
88
        """Downloads the database from huggingface if needed."""
89
        done_file_path = os.path.join(self.databases_folder, "download_done")
1✔
90
        if "bird/" in db_id:
1✔
91
            if not os.path.exists(done_file_path):
×
92
                snapshot_download(
×
93
                    repo_id="premai-io/birdbench",
94
                    repo_type="dataset",
95
                    local_dir=self.databases_folder,
96
                    force_download=False,
97
                    allow_patterns="*validation*",
98
                )
99
                open(os.path.join(self.databases_folder, "download_done"), "w").close()
×
100
        else:
101
            raise NotImplementedError(
1✔
102
                f"current local db: {db_id} is not supported, only bird"
103
            )
104

105
    def get_db_file_path(self, db_id):
1✔
106
        """Gets the local path of a downloaded database file."""
107
        self.download_database(db_id)
1✔
108
        db_id = db_id.split("/")[-1]
×
109

110
        db_file_pattern = os.path.join(self.databases_folder, "**", db_id + ".sqlite")
×
111
        db_file_paths = glob.glob(db_file_pattern, recursive=True)
×
112

113
        if not db_file_paths:
×
114
            raise FileNotFoundError(f"Database file {db_id} not found.")
×
115
        if len(db_file_paths) > 1:
×
116
            raise FileExistsError(f"More than one files matched for {db_id}")
×
117
        return db_file_paths[0]
×
118

119
    def get_table_schema(
1✔
120
        self,
121
    ) -> str:
122
        """Extracts schema from an SQLite database."""
123
        self.cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
×
124
        tables: list[tuple[str]] = self.cursor.fetchall()
×
125
        schemas: dict[str, str] = {}
×
126

127
        for table in tables:
×
128
            if isinstance(table, tuple):
×
129
                table = table[0]
×
130
            if table == "sqlite_sequence":
×
131
                continue
×
132
            sql_query: str = (
×
133
                f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table}';"
134
            )
135
            self.cursor.execute(sql_query)
×
136
            schema_prompt: str = self.cursor.fetchone()[0]
×
137

138
            schemas[table] = schema_prompt
×
139

140
        schema_prompt: str = "\n\n".join(list(schemas.values()))
×
141
        return schema_prompt
×
142

143
    def execute_query(self, query: str) -> Any:
1✔
144
        """Executes a query against the SQLite database."""
145
        return execute_query_local(self.db_path, query)
1✔
146

147

148
class InMemoryDatabaseConnector(DatabaseConnector):
1✔
149
    """Database connector for mocking databases with in-memory data structures."""
150

151
    def __init__(self, db_config: SQLDatabase):
1✔
152
        super().__init__(db_config)
1✔
153
        self.tables = db_config.get("data", None)
1✔
154

155
        if not self.tables:
1✔
156
            raise ValueError("data is required for InMemoryDatabaseConnector.")
157

158
    def get_table_schema(
1✔
159
        self,
160
        select_tables: Optional[List[str]] = None,
161
    ) -> str:
162
        """Generates a mock schema from the tables structure."""
163
        schemas = {}
1✔
164
        for table_name, table_data in self.tables.items():
1✔
165
            if select_tables and table_name.lower() not in select_tables:
1✔
166
                continue
1✔
167
            columns = ", ".join([f"`{col}` TEXT" for col in table_data["columns"]])
1✔
168
            schema = f"CREATE TABLE `{table_name}` ({columns});"
1✔
169

170
            schemas[table_name] = schema
1✔
171

172
        return "\n\n".join(list(schemas.values()))
1✔
173

174
    def execute_query(self, query: str) -> Any:
1✔
175
        """Simulates executing a query against the mock database."""
176
        # Initialize in-memory database from the 'tables' dictionary
177
        conn = sqlite3.connect(":memory:")
1✔
178
        cursor = conn.cursor()
1✔
179
        logger.debug("Running SQL query over in-memory DB")
180

181
        # Create tables and insert data from the 'db' dictionary
182
        for table_name, table_data in self.tables.items():
1✔
183
            columns = table_data["columns"]
1✔
184
            rows = table_data["rows"]
1✔
185

186
            # Create table
187
            cursor.execute(f"CREATE TABLE {table_name} ({', '.join(columns)})")
1✔
188

189
            # Insert data
190
            placeholders = ", ".join(["?"] * len(columns))
1✔
191
            cursor.executemany(
1✔
192
                f"INSERT INTO {table_name} VALUES ({placeholders})", rows
193
            )
194

195
        try:
1✔
196
            cursor.execute(query)
1✔
197
            return cursor.fetchall(), None
1✔
198
        except sqlite3.Error as e:
1✔
199
            logger.info(f"Error executing SQL: {e}")
200
            return None, f"Error executing SQL: {e}"
1✔
201
        finally:
202
            conn.close()
1✔
203

204

205
def get_cache():
1✔
206
    """Returns a singleton cache instance, initializing it if necessary."""
207
    global _cache_instance
208
    if _cache_instance is None:
1✔
209
        _cache_instance = Cache()
1✔
210
    return _cache_instance
1✔
211

212

213
def generate_cache_key(*args, **kwargs):
1✔
214
    """Generate a stable hashable cache key for various input types.
215

216
    :param args: Positional arguments of the function.
217
    :param kwargs: Keyword arguments of the function.
218
    :return: A hashed key as a string.
219
    """
220
    try:
1✔
221
        # Convert args and kwargs to a JSON string (sorted to ensure consistency)
222
        serialized = json.dumps(
1✔
223
            {"args": args, "kwargs": kwargs}, sort_keys=True, default=str
224
        )
225
    except TypeError:
×
226
        # Fallback for non-serializable objects
227
        serialized = repr((args, kwargs))
×
228

229
    # Hash the serialized data
230
    return hashlib.md5(serialized.encode()).hexdigest()
1✔
231

232

233
class Cache:
1✔
234
    """A class that provides disk-based caching functionality for a given function."""
235

236
    def __init__(self):
1✔
237
        """Initializes the cache.
238

239
        If `CACHE_LOCATION` (os.getenv("UNITXT_CACHE_LOCATION") is set, a disk-based
240
        cache is created using `diskcache`.
241

242
        Args:
243
            None
244

245
        Returns:
246
            None
247
        """
248
        if CACHE_LOCATION:
1✔
249
            try:
×
250
                import diskcache
×
251

252
                # Ensure the cache directory exists
253
                os.makedirs(CACHE_LOCATION, exist_ok=True)
×
254

255
                # Create a global diskcache Cache instance
256
                self.cache = diskcache.Cache(CACHE_LOCATION, size_limit=MAX_CACHE_SIZE)
×
257
                logger.info(f"Caching enabled at {CACHE_LOCATION}")
258
            except ImportError as e:
×
259
                raise ImportError(
×
260
                    "UNITXT_CACHE_LOCATION is set, but diskcache is not installed.\n"
261
                    "Please install diskcache `pip install diskcache` "
262
                    "or unset UNITXT_CACHE_LOCATION."
263
                ) from e
264
        else:
265
            self.cache = None  # Disable caching
1✔
266

267
    def get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
268
        if not self.cache or no_cache:
1✔
269
            logger.info(f"Bypassing cache for key: {key}")
270
            return compute_fn()
1✔
271

272
        if refresh and key in self.cache:
×
273
            logger.info(f"Refreshing cache for key: {key}")
274
            del self.cache[key]
×
275

276
        if key in self.cache:
×
277
            logger.info(f"Cache hit for key: {key}")
278
            return self.cache[key]
×
279

280
        logger.info(f"Cache miss for key: {key}. Computing value...")
281
        result = compute_fn()
×
282

283
        if result and not (
×
284
            isinstance(result, tuple) and len(result) == 2 and result[0] is None
285
        ):
286
            self.cache[key] = result
×
287
            logger.info(f"Stored result in cache for key: {key}")
288
        else:
289
            logger.info(f"None result. Bypassing caching for key: {key}")
290

291
        return result
×
292

293
    async def async_get_or_set(self, key, compute_fn, no_cache=False, refresh=False):
1✔
294
        if not self.cache or no_cache:
×
295
            logger.info(f"Bypassing cache for key: {key}")
296
            return await compute_fn()
×
297

298
        if refresh and key in self.cache:
×
299
            logger.info(f"Refreshing cache for key: {key}")
300
            del self.cache[key]
×
301

302
        if key in self.cache:
×
303
            logger.info(f"Cache hit for key: {key}")
304
            return self.cache[key]
×
305

306
        logger.info(f"Cache miss for key: {key}. Computing value asynchronously...")
307
        result = await compute_fn()
×
308
        self.cache[key] = result
×
309
        logger.info(f"Stored result in cache for key: {key}")
310
        return result
×
311

312
    def memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
313
        def decorator(func):
×
314
            @functools.wraps(func)
×
315
            def wrapper(*args, **kwargs):
×
316
                if not self.cache or no_cache:
×
317
                    logger.info(f"Bypassing cache for function: {func.__name__}")
318
                    return func(*args, **kwargs)
×
319

320
                key = key_func(func.__name__, *args, **kwargs)
×
321

322
                if refresh and key in self.cache:
×
323
                    logger.info(
324
                        f"Refreshing cache for function: {func.__name__}, key: {key}"
325
                    )
326
                    del self.cache[key]
×
327

328
                if key in self.cache:
×
329
                    logger.info(f"Cache hit for function: {func.__name__}, key: {key}")
330
                    return self.cache[key]
×
331

332
                logger.info(
333
                    f"Cache miss for function: {func.__name__}, key: {key}. Computing value..."
334
                )
335
                result = func(*args, **kwargs)
×
336
                self.cache[key] = result
×
337
                logger.info(
338
                    f"Stored result in cache for function: {func.__name__}, key: {key}"
339
                )
340
                return result
×
341

342
            return wrapper
×
343

344
        return decorator
×
345

346
    def async_memoize(self, key_func=generate_cache_key, no_cache=False, refresh=False):
1✔
347
        def decorator(func):
×
348
            @functools.wraps(func)
×
349
            async def wrapper(*args, **kwargs):
×
350
                if no_cache:
×
351
                    logger.info(f"Bypassing cache for async function: {func.__name__}")
352
                    return await func(*args, **kwargs)
×
353

354
                key = key_func(func.__name__, *args, **kwargs)
×
355

356
                if refresh and key in self.cache:
×
357
                    logger.info(
358
                        f"Refreshing cache for async function: {func.__name__}, key: {key}"
359
                    )
360
                    del self.cache[key]
×
361

362
                if key in self.cache:
×
363
                    logger.info(
364
                        f"Cache hit for async function: {func.__name__}, key: {key}"
365
                    )
366
                    return self.cache[key]
×
367

368
                logger.info(
369
                    f"Cache miss for async function: {func.__name__}, key: {key}. Computing value..."
370
                )
371
                result = await func(*args, **kwargs)
×
372
                self.cache[key] = result
×
373
                logger.info(
374
                    f"Stored result in cache for async function: {func.__name__}, key: {key}"
375
                )
376
                return result
×
377

378
            return wrapper
×
379

380
        return decorator
×
381

382

383
@lru_cache(maxsize=128)
1✔
384
def execute_query_remote(
1✔
385
    api_url: str,
386
    database_id: str,
387
    api_key: str,
388
    query: str,
389
    retryable_exceptions: tuple = (ConnectionError, ReadTimeout),
390
    max_retries: int = 3,
391
    retry_delay: int = 5,  # seconds
392
    timeout: int = 30,  # seconds
393
) -> (Optional[dict], str):
394
    """Executes a query against the remote database, with retries for certain exceptions."""
395
    headers = {
1✔
396
        "Content-Type": "application/json",
397
        "accept": "application/json",
398
        "Authorization": f"Bearer {api_key}",
399
    }
400
    retries = 0
1✔
401
    while retries <= max_retries:
1✔
402
        try:
1✔
403
            response = requests.post(
1✔
404
                f"{api_url}/sql",
405
                headers=headers,
406
                json={"sql": query, "dataSourceId": database_id},
407
                verify=False,
408
                timeout=timeout,
409
            )
410
            response.raise_for_status()
×
411
            return response.json(), None
×
412

413
        except retryable_exceptions as e:
1✔
414
            retries += 1
×
415
            logger.warning(
416
                f"Attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
417
            )
418
            if retries <= max_retries:
×
419
                time.sleep(retry_delay)
×
420
            else:
421
                logger.error(f"Max retries ({max_retries}) exceeded for query: {query}")
422
                return (
×
423
                    None,
424
                    f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
425
                )
426

427
        except requests.exceptions.HTTPError as e:
1✔
428
            if e.response.status_code >= 500:
×
429
                retries += 1
×
430
                logger.warning(
431
                    f"Server error, attempt {retries} failed with error: {e}. Retrying in {retry_delay} seconds."
432
                )
433
                if retries <= max_retries:
×
434
                    time.sleep(retry_delay)
×
435
                else:
436
                    logger.error(
437
                        f"Max retries ({max_retries}) exceeded for query: {query}"
438
                    )
439
                    return (
×
440
                        None,
441
                        f"Max retries ({max_retries}) exceeded for query: {query} - Error: {e!s}",
442
                    )
443
            else:
444
                logger.error(f"HTTP Error on attempt {retries}: {e}")
445
                return (
×
446
                    None,
447
                    f"HTTP Error on attempt {retries}: {e}",
448
                )
449

450
        except Exception as e:
451
            logger.error(f"Unexpected error on attempt {retries}: {e}")
452
            return (None, f"Unexpected error on attempt {retries}: {e}")
453

454
    return None, "Unknown Error in SQL execution"
×
455

456

457
class RemoteDatabaseConnector(DatabaseConnector):
1✔
458
    """Database connector for remote databases accessed via HTTP."""
459

460
    def __init__(self, db_config: SQLDatabase):
1✔
461
        super().__init__(db_config)
1✔
462

463
        assert db_config[
1✔
464
            "db_id"
465
        ], "db_id must be in db_config for RemoteDatabaseConnector"
466
        self.api_url, self.database_id = (
1✔
467
            db_config["db_id"].split(",")[0],
468
            db_config["db_id"].split("db_id=")[-1].split(",")[0],
469
        )
470

471
        if not self.api_url or not self.database_id:
1✔
472
            raise ValueError(
473
                "Both 'api_url' and 'database_id' are required for RemoteDatabaseConnector."
474
            )
475

476
        self.api_key = os.getenv("SQL_API_KEY", None)
1✔
477
        if not self.api_key:
1✔
478
            raise ValueError(
479
                "The environment variable 'SQL_API_KEY' must be set to use the RemoteDatabaseConnector."
480
            )
481

482
        self.headers = {
1✔
483
            "Content-Type": "application/json",
484
            "accept": "application/json",
485
            "Authorization": f"Bearer {self.api_key}",
486
        }
487

488
        self.timeout = 30
1✔
489

490
    def get_table_schema(
1✔
491
        self,
492
    ) -> str:
493
        """Retrieves the schema of a database."""
494
        cur_api_url = f"{self.api_url}/datasources/{self.database_id}"
1✔
495
        response = requests.get(
1✔
496
            cur_api_url,
497
            headers=self.headers,
498
            verify=False,
499
            timeout=self.timeout,
500
        )
501
        if response.status_code == 200:
1✔
502
            schema = response.json()["schema"]
×
503
        else:
504
            raise OSError(f"Could not fetch schema from {cur_api_url}")
1✔
505

506
        schema_text = ""
×
507
        for table in schema["tables"]:
×
508
            schema_text += f"Table: {table['name'] if 'name' in table else table['table_name']} has columns: {[col['name'] if 'name' in col else col['column_name'] for col in table['columns']]}\n"
×
509

510
        return schema_text
×
511

512
    def execute_query(self, query: str) -> Any:
1✔
513
        """Executes a query against the remote database, with retries for certain exceptions."""
514
        cache = get_cache()
1✔
515

516
        cache_key = generate_cache_key(
1✔
517
            "sql_request", self.api_url, self.database_id, query
518
        )
519
        return cache.get_or_set(
1✔
520
            cache_key,
521
            lambda: execute_query_remote(
522
                api_url=self.api_url,
523
                database_id=self.database_id,
524
                api_key=self.api_key,
525
                query=query,
526
                timeout=self.timeout,
527
            ),
528
        )
529

530

531
def get_db_connector(db_type: str):
1✔
532
    """Creates and returns the appropriate DatabaseConnector instance based on db_type."""
533
    if db_type == "local":
1✔
534
        connector = LocalSQLiteConnector
×
535
    elif db_type == "in_memory":
1✔
536
        connector = InMemoryDatabaseConnector
1✔
537
    elif db_type == "remote":
1✔
538
        connector = RemoteDatabaseConnector
×
539

540
    else:
541
        raise ValueError(f"Unsupported database type: {db_type}")
542

543
    return connector
1✔
544

545

546
@dataclass
1✔
547
class SQLNonExecutionMetricResult:
1✔
548
    sqlglot_validity: int  # Whether SQL parses with sqlglot
1✔
549
    sqlparse_validity: int  # Whether SQL parses with sqlparse
1✔
550
    sqlglot_equivalence: int  # Semantic equivalence using sqlglot AST
1✔
551
    sqlglot_optimized_equivalence: int  # Equivalence after optimization via sqlglot
1✔
552
    sqlparse_equivalence: int  # Equivalence using sqlparse AST
1✔
553
    sql_exact_match: int  # Exact string match of predicted and gold SQL
1✔
554
    sql_syntactic_equivalence: int  # Any of the above equivalence conditions hold
1✔
555

556

557
def is_sqlglot_parsable(sql: str, db_type="sqlite") -> bool:
1✔
558
    """Returns True if sqlglot does not encounter any error, False otherwise."""
559
    from sqlglot import parse
1✔
560

561
    if not sql.strip():
1✔
562
        return False
×
563
    if db_type == "db2":
1✔
564
        db_type = "postgres"  ## TODO: temporary until sqlglot adds support for db2
×
565
    try:
1✔
566
        parse(sql, read=db_type)
1✔
567
        return True
1✔
568
    except Exception as e:
569
        logger.debug(f"SQL query could not parse: {e}")
570
        return False
571

572

573
def is_sqlparse_parsable(sql: str) -> bool:
1✔
574
    """Returns True if sqlparse does not encounter any error, False otherwise."""
575
    from sqlparse import parse
1✔
576
    from sqlparse.tokens import Error
1✔
577

578
    if not sql.strip():
1✔
579
        return False
×
580
    try:
1✔
581
        statements = parse(sql)
1✔
582
        for statement in statements:
1✔
583
            for token in statement.tokens:
1✔
584
                if token.ttype == Error:
1✔
585
                    return False
×
586
        return True
1✔
587
    except Exception as e:
588
        logger.debug(f"SQL query could not parse: {e}")
589
        return False
590

591

592
def sqlglot_optimized_equivalence(expected: str, generated: str) -> int:
1✔
593
    """Checks if SQL queries are equivalent using SQLGlot parsing, so we don't run them."""
594
    from sqlglot import diff, parse_one
1✔
595
    from sqlglot.optimizer import optimize
1✔
596

597
    try:
1✔
598
        t_diff = diff(
1✔
599
            optimize(parse_one(expected.lower()).sql(pretty=True)),
600
            optimize(parse_one(generated.lower()).sql(pretty=True)),
601
        )
602
        sql_diff = sum(0 if (e.__class__.__name__ == "Keep") else 1 for e in t_diff)
1✔
603

604
        return 1 if sql_diff == 0 else 0
1✔
605
    except Exception as e:
606
        logger.debug(f"Error parsing SQL for comparison: {e}")
607
        return False
608

609

610
def extract_select_columns(statement):
1✔
611
    """Parse SQL using sqlparse and extract columns."""
612
    from sqlparse.sql import Identifier, IdentifierList
1✔
613
    from sqlparse.tokens import DML, Keyword
1✔
614

615
    columns = []
1✔
616
    select_seen = False
1✔
617
    for token in statement.tokens:
1✔
618
        if token.ttype is DML and token.value.upper() == "SELECT":
1✔
619
            select_seen = True
1✔
620
            continue
1✔
621
        if select_seen:
1✔
622
            if token.ttype is Keyword and token.value.upper() in (
1✔
623
                "FROM",
624
                "WHERE",
625
                "GROUP",
626
                "HAVING",
627
                "ORDER",
628
                "LIMIT",
629
            ):
630
                break
1✔
631
            if isinstance(token, IdentifierList):
1✔
632
                for identifier in token.get_identifiers():
1✔
633
                    columns.append(strip_alias(identifier.value))
1✔
634
            elif isinstance(token, Identifier):
1✔
635
                columns.append(strip_alias(token.value))
1✔
636
            else:
637
                val = token.value.strip()
1✔
638
                if val:
1✔
639
                    columns.append(strip_alias(val))
1✔
640
    return frozenset(columns)
1✔
641

642

643
def strip_alias(col: str) -> str:
1✔
644
    """Remove any AS alias from a column."""
645
    col = col.strip()
1✔
646
    upper = col.upper()
1✔
647
    if " AS " in upper:
1✔
648
        return col[: upper.index(" AS ")].strip()
1✔
649
    parts_alias = col.split()
1✔
650
    if len(parts_alias) > 1:
1✔
651
        return " ".join(parts_alias[:-1])
×
652
    return col
1✔
653

654

655
def collect_clause(statement, clause_keyword):
1✔
656
    """Parse SQL statement and collect clauses."""
657
    from sqlparse.tokens import Keyword
1✔
658

659
    found = False
1✔
660
    collected = []
1✔
661
    for token in statement.tokens:
1✔
662
        tvalue = token.value.upper()
1✔
663
        if token.ttype is Keyword:
1✔
664
            if tvalue.startswith(clause_keyword):
1✔
665
                found = True
1✔
666
                continue
1✔
667
            if found and tvalue in (
1✔
668
                "FROM",
669
                "WHERE",
670
                "GROUP",
671
                "HAVING",
672
                "ORDER",
673
                "LIMIT",
674
            ):
675
                break
×
676
        if found:
1✔
677
            collected.append(token.value)
1✔
678
    return " ".join(collected).strip()
1✔
679

680

681
def extract_select_info(sql: str):
1✔
682
    """Parse SQL using sqlparse and return a dict of extracted columns and clauses."""
683
    from sqlparse import parse
1✔
684
    from sqlparse.tokens import DML
1✔
685

686
    statements = parse(sql)
1✔
687
    if len(statements) != 1:
1✔
688
        return None
×
689
    stmt = statements[0]
1✔
690
    if not any(t.ttype is DML and t.value.upper() == "SELECT" for t in stmt.tokens):
1✔
691
        return None
×
692
    parts = {
1✔
693
        "columns": None,
694
        "from": "",
695
        "where": "",
696
        "group": "",
697
        "having": "",
698
        "order": "",
699
    }
700
    columns = extract_select_columns(stmt)
1✔
701
    if not columns:
1✔
702
        columns = frozenset()
×
703
    parts["columns"] = columns
1✔
704
    parts["from"] = collect_clause(stmt, "FROM")
1✔
705
    parts["where"] = collect_clause(stmt, "WHERE")
1✔
706
    parts["group"] = collect_clause(stmt, "GROUP")
1✔
707
    parts["having"] = collect_clause(stmt, "HAVING")
1✔
708
    parts["order"] = collect_clause(stmt, "ORDER")
1✔
709
    return parts
1✔
710

711

712
def sqlparse_queries_equivalent(sql1: str, sql2: str) -> bool:
1✔
713
    """Returns True if both SQL queries are naively considered equivalent."""
714
    try:
1✔
715
        info1 = extract_select_info(sql1)
1✔
716
        info2 = extract_select_info(sql2)
1✔
717
        if not info1 or not info2:
1✔
718
            return False
×
719
        if info1["columns"] != info2["columns"]:
1✔
720
            return False
1✔
721
        for k in ["from", "where", "group", "having", "order"]:
1✔
722
            if info1[k].replace(" ", "").upper() != info2[k].replace(" ", "").upper():
1✔
723
                return False
1✔
724
        return True
×
725
    except Exception as e:
726
        logger.debug(f"Errpr parsing SQL query for comparison: {e}")
727
        return False
728

729

730
def sqlglot_parsed_queries_equivalent(sql1: str, sql2: str, dialect: str = "") -> bool:
1✔
731
    """Return True if two SQL queries match after parsing with SQLGlot."""
732
    from sqlglot import exp, parse_one
1✔
733

734
    try:
1✔
735
        ast1 = parse_one(sql1, read=dialect)
1✔
736
        ast2 = parse_one(sql2, read=dialect)
1✔
737
    except:
×
738
        return False
×
739
    if not (isinstance(ast1, exp.Select) and isinstance(ast2, exp.Select)):
1✔
740
        return False
×
741

742
    def normalized_select_columns(select_expr: exp.Select):
1✔
743
        cols = []
1✔
744
        for item in select_expr.expressions:
1✔
745
            copy_item = item.copy()
1✔
746
            copy_item.set("alias", None)
1✔
747
            cols.append(copy_item.sql(dialect=dialect, normalize=True))
1✔
748
        return frozenset(cols)
1✔
749

750
    if normalized_select_columns(ast1) != normalized_select_columns(ast2):
1✔
751
        return False
1✔
752

753
    def normalized_clause(expr: exp.Expression, key: str):
1✔
754
        clause = expr.args.get(key)
1✔
755
        return clause.sql(dialect=dialect, normalize=True) if clause else ""
1✔
756

757
    for clause_key in ("from", "where", "group", "having", "order"):
1✔
758
        if normalized_clause(ast1, clause_key) != normalized_clause(ast2, clause_key):
1✔
759
            return False
×
760

761
    return True
1✔
762

763

764
def sql_exact_match(sql1: str, sql2: str) -> bool:
1✔
765
    """Return True if two SQL strings match after very basic normalization."""
766

767
    def normalize_sql(s: str) -> str:
1✔
768
        s = s.strip().rstrip(";")
1✔
769
        s = re.sub(r"\s+", " ", s)
1✔
770
        return s.upper()
1✔
771

772
    return normalize_sql(sql1) == normalize_sql(sql2)
1✔
773

774

775
@dataclass
1✔
776
class SQLExecutionResult:
1✔
777
    execution_accuracy: int  # Whether the predicted and gold SQL results match exactly
1✔
778
    non_empty_execution_accuracy: (
1✔
779
        int  # Same as execution_accuracy but only if gold is non-empty
780
    )
781
    subset_non_empty_execution_accuracy: (
1✔
782
        int  # Whether predicted is a subset of gold or vice versa, non-empty only
783
    )
784
    execution_accuracy_bird: (
1✔
785
        int  # Whether the predicted SQL matches gold using BIRD evaluation logic
786
    )
787
    non_empty_gold_df: int  # Whether the gold SQL produced a non-empty dataframe
1✔
788
    gold_sql_runtime: float  # Time taken to execute the gold SQL
1✔
789
    predicted_sql_runtime: float  # Time taken to execute the predicted SQL
1✔
790
    pred_to_gold_runtime_ratio: float  # Ratio of predicted runtime to gold runtime
1✔
791
    gold_error: int  # Whether the gold SQL had an execution error
1✔
792
    predicted_error: int  # Whether the predicted SQL had an execution error
1✔
793
    gold_df_json: str  # JSON representation of the gold SQL result dataframe
1✔
794
    predicted_df_json: str  # JSON representation of the predicted SQL result dataframe
1✔
795
    error_message: str  # Error message from predicted execution if any
1✔
796

797

798
def compare_dfs_ignore_colnames_ordered_rows(
1✔
799
    df1: pd.DataFrame, df2: pd.DataFrame
800
) -> bool:
801
    if df1.shape != df2.shape:
×
802
        return False
×
803
    df1_sorted_rows = np.array([np.sort(row) for row in df1.values.astype(str)])
×
804
    df2_sorted_rows = np.array([np.sort(row) for row in df2.values.astype(str)])
×
805
    return np.array_equal(df1_sorted_rows, df2_sorted_rows)
×
806

807

808
def compare_dfs_ignore_colnames_unordered_rows(
1✔
809
    df1: pd.DataFrame, df2: pd.DataFrame
810
) -> bool:
811
    if df1.shape != df2.shape:
1✔
812
        return False
×
813
    df1_sorted = np.sort(np.sort(df1.values.astype(str), axis=1), axis=0)
1✔
814
    df2_sorted = np.sort(np.sort(df2.values.astype(str), axis=1), axis=0)
1✔
815
    return np.array_equal(df1_sorted, df2_sorted)
1✔
816

817

818
def compare_dfs_ignore_colnames_subset(
1✔
819
    df1: pd.DataFrame, df2: pd.DataFrame, ignore_row_order: bool = True
820
) -> bool:
821
    """Checks if the smaller of the two DataFrames is likely a subset of the other.
822

823
    Subset comparison is column-based, to support Text2SQL evaluation for when the
824
    predicted SQL dataframe has missing or additional columns. Each row is treated as
825
    a multiset of (stringified) values, and the function checks if every row in the
826
    smaller DataFrame (by column count) is a multiset subset of the corresponding row
827
    in the larger DataFrame. When ground truth SQL does not have ORDER BY,
828
    ignore_row_order can be set to True to ignore the order of rows. In this case,
829
    column values are sorted before comparison. This means that there could be cases
830
    where the dataframes have the exact same number of rows and column values after
831
    sort are the same, but the dataframes are not actually a subset of each other.
832
    This is unlikely to happen in practice, but the score is not guaranteed to be
833
    100% accurate and may overestimate the accuracy.
834

835
    Args:
836
        df1 (pd.DataFrame): The first DataFrame to compare.
837
        df2 (pd.DataFrame): The second DataFrame to compare.
838
        ignore_row_order (bool, optional): If True, ignores the order of rows by
839
            sorting them before comparison. Defaults to True.
840

841
    Returns:
842
        bool: True if the smaller DataFrame (column-wise) is likely a subset of the
843
            larger one, False otherwise.
844
    """
845

846
    def row_to_multiset(row):
1✔
847
        return Counter(str(x) for x in row)
1✔
848

849
    def rows_to_multisets(df):
1✔
850
        return [row_to_multiset(row) for row in df.values]
1✔
851

852
    def sort_df(df):
1✔
853
        sorted_df = df.copy()
1✔
854
        for col in sorted_df.columns:
1✔
855
            sorted_df[col] = sorted_df[col].astype(str).sort_values(ignore_index=True)
1✔
856
        return sorted_df
1✔
857

858
    if df1.empty or df2.empty or len(df1) != len(df2):
1✔
859
        return False
×
860

861
    subset_df, superset_df = (df1, df2) if df1.shape[1] <= df2.shape[1] else (df2, df1)
1✔
862

863
    if ignore_row_order:
1✔
864
        subset_df = sort_df(subset_df)
1✔
865
        superset_df = sort_df(superset_df)
1✔
866

867
    subset_rows = rows_to_multisets(subset_df)
1✔
868
    superset_rows = rows_to_multisets(superset_df)
1✔
869

870
    for r1, r2 in zip(subset_rows, superset_rows):
1✔
871
        if not all(r1[k] <= r2.get(k, 0) for k in r1):
1✔
872
            return False
×
873
    return True
1✔
874

875

876
def compare_dfs_bird_eval_logic(df1: pd.DataFrame, df2: pd.DataFrame):
1✔
877
    """Check if two SQL query result sets are exactly equal, as in BIRD evaluation.
878

879
    This function checks if the set of rows returned by the predicted SQL query
880
    (`predicted_res`) is exactly equal to the set of rows returned by the ground truth
881
    SQL query (`ground_truth_res`). This is the logic used in the original BIRD
882
    evaluation code:
883
    https://github.com/AlibabaResearch/DAMO-ConvAI/blob/main/bird/llm/src/evaluation.py.
884
    """
885
    df1_set = {tuple(row) for row in df1.values.astype(str)}
1✔
886
    df2_set = {tuple(row) for row in df2.values.astype(str)}
1✔
887
    return int(df1_set == df2_set)
1✔
888

889

890
def compare_result_dfs(
1✔
891
    gold_df: pd.DataFrame, pred_df: pd.DataFrame, gold_sql: str
892
) -> Tuple[int, int, int]:
893
    """Compares two DataFrames representing SQL query results.
894

895
    Args:
896
        gold_df (pd.DataFrame): The ground truth DataFrame.
897
        pred_df (pd.DataFrame): The predicted DataFrame.
898
        gold_sql (str): The ground truth SQL query string.
899

900
    Returns:
901
        Tuple[int, int, int]: A tuple containing:
902
            - match (int): 1 if the predicted DataFrame matches the gold DataFrame
903
            - non_empty_match (int): 1 if both DataFrames are non-empty and match,
904
              0 otherwise.
905
            - subset_match (int): 1 if the predicted DataFrame is a subset or
906
              superset of the gold DataFrame.
907

908
    Notes:
909
        - The comparison ignores column names.
910
        - Row order is considered only if 'ORDER BY' is present in the SQL query.
911
    """
912
    subset_match = 0
1✔
913
    non_empty_match = 0
1✔
914
    if "ORDER BY" in gold_sql.upper():
1✔
915
        match = int(compare_dfs_ignore_colnames_ordered_rows(pred_df, gold_df))
×
916
        if not gold_df.empty and not pred_df.empty:
×
917
            non_empty_match = match
×
918
            if compare_dfs_ignore_colnames_subset(
×
919
                gold_df, pred_df, ignore_row_order=False
920
            ):
921
                subset_match = 1
×
922
    else:
923
        match = int(compare_dfs_ignore_colnames_unordered_rows(pred_df, gold_df))
1✔
924
        if not gold_df.empty and not pred_df.empty:
1✔
925
            non_empty_match = match
1✔
926
            if compare_dfs_ignore_colnames_subset(
1✔
927
                gold_df, pred_df, ignore_row_order=True
928
            ):
929
                subset_match = 1
1✔
930
    return match, non_empty_match, subset_match
1✔
931

932

933
def run_query(
1✔
934
    sql: str, connector, sql_timeout: float
935
) -> Tuple[Optional[pd.DataFrame], float, str]:
936
    """Executes a SQL query using the provided connector with a timeout.
937

938
    Args:
939
        sql (str): The SQL query string to execute.
940
        connector: An object with an `execute_query` method that executes the SQL
941
            query.
942
        sql_timeout (float): The maximum time in seconds to allow for query
943
            execution.
944

945
    Returns:
946
        Tuple[Optional[pd.DataFrame], float, str]:
947
            - A pandas DataFrame containing the query results, or None if an error
948
              occurred.
949
            - The duration in seconds taken to execute the query. 0.0 if an error.
950
            - An error message string if an error occurred, otherwise an empty
951
              string.
952

953
    Notes:
954
        - If the SQL string is empty or only whitespace, returns immediately with a
955
          message.
956
        - If the query execution exceeds the timeout, returns a timeout error
957
          message.
958
        - Any other exceptions are caught and returned as error messages.
959
    """
960
    import time
1✔
961

962
    from func_timeout import func_timeout
1✔
963
    from func_timeout.exceptions import FunctionTimedOut
1✔
964

965
    if not sql.strip():
1✔
966
        return None, 0.0, "No SQL query found in the prediction."
×
967

968
    try:
1✔
969
        start = time.perf_counter()
1✔
970
        result, error = func_timeout(sql_timeout, connector.execute_query, args=(sql,))
1✔
971
        duration = time.perf_counter() - start
1✔
972
        if isinstance(result, dict) and "results" in result:
1✔
973
            result = result["results"]
×
974
        if error:
1✔
975
            return None, duration, error
1✔
976
        return pd.DataFrame(result), duration, ""
1✔
977
    except FunctionTimedOut as e:
×
978
        return None, 0.0, f"Timeout: {e}"
×
979
    except Exception as e:
980
        return None, 0.0, f"Error: {e}"
981

982

983
def get_sql_execution_results(
1✔
984
    predicted_sql: str, gold_sql: str, connector, sql_timeout: float
985
) -> SQLExecutionResult:
986
    """Execute and compare predicted and gold SQL queries, returning execution metrics.
987

988
    Args:
989
        predicted_sql (str): The SQL query predicted by the model.
990
        gold_sql (str): The reference (gold) SQL query.
991
        connector: Database connector object used to execute the queries.
992
        sql_timeout (float): Maximum time (in seconds) allowed for query execution.
993

994
    Returns:
995
        SQLExecutionResult: An object containing various execution metrics, including:
996
            - execution_accuracy (int): 1 if predicted and gold queries produce
997
              equivalent results, else 0.
998
            - non_empty_execution_accuracy (int): 1 if both queries produce non-empty
999
              and equivalent results, else 0.
1000
            - subset_non_empty_execution_accuracy (int): 1 if predicted results are a
1001
              subset or superset of gold results and non-empty, else 0. Subset
1002
              comparison is column-based. This means that the predicted SQL dataframe
1003
              can have missing or additional columns compared to the gold SQL dataframe.
1004
            - execution_accuracy_bird (int): 1 if results match according to BIRD
1005
              evaluation logic, else 0.
1006
            - non_empty_gold_df (int): 1 if the gold query result is non-empty, else 0.
1007
            - gold_sql_runtime (float): Execution time for the gold SQL query.
1008
            - predicted_sql_runtime (float): Execution time for the predicted SQL query.
1009
            - pred_to_gold_runtime_ratio (float): Ratio of predicted to gold query
1010
              runtimes.
1011
            - gold_error (int): 1 if the gold query failed, else 0.
1012
            - predicted_error (int): 1 if the predicted query failed, else 0.
1013
            - gold_df_json (str): JSON representation of the gold query result
1014
              DataFrame.
1015
            - predicted_df_json (str): JSON representation of the predicted query
1016
              result DataFrame.
1017
            - error_message (str): Error message if any query failed, else empty
1018
              string.
1019

1020
    Notes:
1021
        - If the gold query fails, the function returns early with error details.
1022
        - If the predicted query is identical or SQL-equivalent to the gold query,
1023
          results are considered correct without execution.
1024
        - Otherwise, both queries are executed and their results compared using
1025
          multiple metrics.
1026
    """
1027
    gold_df, gold_runtime, gold_error_msg = run_query(gold_sql, connector, sql_timeout)
1✔
1028
    gold_error = int(bool(gold_error_msg))
1✔
1029

1030
    if gold_error:
1✔
1031
        return SQLExecutionResult(
×
1032
            execution_accuracy=0,
1033
            non_empty_execution_accuracy=0,
1034
            subset_non_empty_execution_accuracy=0,
1035
            execution_accuracy_bird=0,
1036
            non_empty_gold_df=0,
1037
            gold_sql_runtime=gold_runtime,
1038
            predicted_sql_runtime=0,
1039
            pred_to_gold_runtime_ratio=0,
1040
            gold_error=gold_error,
1041
            predicted_error=0,
1042
            gold_df_json="",
1043
            predicted_df_json="",
1044
            error_message=gold_error_msg,
1045
        )
1046

1047
    non_empty_gold_df = int(not gold_df.empty)
1✔
1048
    if predicted_sql.strip().lower() == gold_sql.strip().lower():
1✔
1049
        return SQLExecutionResult(
×
1050
            execution_accuracy=1,
1051
            non_empty_execution_accuracy=non_empty_gold_df,
1052
            subset_non_empty_execution_accuracy=non_empty_gold_df,
1053
            execution_accuracy_bird=1,
1054
            non_empty_gold_df=non_empty_gold_df,
1055
            gold_sql_runtime=gold_runtime,
1056
            predicted_sql_runtime=0,
1057
            pred_to_gold_runtime_ratio=0,
1058
            gold_error=0,
1059
            predicted_error=0,
1060
            gold_df_json=gold_df.to_json(),
1061
            predicted_df_json=gold_df.to_json(),
1062
            error_message="",
1063
        )
1064

1065
    try:
1✔
1066
        if sqlglot_optimized_equivalence(gold_sql, predicted_sql):
1✔
1067
            return SQLExecutionResult(
1✔
1068
                execution_accuracy=1,
1069
                non_empty_execution_accuracy=non_empty_gold_df,
1070
                subset_non_empty_execution_accuracy=non_empty_gold_df,
1071
                execution_accuracy_bird=1,
1072
                non_empty_gold_df=non_empty_gold_df,
1073
                gold_sql_runtime=gold_runtime,
1074
                predicted_sql_runtime=0,
1075
                pred_to_gold_runtime_ratio=0,
1076
                gold_error=0,
1077
                predicted_error=0,
1078
                gold_df_json=gold_df.to_json(),
1079
                predicted_df_json=gold_df.to_json(),
1080
                error_message="",
1081
            )
1082
    except Exception as e:
1083
        logger.info(f"Could not check SQL equivalence: {e}")
1084

1085
    pred_df, pred_runtime, pred_error_msg = run_query(
1✔
1086
        predicted_sql, connector, sql_timeout
1087
    )
1088
    pred_error = 1 if pred_error_msg else 0
1✔
1089

1090
    if pred_df is None:
1✔
1091
        return SQLExecutionResult(
1✔
1092
            execution_accuracy=0,
1093
            non_empty_execution_accuracy=0,
1094
            subset_non_empty_execution_accuracy=0,
1095
            execution_accuracy_bird=0,
1096
            non_empty_gold_df=non_empty_gold_df,
1097
            gold_sql_runtime=gold_runtime,
1098
            predicted_sql_runtime=pred_runtime,
1099
            pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1100
            if gold_runtime > 0
1101
            else 0,
1102
            gold_error=0,
1103
            predicted_error=pred_error,
1104
            gold_df_json=gold_df.to_json(),
1105
            predicted_df_json="",
1106
            error_message=pred_error_msg,
1107
        )
1108

1109
    match, non_empty_match, subset_match = compare_result_dfs(
1✔
1110
        gold_df, pred_df, gold_sql
1111
    )
1112
    bird_match = compare_dfs_bird_eval_logic(gold_df, pred_df)
1✔
1113

1114
    return SQLExecutionResult(
1✔
1115
        execution_accuracy=match,
1116
        non_empty_execution_accuracy=non_empty_match,
1117
        subset_non_empty_execution_accuracy=subset_match,
1118
        execution_accuracy_bird=bird_match,
1119
        non_empty_gold_df=non_empty_gold_df,
1120
        gold_sql_runtime=gold_runtime,
1121
        predicted_sql_runtime=pred_runtime,
1122
        pred_to_gold_runtime_ratio=(pred_runtime / gold_runtime)
1123
        if gold_runtime > 0
1124
        else 0,
1125
        gold_error=0,
1126
        predicted_error=0,
1127
        gold_df_json=gold_df.to_json(),
1128
        predicted_df_json=pred_df.to_json(),
1129
        error_message=pred_error_msg,
1130
    )
1131

1132

1133
def replace_select_clause(
1✔
1134
    source_query: str, target_query: str, dialect: str = "postgres"
1135
) -> str:
1136
    """Replaces the SELECT clause of the target SQL query with the SELECT clause from the source SQL query.
1137

1138
    Args:
1139
        source_query (str): SQL query whose SELECT clause will be used.
1140
        target_query (str): SQL query whose SELECT clause will be replaced.
1141
        dialect (str): SQL dialect for parsing and rendering (default: "postgres").
1142

1143
    Returns:
1144
        str: A new SQL query with the SELECT clause of `target_query` replaced by that of `source_query`.
1145

1146
    Raises:
1147
        ValueError: If either query is not a valid SELECT statement.
1148

1149
    Example:
1150
        >>> replace_select_clause(
1151
        ...     "SELECT id FROM employees",
1152
        ...     "SELECT name FROM employees WHERE age > 30"
1153
        ... )
1154
        "SELECT id FROM employees WHERE age > 30"
1155
    """
1156
    from sqlglot import exp, parse_one
×
1157

1158
    if not dialect:
×
1159
        dialect = "postgres"
×
1160

1161
    # Parse queries using the specified dialect
1162
    source_ast = parse_one(source_query, read=dialect)
×
1163
    target_ast = parse_one(target_query, read=dialect)
×
1164

1165
    if not isinstance(source_ast, exp.Select) or not isinstance(target_ast, exp.Select):
×
1166
        raise ValueError("Both queries must be valid SELECT statements.")
1167

1168
    # Replace SELECT expressions in the target with those from the source
1169
    target_ast.set("expressions", source_ast.expressions)
×
1170

1171
    # Return the updated SQL string using the dialect
1172
    return target_ast.sql(dialect=dialect)
×
1173

1174

1175
def extract_sql_from_text(text: str) -> str:
1✔
1176
    """Extracts the first SQL query from the given text.
1177

1178
    Priority:
1179
    1. SQL inside fenced blocks like ```sql ... ```
1180
    2. SQL starting on a new line or after a colon/label
1181
    3. SQL without semicolon
1182

1183
    Returns:
1184
        The SQL query string, or an empty string if not found.
1185
    """
1186
    # 1. Look for fenced SQL code block
1187
    fenced_block_pattern = re.compile(r"```sql\s+(.*?)```", re.IGNORECASE | re.DOTALL)
1✔
1188
    match = fenced_block_pattern.search(text)
1✔
1189
    if match:
1✔
1190
        return match.group(1).strip()
×
1191

1192
    # 2. Inline SQL with semicolon
1193
    sql_keywords = r"(?:SELECT|INSERT|UPDATE|DELETE|WITH)\s+"
1✔
1194
    sql_start = (
1✔
1195
        r"(?:^|\n|:\s*)"  # Start of string, newline, or colon label like "Just run:"
1196
    )
1197
    sql_pattern = re.compile(
1✔
1198
        rf"{sql_start}({sql_keywords}.*?;)", re.IGNORECASE | re.DOTALL
1199
    )
1200
    match = sql_pattern.search(text)
1✔
1201
    if match:
1✔
1202
        return match.group(1).strip()
×
1203

1204
    # 3. Inline SQL without semicolon
1205
    fallback_pattern = re.compile(
1✔
1206
        rf"{sql_start}({sql_keywords}.*)", re.IGNORECASE | re.DOTALL
1207
    )
1208
    fallback_match = fallback_pattern.search(text)
1✔
1209
    if fallback_match:
1✔
1210
        return fallback_match.group(1).strip()
1✔
1211

1212
    return ""
×
1213

1214

1215
ALL_DIALECTS = [
1✔
1216
    "Athena",
1217
    "BigQuery",
1218
    "ClickHouse",
1219
    "Databricks",
1220
    "Doris",
1221
    "Drill",
1222
    "Druid",
1223
    "DuckDB",
1224
    "Hive",
1225
    "Materialize",
1226
    "MySQL",
1227
    "Oracle",
1228
    "Postgres",
1229
    "Presto",
1230
    "PRQL",
1231
    "Redshift",
1232
    "RisingWave",
1233
    "Snowflake",
1234
    "Spark",
1235
    "Spark2",
1236
    "SQLite",
1237
    "StarRocks",
1238
    "Tableau",
1239
    "Teradata",
1240
    "Trino",
1241
    "TSQL",
1242
]
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc