• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

joaoh82 / rust_sqlite / 25040177500

28 Apr 2026 07:37AM UTC coverage: 70.28% (+0.6%) from 69.705%
25040177500

push

github

web-flow
Phase 7d.2: SQL integration for HNSW (CREATE INDEX, KNN probe) (#50)

Wires the Phase 7d.1 HNSW algorithm into the SQL surface so users
can finally do KNN queries through the engine:

  CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(384));
  INSERT INTO docs (e) VALUES ([0.1, 0.2, ...]);
  CREATE INDEX ix_e ON docs USING hnsw (e);
  SELECT id FROM docs
  ORDER BY vec_distance_l2(e, [0.5, 0.3, ...])
  LIMIT 10;

Five integration points:

  1. **Table::hnsw_indexes** — new Vec<HnswIndexEntry> field
     parallel to secondary_indexes. HnswIndexEntry holds
     {name, column_name, index} where index is the in-memory
     HnswIndex from 7d.1. Auto-cloned by deep_clone (Phase 4f
     transaction snapshot path).

  2. **CREATE INDEX … USING hnsw (col)** parser + executor
     branch. Routes through new IndexMethod::{Btree, Hnsw}
     dispatch in execute_create_index. Validates the column is
     VECTOR(N) and rejects UNIQUE on HNSW (no semantic meaning).
     Pre-populates the graph from existing rows. Distance metric
     is L2 by default; cosine and dot are 7d.x follow-ups
     (would need either WITH (metric=…) clause or hnsw_cosine
     / hnsw_dot named methods — not in 7d.2 scope).

  3. **INSERT incremental maintenance** in Table::insert_row.
     After writing a row, calls Table::maintain_hnsw_on_insert
     for any HNSW index whose column matches. Uses a snapshot
     of current row storage as the get_vec source so the
     algorithm's neighbor lookups don't fight with the
     &mut self.hnsw_indexes borrow.

  4. **DELETE / UPDATE refusal** in execute_delete /
     execute_update. HNSW lacks an in-place delete-node
     primitive and we don't yet have a graph-rebuild trigger
     (that's part of 7d.3 with persistence). Fail with a
     helpful message that points at the workaround
     (DROP the index, mutate, re-CREATE) and explicitly
     mentions Phase 7d.3 as the follow-up.

  5. **Query optimizer hook** in execute_select_rows. Before
     ... (continued)

218 of 251 new or added lines in 4 files covered. (86.85%)

3 existing lines in 1 file now uncovered.

4973 of 7076 relevant lines covered (70.28%)

1.45 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.89
/src/sql/executor.rs
1
//! Query executors — evaluate parsed SQL statements against the in-memory
2
//! storage and produce formatted output.
3

4
use std::cmp::Ordering;
5

6
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7
use sqlparser::ast::{
8
    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9
    FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10
    TableWithJoins, UnaryOperator, Update,
11
};
12

13
use crate::error::{Result, SQLRiteError};
14
use crate::sql::db::database::Database;
15
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16
use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17
use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18
use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19

20
/// Executes a parsed `SelectQuery` against the database and returns a
21
/// human-readable rendering of the result set (prettytable). Also returns
22
/// the number of rows produced, for the top-level status message.
23
/// Structured result of a SELECT: column names in projection order,
24
/// and each matching row as a `Vec<Value>` aligned with the columns.
25
/// Phase 5a introduced this so the public `Connection` / `Statement`
26
/// API has typed rows to yield; the existing `execute_select` that
27
/// returns pre-rendered text is now a thin wrapper on top.
28
pub struct SelectResult {
29
    pub columns: Vec<String>,
30
    pub rows: Vec<Vec<Value>>,
31
}
32

33
/// Executes a SELECT and returns structured rows. The typed rows are
34
/// what the new public API streams to callers; the REPL / Tauri app
35
/// pre-render into a prettytable via `execute_select`.
36
pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
1✔
37
    let table = db
4✔
38
        .get_table(query.table_name.clone())
2✔
39
        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
4✔
40

41
    // Resolve projection to a concrete ordered column list.
42
    let projected_cols: Vec<String> = match &query.projection {
1✔
43
        Projection::All => table.column_names(),
1✔
44
        Projection::Columns(cols) => {
1✔
45
            for c in cols {
2✔
46
                if !table.contains_column(c.to_string()) {
2✔
47
                    return Err(SQLRiteError::Internal(format!(
1✔
48
                        "Column '{c}' does not exist on table '{}'",
49
                        query.table_name
50
                    )));
51
                }
52
            }
53
            cols.clone()
1✔
54
        }
55
    };
56

57
    // Collect matching rowids. If the WHERE is the shape `col = literal`
58
    // and `col` has a secondary index, probe the index for an O(log N)
59
    // seek; otherwise fall back to the full table scan.
60
    let matching = match select_rowids(table, query.selection.as_ref())? {
2✔
61
        RowidSource::IndexProbe(rowids) => rowids,
1✔
62
        RowidSource::FullScan => {
63
            let mut out = Vec::new();
1✔
64
            for rowid in table.rowids() {
3✔
65
                if let Some(expr) = &query.selection {
2✔
66
                    if !eval_predicate(expr, table, rowid)? {
2✔
67
                        continue;
68
                    }
69
                }
70
                out.push(rowid);
2✔
71
            }
72
            out
1✔
73
        }
74
    };
75
    let mut matching = matching;
1✔
76

77
    // Phase 7c — bounded-heap top-k optimization.
78
    //
79
    // The naive "ORDER BY <expr>" path (Phase 7b) sorts every matching
80
    // rowid: O(N log N) sort_by + a truncate. For KNN queries
81
    //
82
    //     SELECT id FROM docs
83
    //     ORDER BY vec_distance_l2(embedding, [...])
84
    //     LIMIT 10;
85
    //
86
    // N is the table row count and k is the LIMIT. With a bounded
87
    // max-heap of size k we can find the top-k in O(N log k) — same
88
    // sort_by-per-row cost on the heap operations, but k is typically
89
    // 10-100 while N can be millions.
90
    //
91
    // Phase 7d.2 — HNSW ANN probe.
92
    //
93
    // Even better than the bounded heap: if the ORDER BY expression is
94
    // exactly `vec_distance_l2(<col>, <bracket-array literal>)` AND
95
    // `<col>` has an HNSW index attached, skip the linear scan
96
    // entirely and probe the graph in O(log N). Approximate but
97
    // typically ≥ 0.95 recall (verified by the recall tests in
98
    // src/sql/hnsw.rs).
99
    //
100
    // We branch in cases:
101
    //   1. ORDER BY + LIMIT k matches the HNSW probe pattern  → graph probe.
102
    //   2. ORDER BY + LIMIT k where k < |matching|            → bounded heap (7c).
103
    //   3. ORDER BY without LIMIT, or LIMIT >= |matching|     → full sort.
104
    //   4. LIMIT without ORDER BY                              → just truncate.
105
    match (&query.order_by, query.limit) {
2✔
106
        (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
3✔
107
            matching = try_hnsw_probe(table, &order.expr, k).unwrap();
1✔
108
        }
109
        (Some(order), Some(k)) if k < matching.len() => {
1✔
110
            matching = select_topk(&matching, table, order, k)?;
1✔
111
        }
112
        (Some(order), _) => {
×
113
            sort_rowids(&mut matching, table, order)?;
×
114
            if let Some(k) = query.limit {
×
115
                matching.truncate(k);
×
116
            }
117
        }
118
        (None, Some(k)) => {
×
119
            matching.truncate(k);
×
120
        }
121
        (None, None) => {}
122
    }
123

124
    // Build typed rows. Missing cells surface as `Value::Null` — that
125
    // maps a column-not-present-for-this-rowid case onto the public
126
    // `Row::get` → `Option<T>` surface cleanly.
127
    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
2✔
128
    for rowid in &matching {
2✔
129
        let row: Vec<Value> = projected_cols
1✔
130
            .iter()
131
            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
3✔
132
            .collect();
133
        rows.push(row);
1✔
134
    }
135

136
    Ok(SelectResult {
1✔
137
        columns: projected_cols,
1✔
138
        rows,
1✔
139
    })
140
}
141

142
/// Executes a SELECT and returns `(rendered_table, row_count)`. The
143
/// REPL and Tauri app use this to keep the table-printing behaviour
144
/// the engine has always shipped. Structured callers use
145
/// `execute_select_rows` instead.
146
pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
1✔
147
    let result = execute_select_rows(query, db)?;
1✔
148
    let row_count = result.rows.len();
2✔
149

150
    let mut print_table = PrintTable::new();
1✔
151
    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
4✔
152
    print_table.add_row(PrintRow::new(header_cells));
1✔
153

154
    for row in &result.rows {
1✔
155
        let cells: Vec<PrintCell> = row
1✔
156
            .iter()
157
            .map(|v| PrintCell::new(&v.to_display_string()))
3✔
158
            .collect();
159
        print_table.add_row(PrintRow::new(cells));
1✔
160
    }
161

162
    Ok((print_table.to_string(), row_count))
1✔
163
}
164

165
/// Executes a DELETE statement. Returns the number of rows removed.
166
pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
1✔
167
    let Statement::Delete(Delete {
1✔
168
        from, selection, ..
1✔
169
    }) = stmt
1✔
170
    else {
171
        return Err(SQLRiteError::Internal(
×
172
            "execute_delete called on a non-DELETE statement".to_string(),
×
173
        ));
174
    };
175

176
    let tables = match from {
1✔
177
        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
2✔
178
    };
179
    let table_name = extract_single_table_name(tables)?;
1✔
180

181
    // Phase 7d.2 limitation: HNSW lacks an in-place delete-node operation.
182
    // True deletion needs either soft-delete + tombstones or a graph rebuild
183
    // — both nontrivial. Until 7d.3 lands persistence we don't have a
184
    // natural rebuild trigger either. So: refuse DELETE on tables carrying
185
    // any HNSW index, with a message that points at the workaround
186
    // (DROP the index, DELETE, recreate).
187
    {
188
        let table = db.get_table(table_name.clone()).map_err(|_| {
2✔
NEW
189
            SQLRiteError::General(format!("DELETE references unknown table '{table_name}'"))
×
190
        })?;
191
        if !table.hnsw_indexes.is_empty() {
1✔
192
            let names: Vec<&str> = table.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
4✔
193
            return Err(SQLRiteError::NotImplemented(format!(
2✔
194
                "DELETE on tables with HNSW indexes is not supported yet \
195
                 (Phase 7d.3 follow-up). DROP the index first, then DELETE, then re-CREATE. \
196
                 Table '{table_name}' currently has: {names:?}"
197
            )));
198
        }
199
    }
200

201
    // Compute matching rowids with an immutable borrow, then mutate.
202
    let matching: Vec<i64> = {
203
        let table = db
1✔
204
            .get_table(table_name.clone())
2✔
205
            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1✔
206
        match select_rowids(table, selection.as_ref())? {
1✔
207
            RowidSource::IndexProbe(rowids) => rowids,
×
208
            RowidSource::FullScan => {
209
                let mut out = Vec::new();
1✔
210
                for rowid in table.rowids() {
3✔
211
                    if let Some(expr) = selection {
2✔
212
                        if !eval_predicate(expr, table, rowid)? {
2✔
213
                            continue;
214
                        }
215
                    }
216
                    out.push(rowid);
2✔
217
                }
218
                out
1✔
219
            }
220
        }
221
    };
222

223
    let table = db.get_table_mut(table_name)?;
2✔
224
    for rowid in &matching {
1✔
225
        table.delete_row(*rowid);
2✔
226
    }
227
    Ok(matching.len())
1✔
228
}
229

230
/// Executes an UPDATE statement. Returns the number of rows updated.
231
pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
1✔
232
    let Statement::Update(Update {
1✔
233
        table,
1✔
234
        assignments,
1✔
235
        from,
1✔
236
        selection,
1✔
237
        ..
238
    }) = stmt
1✔
239
    else {
240
        return Err(SQLRiteError::Internal(
×
241
            "execute_update called on a non-UPDATE statement".to_string(),
×
242
        ));
243
    };
244

245
    if from.is_some() {
1✔
246
        return Err(SQLRiteError::NotImplemented(
×
247
            "UPDATE ... FROM is not supported yet".to_string(),
×
248
        ));
249
    }
250

251
    let table_name = extract_table_name(table)?;
1✔
252

253
    // Phase 7d.2 limitation (same shape as DELETE above): we have no
254
    // in-place UPDATE-an-HNSW-node primitive. UPDATE on a column NOT
255
    // covered by HNSW is fine in principle, but the simplest MVP is
256
    // refuse-everything-when-HNSW-is-present. Re-evaluate in 7d.3 once
257
    // persistence + rebuild is in.
258
    {
259
        let tbl = db.get_table(table_name.clone()).map_err(|_| {
2✔
NEW
260
            SQLRiteError::General(format!("UPDATE references unknown table '{table_name}'"))
×
261
        })?;
262
        if !tbl.hnsw_indexes.is_empty() {
1✔
263
            let names: Vec<&str> = tbl.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
4✔
264
            return Err(SQLRiteError::NotImplemented(format!(
2✔
265
                "UPDATE on tables with HNSW indexes is not supported yet \
266
                 (Phase 7d.3 follow-up). DROP the index first if you need to mutate. \
267
                 Table '{table_name}' currently has: {names:?}"
268
            )));
269
        }
270
    }
271

272
    // Resolve assignment targets to plain column names and verify they exist.
273
    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
2✔
274
    {
275
        let tbl = db
1✔
276
            .get_table(table_name.clone())
2✔
277
            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1✔
278
        for a in assignments {
2✔
279
            let col = match &a.target {
1✔
280
                AssignmentTarget::ColumnName(name) => name
2✔
281
                    .0
282
                    .last()
1✔
283
                    .map(|p| p.to_string())
3✔
284
                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
1✔
285
                AssignmentTarget::Tuple(_) => {
286
                    return Err(SQLRiteError::NotImplemented(
×
287
                        "tuple assignment targets are not supported".to_string(),
×
288
                    ));
289
                }
290
            };
291
            if !tbl.contains_column(col.clone()) {
2✔
292
                return Err(SQLRiteError::Internal(format!(
×
293
                    "UPDATE references unknown column '{col}'"
294
                )));
295
            }
296
            parsed_assignments.push((col, a.value.clone()));
1✔
297
        }
298
    }
299

300
    // Gather matching rowids + the new values to write for each assignment, under
301
    // an immutable borrow. Uses the index-probe fast path when the WHERE is
302
    // `col = literal` on an indexed column.
303
    let work: Vec<(i64, Vec<(String, Value)>)> = {
304
        let tbl = db.get_table(table_name.clone())?;
1✔
305
        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
1✔
306
            RowidSource::IndexProbe(rowids) => rowids,
1✔
307
            RowidSource::FullScan => {
308
                let mut out = Vec::new();
1✔
309
                for rowid in tbl.rowids() {
3✔
310
                    if let Some(expr) = selection {
2✔
311
                        if !eval_predicate(expr, tbl, rowid)? {
2✔
312
                            continue;
313
                        }
314
                    }
315
                    out.push(rowid);
2✔
316
                }
317
                out
1✔
318
            }
319
        };
320
        let mut rows_to_update = Vec::new();
1✔
321
        for rowid in matched_rowids {
4✔
322
            let mut values = Vec::with_capacity(parsed_assignments.len());
2✔
323
            for (col, expr) in &parsed_assignments {
3✔
324
                // UPDATE's RHS is evaluated in the context of the row being updated,
325
                // so column references on the right resolve to the current row's values.
326
                let v = eval_expr(expr, tbl, rowid)?;
2✔
327
                values.push((col.clone(), v));
2✔
328
            }
329
            rows_to_update.push((rowid, values));
1✔
330
        }
331
        rows_to_update
1✔
332
    };
333

334
    let tbl = db.get_table_mut(table_name)?;
2✔
335
    for (rowid, values) in &work {
1✔
336
        for (col, v) in values {
2✔
337
            tbl.set_value(col, *rowid, v.clone())?;
1✔
338
        }
339
    }
340
    Ok(work.len())
1✔
341
}
342

343
/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> [USING <method>] (<column>)`.
344
/// Single-column indexes only.
345
///
346
/// Two flavours, branching on the optional `USING <method>` clause:
347
///   - **No USING, or `USING btree`**: regular B-Tree secondary index
348
///     (Phase 3e). Indexable types: Integer, Text.
349
///   - **`USING hnsw`**: HNSW ANN index (Phase 7d.2). Indexable types:
350
///     Vector(N) only. Distance metric is L2 by default; cosine and
351
///     dot variants are deferred to Phase 7d.x.
352
///
353
/// Returns the (possibly synthesized) index name for the status message.
354
pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
1✔
355
    let Statement::CreateIndex(CreateIndex {
1✔
356
        name,
1✔
357
        table_name,
1✔
358
        columns,
1✔
359
        using,
1✔
360
        unique,
1✔
361
        if_not_exists,
1✔
362
        predicate,
1✔
363
        ..
364
    }) = stmt
1✔
365
    else {
366
        return Err(SQLRiteError::Internal(
×
367
            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
×
368
        ));
369
    };
370

371
    if predicate.is_some() {
1✔
372
        return Err(SQLRiteError::NotImplemented(
×
373
            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
×
374
        ));
375
    }
376

377
    if columns.len() != 1 {
1✔
378
        return Err(SQLRiteError::NotImplemented(format!(
×
379
            "multi-column indexes are not supported yet ({} columns given)",
380
            columns.len()
×
381
        )));
382
    }
383

384
    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
3✔
385
        SQLRiteError::NotImplemented(
×
386
            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
×
387
        )
388
    })?;
389

390
    // Detect USING <method>. The `using` field on CreateIndex covers the
391
    // pre-column form `CREATE INDEX … USING hnsw (col)`. (sqlparser also
392
    // accepts a post-column form `… (col) USING hnsw` and parks that in
393
    // `index_options`; we don't bother with it — the canonical form is
394
    // pre-column and matches PG/pgvector convention.)
395
    let method = match using {
1✔
396
        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
2✔
397
            IndexMethod::Hnsw
1✔
398
        }
NEW
399
        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
×
NEW
400
            IndexMethod::Btree
×
401
        }
NEW
402
        Some(other) => {
×
NEW
403
            return Err(SQLRiteError::NotImplemented(format!(
×
404
                "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
405
            )));
406
        }
407
        None => IndexMethod::Btree,
1✔
408
    };
409

410
    let table_name_str = table_name.to_string();
1✔
411
    let column_name = match &columns[0].column.expr {
2✔
412
        Expr::Identifier(ident) => ident.value.clone(),
2✔
413
        Expr::CompoundIdentifier(parts) => parts
×
414
            .last()
×
415
            .map(|p| p.value.clone())
×
416
            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
×
417
        other => {
×
418
            return Err(SQLRiteError::NotImplemented(format!(
×
419
                "CREATE INDEX only supports simple column references, got {other:?}"
420
            )));
421
        }
422
    };
423

424
    // Validate: table exists, column exists, type matches the index method,
425
    // name is unique across both index kinds. Snapshot (rowid, value) pairs
426
    // up front under the immutable borrow so the mutable attach later
427
    // doesn't fight over `self`.
428
    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
1✔
429
        let table = db.get_table(table_name_str.clone()).map_err(|_| {
2✔
430
            SQLRiteError::General(format!(
×
431
                "CREATE INDEX references unknown table '{table_name_str}'"
432
            ))
433
        })?;
434
        if !table.contains_column(column_name.clone()) {
1✔
435
            return Err(SQLRiteError::General(format!(
×
436
                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
437
            )));
438
        }
439
        let col = table
3✔
440
            .columns
441
            .iter()
442
            .find(|c| c.column_name == column_name)
3✔
443
            .expect("we just verified the column exists");
444

445
        // Name uniqueness check spans BOTH index kinds — a btree and an
446
        // hnsw can't share a name.
447
        if table.index_by_name(&index_name).is_some()
1✔
448
            || table.hnsw_indexes.iter().any(|i| i.name == index_name)
4✔
449
        {
450
            if *if_not_exists {
1✔
451
                return Ok(index_name);
1✔
452
            }
453
            return Err(SQLRiteError::General(format!(
2✔
454
                "index '{index_name}' already exists"
455
            )));
456
        }
457
        let datatype = clone_datatype(&col.datatype);
1✔
458

459
        let mut pairs = Vec::new();
1✔
460
        for rowid in table.rowids() {
3✔
461
            if let Some(v) = table.get_value(&column_name, rowid) {
2✔
462
                pairs.push((rowid, v));
1✔
463
            }
464
        }
465
        (datatype, pairs)
1✔
466
    };
467

468
    match method {
1✔
469
        IndexMethod::Btree => create_btree_index(
470
            db,
471
            &table_name_str,
1✔
472
            &index_name,
1✔
473
            &column_name,
1✔
474
            &datatype,
475
            *unique,
1✔
476
            &existing_rowids_and_values,
1✔
477
        ),
478
        IndexMethod::Hnsw => create_hnsw_index(
479
            db,
480
            &table_name_str,
1✔
481
            &index_name,
1✔
482
            &column_name,
1✔
483
            &datatype,
484
            *unique,
1✔
485
            &existing_rowids_and_values,
1✔
486
        ),
487
    }
488
}
489

490
/// `USING <method>` choices recognized by `execute_create_index`. A
491
/// missing USING clause defaults to `Btree` so existing CREATE INDEX
492
/// statements (Phase 3e) keep working unchanged.
493
#[derive(Debug, Clone, Copy)]
494
enum IndexMethod {
495
    Btree,
496
    Hnsw,
497
}
498

499
/// Builds a Phase 3e B-Tree secondary index and attaches it to the table.
500
fn create_btree_index(
1✔
501
    db: &mut Database,
502
    table_name: &str,
503
    index_name: &str,
504
    column_name: &str,
505
    datatype: &DataType,
506
    unique: bool,
507
    existing: &[(i64, Value)],
508
) -> Result<String> {
509
    let mut idx = SecondaryIndex::new(
3✔
510
        index_name.to_string(),
1✔
511
        table_name.to_string(),
2✔
512
        column_name.to_string(),
1✔
513
        datatype,
514
        unique,
515
        IndexOrigin::Explicit,
516
    )?;
517

518
    // Populate from existing rows. UNIQUE violations here mean the
519
    // existing data already breaks the new index's constraint — a
520
    // common source of user confusion, so be explicit.
521
    for (rowid, v) in existing {
2✔
522
        if unique && idx.would_violate_unique(v) {
2✔
523
            return Err(SQLRiteError::General(format!(
1✔
524
                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
525
                 already contains the duplicate value {}",
526
                v.to_display_string()
1✔
527
            )));
528
        }
529
        idx.insert(v, *rowid)?;
2✔
530
    }
531

532
    let table_mut = db.get_table_mut(table_name.to_string())?;
1✔
533
    table_mut.secondary_indexes.push(idx);
1✔
534
    Ok(index_name.to_string())
1✔
535
}
536

537
/// Builds a Phase 7d.2 HNSW index and attaches it to the table.
538
fn create_hnsw_index(
1✔
539
    db: &mut Database,
540
    table_name: &str,
541
    index_name: &str,
542
    column_name: &str,
543
    datatype: &DataType,
544
    unique: bool,
545
    existing: &[(i64, Value)],
546
) -> Result<String> {
547
    // HNSW only makes sense on VECTOR columns. Reject anything else
548
    // with a clear message — this is the most likely user error.
549
    let dim = match datatype {
1✔
550
        DataType::Vector(d) => *d,
1✔
551
        other => {
1✔
552
            return Err(SQLRiteError::General(format!(
1✔
553
                "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
554
            )));
555
        }
556
    };
557

558
    if unique {
1✔
NEW
559
        return Err(SQLRiteError::General(
×
NEW
560
            "UNIQUE has no meaning for HNSW indexes".to_string(),
×
561
        ));
562
    }
563

564
    // Build the in-memory graph. Distance metric is L2 by default
565
    // (Phase 7d.2 doesn't yet expose a knob for picking cosine/dot —
566
    // see `docs/phase-7-plan.md` for the deferral).
567
    //
568
    // Seed: hash the index name so different indexes get different
569
    // graph topologies, but the same index always gets the same one
570
    // — useful when debugging recall / index size.
571
    let seed = hash_str_to_seed(index_name);
1✔
572
    let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
1✔
573

574
    // Snapshot the (rowid, vector) pairs into a side map so the
575
    // get_vec closure below can serve them by id without re-borrowing
576
    // the table (we're already holding `existing` — flatten it).
577
    let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1✔
578
        std::collections::HashMap::with_capacity(existing.len());
579
    for (rowid, v) in existing {
2✔
580
        match v {
1✔
581
            Value::Vector(vec) => {
1✔
582
                if vec.len() != dim {
1✔
NEW
583
                    return Err(SQLRiteError::Internal(format!(
×
584
                        "row {rowid} stores a {}-dim vector in column '{column_name}' \
585
                         declared as VECTOR({dim}) — schema invariant violated",
NEW
586
                        vec.len()
×
587
                    )));
588
                }
589
                vec_map.insert(*rowid, vec.clone());
2✔
590
            }
591
            // Non-vector values (theoretical NULL, type coercion bug)
592
            // get skipped — they wouldn't have a sensible graph
593
            // position anyway.
594
            _ => continue,
595
        }
596
    }
597

598
    for (rowid, _) in existing {
1✔
599
        if let Some(v) = vec_map.get(rowid) {
2✔
600
            let v_clone = v.clone();
1✔
601
            idx.insert(*rowid, &v_clone, |id| {
3✔
602
                vec_map.get(&id).cloned().unwrap_or_default()
1✔
603
            });
604
        }
605
    }
606

607
    let table_mut = db.get_table_mut(table_name.to_string())?;
1✔
608
    table_mut.hnsw_indexes.push(HnswIndexEntry {
2✔
609
        name: index_name.to_string(),
1✔
610
        column_name: column_name.to_string(),
1✔
611
        index: idx,
1✔
612
    });
613
    Ok(index_name.to_string())
1✔
614
}
615

616
/// Stable, deterministic hash of a string into a u64 RNG seed. FNV-1a;
617
/// avoids pulling in `std::hash::DefaultHasher` (which is randomized
618
/// per process).
619
fn hash_str_to_seed(s: &str) -> u64 {
1✔
620
    let mut h: u64 = 0xCBF29CE484222325;
1✔
621
    for b in s.as_bytes() {
2✔
622
        h ^= *b as u64;
1✔
623
        h = h.wrapping_mul(0x100000001B3);
1✔
624
    }
625
    h
1✔
626
}
627

628
/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
629
/// because the enum has no ergonomic reason to be cloneable elsewhere.
630
fn clone_datatype(dt: &DataType) -> DataType {
1✔
631
    match dt {
1✔
632
        DataType::Integer => DataType::Integer,
1✔
633
        DataType::Text => DataType::Text,
1✔
634
        DataType::Real => DataType::Real,
×
635
        DataType::Bool => DataType::Bool,
×
636
        DataType::Vector(dim) => DataType::Vector(*dim),
1✔
637
        DataType::None => DataType::None,
×
638
        DataType::Invalid => DataType::Invalid,
×
639
    }
640
}
641

642
fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1✔
643
    if tables.len() != 1 {
1✔
644
        return Err(SQLRiteError::NotImplemented(
×
645
            "multi-table DELETE is not supported yet".to_string(),
×
646
        ));
647
    }
648
    extract_table_name(&tables[0])
2✔
649
}
650

651
fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1✔
652
    if !twj.joins.is_empty() {
1✔
653
        return Err(SQLRiteError::NotImplemented(
×
654
            "JOIN is not supported yet".to_string(),
×
655
        ));
656
    }
657
    match &twj.relation {
1✔
658
        TableFactor::Table { name, .. } => Ok(name.to_string()),
1✔
659
        _ => Err(SQLRiteError::NotImplemented(
×
660
            "only plain table references are supported".to_string(),
×
661
        )),
662
    }
663
}
664

665
/// Tells the executor how to produce its candidate rowid list.
666
enum RowidSource {
667
    /// The WHERE was simple enough to probe a secondary index directly.
668
    /// The `Vec` already contains exactly the rows the index matched;
669
    /// no further WHERE evaluation is needed (the probe is precise).
670
    IndexProbe(Vec<i64>),
671
    /// No applicable index; caller falls back to walking `table.rowids()`
672
    /// and evaluating the WHERE on each row.
673
    FullScan,
674
}
675

676
/// Try to satisfy `WHERE` with an index probe. Currently supports the
677
/// simplest shape: a single `col = literal` (or `literal = col`) where
678
/// `col` is on a secondary index. AND/OR/range predicates fall back to
679
/// full scan — those can be layered on later without changing the caller.
680
fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1✔
681
    let Some(expr) = selection else {
1✔
682
        return Ok(RowidSource::FullScan);
1✔
683
    };
684
    let Some((col, literal)) = try_extract_equality(expr) else {
2✔
685
        return Ok(RowidSource::FullScan);
1✔
686
    };
687
    let Some(idx) = table.index_for_column(&col) else {
2✔
688
        return Ok(RowidSource::FullScan);
1✔
689
    };
690

691
    // Convert the literal into a runtime Value. If the literal type doesn't
692
    // match the column's index we still need correct semantics — evaluate
693
    // the WHERE against every row. Fall back to full scan.
694
    let literal_value = match convert_literal(&literal) {
2✔
695
        Ok(v) => v,
1✔
696
        Err(_) => return Ok(RowidSource::FullScan),
×
697
    };
698

699
    // Index lookup returns the full list of rowids matching this equality
700
    // predicate. For unique indexes that's at most one; for non-unique it
701
    // can be many.
702
    let mut rowids = idx.lookup(&literal_value);
1✔
703
    rowids.sort_unstable();
2✔
704
    Ok(RowidSource::IndexProbe(rowids))
1✔
705
}
706

707
/// Recognizes `expr` as a simple equality on a column reference against a
708
/// literal. Returns `(column_name, literal_value)` if the shape matches;
709
/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
710
fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1✔
711
    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
712
    let peeled = match expr {
1✔
713
        Expr::Nested(inner) => inner.as_ref(),
1✔
714
        other => other,
1✔
715
    };
716
    let Expr::BinaryOp { left, op, right } = peeled else {
1✔
717
        return None;
×
718
    };
719
    if !matches!(op, BinaryOperator::Eq) {
1✔
720
        return None;
1✔
721
    }
722
    let col_from = |e: &Expr| -> Option<String> {
1✔
723
        match e {
1✔
724
            Expr::Identifier(ident) => Some(ident.value.clone()),
1✔
725
            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
×
726
            _ => None,
1✔
727
        }
728
    };
729
    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1✔
730
        if let Expr::Value(v) = e {
2✔
731
            Some(v.value.clone())
1✔
732
        } else {
733
            None
1✔
734
        }
735
    };
736
    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
3✔
737
        return Some((c, l));
1✔
738
    }
739
    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
3✔
740
        return Some((c, l));
1✔
741
    }
742
    None
×
743
}
744

745
/// Recognizes the HNSW-probable query pattern and probes the graph
746
/// if a matching index exists.
747
///
748
/// Looks for ORDER BY `vec_distance_l2(<col>, <bracket-array literal>)`
749
/// where the table has an HNSW index attached to `<col>`. On a match,
750
/// returns the top-k rowids straight from the graph (O(log N)). On
751
/// any miss — different function name, no matching index, query
752
/// dimension wrong, etc. — returns `None` and the caller falls through
753
/// to the bounded-heap brute-force path (7c) or the full sort (7b),
754
/// preserving correct results regardless of whether the HNSW pathway
755
/// kicked in.
756
///
757
/// Phase 7d.2 caveats:
758
/// - Only `vec_distance_l2` is recognized. Cosine and dot fall through
759
///   to brute-force because we don't yet expose a per-index distance
760
///   knob (deferred to Phase 7d.x — see `docs/phase-7-plan.md`).
761
/// - Only ASCENDING order makes sense for "k nearest" — DESC ORDER BY
762
///   `vec_distance_l2(...) LIMIT k` would mean "k farthest", which
763
///   isn't what the index is built for. We don't bother to detect
764
///   `ascending == false` here; the optimizer just skips and the
765
///   fallback path handles it correctly (slower).
766
fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1✔
767
    if k == 0 {
1✔
NEW
768
        return None;
×
769
    }
770

771
    // Pattern-match: order expr must be a function call vec_distance_l2(a, b).
772
    let func = match order_expr {
1✔
773
        Expr::Function(f) => f,
1✔
774
        _ => return None,
1✔
775
    };
776
    let fname = match func.name.0.as_slice() {
2✔
777
        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2✔
NEW
778
        _ => return None,
×
779
    };
780
    if fname != "vec_distance_l2" {
2✔
NEW
781
        return None;
×
782
    }
783

784
    // Extract the two args as raw Exprs.
785
    let arg_list = match &func.args {
1✔
786
        FunctionArguments::List(l) => &l.args,
1✔
NEW
787
        _ => return None,
×
788
    };
789
    if arg_list.len() != 2 {
2✔
NEW
790
        return None;
×
791
    }
792
    let exprs: Vec<&Expr> = arg_list
1✔
793
        .iter()
794
        .filter_map(|a| match a {
3✔
795
            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1✔
NEW
796
            _ => None,
×
797
        })
798
        .collect();
799
    if exprs.len() != 2 {
2✔
NEW
800
        return None;
×
801
    }
802

803
    // One arg must be a column reference (the indexed col); the other
804
    // must be a bracket-array literal (the query vector). Try both
805
    // orderings — pgvector's idiom puts the column on the left, but
806
    // SQL is commutative for distance.
807
    let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
3✔
808
        Some(v) => v,
1✔
NEW
809
        None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
×
NEW
810
            Some(v) => v,
×
NEW
811
            None => return None,
×
812
        },
813
    };
814

815
    // Find the HNSW index on this column.
816
    let entry = table
4✔
817
        .hnsw_indexes
818
        .iter()
1✔
819
        .find(|e| e.column_name == col_name)?;
3✔
820

821
    // Dimension sanity check — the query vector must match the
822
    // indexed column's declared dimension. If it doesn't, the brute-
823
    // force fallback would also error at the vec_distance_l2 dim-check;
824
    // returning None here lets that path produce the user-visible
825
    // error message.
826
    let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
3✔
827
        Some(c) => match &c.datatype {
1✔
828
            DataType::Vector(d) => *d,
1✔
NEW
829
            _ => return None,
×
830
        },
NEW
831
        None => return None,
×
832
    };
833
    if query_vec.len() != declared_dim {
2✔
NEW
834
        return None;
×
835
    }
836

837
    // Probe the graph. Vectors are looked up from the table's row
838
    // storage — a closure rather than a `&Table` so the algorithm
839
    // module stays decoupled from the SQL types.
840
    let column_for_closure = col_name.clone();
1✔
841
    let table_ref = table;
842
    let result = entry.index.search(&query_vec, k, |id| {
3✔
843
        match table_ref.get_value(&column_for_closure, id) {
1✔
844
            Some(Value::Vector(v)) => v,
1✔
NEW
845
            _ => Vec::new(),
×
846
        }
847
    });
848
    Some(result)
1✔
849
}
850

851
/// Helper for `try_hnsw_probe`: given two function args, identify which
852
/// one is a bare column identifier (the indexed column) and which is a
853
/// bracket-array literal (the query vector). Returns
854
/// `Some((column_name, query_vec))` on a match, `None` otherwise.
855
fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1✔
856
    let col_name = match a {
1✔
857
        Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2✔
NEW
858
        _ => return None,
×
859
    };
860
    let lit_str = match b {
1✔
861
        Expr::Identifier(ident) if ident.quote_style == Some('[') => {
2✔
862
            format!("[{}]", ident.value)
1✔
863
        }
NEW
864
        _ => return None,
×
865
    };
866
    let v = parse_vector_literal(&lit_str).ok()?;
2✔
867
    Some((col_name, v))
1✔
868
}
869

870
/// One entry in the bounded-heap top-k path. Holds a pre-evaluated
871
/// sort key + the rowid it came from. The `asc` flag inverts `Ord`
872
/// so a single `BinaryHeap<HeapEntry>` works for both ASC and DESC
873
/// without wrapping in `std::cmp::Reverse` at the call site:
874
///
875
///   - ASC LIMIT k = "k smallest": natural Ord. Max-heap top is the
876
///     largest currently kept; new items smaller than top displace.
877
///   - DESC LIMIT k = "k largest": Ord reversed. Max-heap top is now
878
///     the smallest currently kept (under reversed Ord, smallest
879
///     looks largest); new items larger than top displace.
880
///
881
/// In both cases the displacement test reduces to "new entry < heap top".
882
struct HeapEntry {
883
    key: Value,
884
    rowid: i64,
885
    asc: bool,
886
}
887

888
impl PartialEq for HeapEntry {
889
    fn eq(&self, other: &Self) -> bool {
×
890
        self.cmp(other) == Ordering::Equal
×
891
    }
892
}
893

894
impl Eq for HeapEntry {}
895

896
impl PartialOrd for HeapEntry {
897
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1✔
898
        Some(self.cmp(other))
1✔
899
    }
900
}
901

902
impl Ord for HeapEntry {
903
    fn cmp(&self, other: &Self) -> Ordering {
1✔
904
        let raw = compare_values(Some(&self.key), Some(&other.key));
1✔
905
        if self.asc { raw } else { raw.reverse() }
1✔
906
    }
907
}
908

909
/// Bounded-heap top-k selection. Returns at most `k` rowids in the
910
/// caller's desired order (ascending key for `order.ascending`,
911
/// descending otherwise).
912
///
913
/// O(N log k) where N = `matching.len()`. Caller must check
914
/// `k < matching.len()` for this to be a win — for k ≥ N the
915
/// `sort_rowids` full-sort path is the same asymptotic cost without
916
/// the heap overhead.
917
fn select_topk(
1✔
918
    matching: &[i64],
919
    table: &Table,
920
    order: &OrderByClause,
921
    k: usize,
922
) -> Result<Vec<i64>> {
923
    use std::collections::BinaryHeap;
924

925
    if k == 0 || matching.is_empty() {
1✔
926
        return Ok(Vec::new());
1✔
927
    }
928

929
    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1✔
930

931
    for &rowid in matching {
3✔
932
        let key = eval_expr(&order.expr, table, rowid)?;
2✔
933
        let entry = HeapEntry {
934
            key,
935
            rowid,
936
            asc: order.ascending,
1✔
937
        };
938

939
        if heap.len() < k {
2✔
940
            heap.push(entry);
2✔
941
        } else {
942
            // peek() returns the largest under our direction-aware Ord
943
            // — the worst entry currently kept. Displace it iff the
944
            // new entry is "better" (i.e. compares Less).
945
            if entry < *heap.peek().unwrap() {
2✔
946
                heap.pop();
1✔
947
                heap.push(entry);
1✔
948
            }
949
        }
950
    }
951

952
    // `into_sorted_vec` returns ascending under our direction-aware Ord:
953
    //   ASC: ascending by raw key (what we want)
954
    //   DESC: ascending under reversed Ord = descending by raw key (what
955
    //         we want for an ORDER BY DESC LIMIT k result)
956
    Ok(heap
2✔
957
        .into_sorted_vec()
1✔
958
        .into_iter()
1✔
959
        .map(|e| e.rowid)
3✔
960
        .collect())
1✔
961
}
962

963
fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1✔
964
    // Phase 7b: ORDER BY now accepts any expression (column ref,
965
    // arithmetic, function call, …). Pre-compute the sort key for
966
    // every rowid up front so the comparator is called O(N log N)
967
    // times against pre-evaluated Values rather than re-evaluating
968
    // the expression O(N log N) times. Not strictly necessary today,
969
    // but vital once 7d's HNSW index lands and this same code path
970
    // could be running tens of millions of distance computations.
971
    let mut keys: Vec<(i64, Result<Value>)> = rowids
2✔
972
        .iter()
973
        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
3✔
974
        .collect();
975

976
    // Surface the FIRST evaluation error if any. We could be lazy
977
    // and let sort_by encounter it, but `Ord::cmp` can't return a
978
    // Result and we'd have to swallow errors silently.
979
    for (_, k) in &keys {
2✔
980
        if let Err(e) = k {
1✔
981
            return Err(SQLRiteError::General(format!(
×
982
                "ORDER BY expression failed: {e}"
983
            )));
984
        }
985
    }
986

987
    keys.sort_by(|(_, ka), (_, kb)| {
3✔
988
        // Both unwrap()s are safe — we just verified above that
989
        // every key Result is Ok.
990
        let va = ka.as_ref().unwrap();
1✔
991
        let vb = kb.as_ref().unwrap();
1✔
992
        let ord = compare_values(Some(va), Some(vb));
1✔
993
        if order.ascending { ord } else { ord.reverse() }
1✔
994
    });
995

996
    // Write the sorted rowids back into the caller's slice.
997
    for (i, (rowid, _)) in keys.into_iter().enumerate() {
2✔
998
        rowids[i] = rowid;
2✔
999
    }
1000
    Ok(())
1✔
1001
}
1002

1003
fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1✔
1004
    match (a, b) {
2✔
1005
        (None, None) => Ordering::Equal,
×
1006
        (None, _) => Ordering::Less,
×
1007
        (_, None) => Ordering::Greater,
×
1008
        (Some(a), Some(b)) => match (a, b) {
2✔
1009
            (Value::Null, Value::Null) => Ordering::Equal,
×
1010
            (Value::Null, _) => Ordering::Less,
×
1011
            (_, Value::Null) => Ordering::Greater,
×
1012
            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1✔
1013
            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1✔
1014
            (Value::Integer(x), Value::Real(y)) => {
×
1015
                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
×
1016
            }
1017
            (Value::Real(x), Value::Integer(y)) => {
×
1018
                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
×
1019
            }
1020
            (Value::Text(x), Value::Text(y)) => x.cmp(y),
1✔
1021
            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
×
1022
            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
1023
            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
×
1024
        },
1025
    }
1026
}
1027

1028
/// Returns `true` if the row at `rowid` matches the predicate expression.
1029
pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1✔
1030
    let v = eval_expr(expr, table, rowid)?;
2✔
1031
    match v {
1✔
1032
        Value::Bool(b) => Ok(b),
1✔
1033
        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
1034
        Value::Integer(i) => Ok(i != 0),
×
1035
        other => Err(SQLRiteError::Internal(format!(
×
1036
            "WHERE clause must evaluate to boolean, got {}",
1037
            other.to_display_string()
×
1038
        ))),
1039
    }
1040
}
1041

1042
fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1✔
1043
    match expr {
1✔
1044
        Expr::Nested(inner) => eval_expr(inner, table, rowid),
2✔
1045

1046
        Expr::Identifier(ident) => {
1✔
1047
            // Phase 7b — sqlparser parses bracket-array literals like
1048
            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
1049
            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
1050
            // in expression-evaluation position (SELECT projection, WHERE,
1051
            // ORDER BY, function args), parse the bracketed content as a
1052
            // vector literal so the rest of the executor can compare /
1053
            // distance-compute against it. Same trick the INSERT parser
1054
            // uses; the executor needed its own copy because expression
1055
            // eval runs on a different code path.
1056
            if ident.quote_style == Some('[') {
1✔
1057
                let raw = format!("[{}]", ident.value);
1✔
1058
                let v = parse_vector_literal(&raw)?;
2✔
1059
                return Ok(Value::Vector(v));
1✔
1060
            }
1061
            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1✔
1062
        }
1063

1064
        Expr::CompoundIdentifier(parts) => {
×
1065
            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
1066
            let col = parts
×
1067
                .last()
×
1068
                .map(|i| i.value.as_str())
×
1069
                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
×
1070
            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
×
1071
        }
1072

1073
        Expr::Value(v) => convert_literal(&v.value),
1✔
1074

1075
        Expr::UnaryOp { op, expr } => {
×
1076
            let inner = eval_expr(expr, table, rowid)?;
×
1077
            match op {
×
1078
                UnaryOperator::Not => match inner {
×
1079
                    Value::Bool(b) => Ok(Value::Bool(!b)),
×
1080
                    Value::Null => Ok(Value::Null),
×
1081
                    other => Err(SQLRiteError::Internal(format!(
×
1082
                        "NOT applied to non-boolean value: {}",
1083
                        other.to_display_string()
×
1084
                    ))),
1085
                },
1086
                UnaryOperator::Minus => match inner {
×
1087
                    Value::Integer(i) => Ok(Value::Integer(-i)),
×
1088
                    Value::Real(f) => Ok(Value::Real(-f)),
×
1089
                    Value::Null => Ok(Value::Null),
×
1090
                    other => Err(SQLRiteError::Internal(format!(
×
1091
                        "unary minus on non-numeric value: {}",
1092
                        other.to_display_string()
×
1093
                    ))),
1094
                },
1095
                UnaryOperator::Plus => Ok(inner),
×
1096
                other => Err(SQLRiteError::NotImplemented(format!(
×
1097
                    "unary operator {other:?} is not supported"
1098
                ))),
1099
            }
1100
        }
1101

1102
        Expr::BinaryOp { left, op, right } => match op {
1✔
1103
            BinaryOperator::And => {
1104
                let l = eval_expr(left, table, rowid)?;
×
1105
                let r = eval_expr(right, table, rowid)?;
×
1106
                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
×
1107
            }
1108
            BinaryOperator::Or => {
1109
                let l = eval_expr(left, table, rowid)?;
×
1110
                let r = eval_expr(right, table, rowid)?;
×
1111
                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
×
1112
            }
1113
            cmp @ (BinaryOperator::Eq
1114
            | BinaryOperator::NotEq
1115
            | BinaryOperator::Lt
1116
            | BinaryOperator::LtEq
1117
            | BinaryOperator::Gt
1118
            | BinaryOperator::GtEq) => {
1119
                let l = eval_expr(left, table, rowid)?;
2✔
1120
                let r = eval_expr(right, table, rowid)?;
2✔
1121
                // Any comparison involving NULL is unknown → false in a WHERE.
1122
                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1✔
1123
                    return Ok(Value::Bool(false));
×
1124
                }
1125
                let ord = compare_values(Some(&l), Some(&r));
2✔
1126
                let result = match cmp {
1✔
1127
                    BinaryOperator::Eq => ord == Ordering::Equal,
2✔
1128
                    BinaryOperator::NotEq => ord != Ordering::Equal,
×
1129
                    BinaryOperator::Lt => ord == Ordering::Less,
2✔
1130
                    BinaryOperator::LtEq => ord != Ordering::Greater,
×
1131
                    BinaryOperator::Gt => ord == Ordering::Greater,
2✔
1132
                    BinaryOperator::GtEq => ord != Ordering::Less,
×
1133
                    _ => unreachable!(),
1134
                };
1135
                Ok(Value::Bool(result))
1✔
1136
            }
1137
            arith @ (BinaryOperator::Plus
1138
            | BinaryOperator::Minus
1139
            | BinaryOperator::Multiply
1140
            | BinaryOperator::Divide
1141
            | BinaryOperator::Modulo) => {
1142
                let l = eval_expr(left, table, rowid)?;
2✔
1143
                let r = eval_expr(right, table, rowid)?;
2✔
1144
                eval_arith(arith, &l, &r)
1✔
1145
            }
1146
            BinaryOperator::StringConcat => {
1147
                let l = eval_expr(left, table, rowid)?;
×
1148
                let r = eval_expr(right, table, rowid)?;
×
1149
                if matches!(l, Value::Null) || matches!(r, Value::Null) {
×
1150
                    return Ok(Value::Null);
×
1151
                }
1152
                Ok(Value::Text(format!(
×
1153
                    "{}{}",
1154
                    l.to_display_string(),
×
1155
                    r.to_display_string()
×
1156
                )))
1157
            }
1158
            other => Err(SQLRiteError::NotImplemented(format!(
×
1159
                "binary operator {other:?} is not supported yet"
1160
            ))),
1161
        },
1162

1163
        // Phase 7b — function-call dispatch. Currently only the three
1164
        // vector-distance functions; this match arm becomes the single
1165
        // place to register more SQL functions later (e.g. abs(),
1166
        // length(), …) without re-touching the rest of the executor.
1167
        //
1168
        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
1169
        // of three don't parse natively in sqlparser (we'd need a
1170
        // string-preprocessing pass or a sqlparser fork). Deferred to
1171
        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
1172
        // corrections" note.
1173
        Expr::Function(func) => eval_function(func, table, rowid),
1✔
1174

1175
        other => Err(SQLRiteError::NotImplemented(format!(
×
1176
            "unsupported expression in WHERE/projection: {other:?}"
1177
        ))),
1178
    }
1179
}
1180

1181
/// Dispatches an `Expr::Function` to its built-in implementation.
1182
/// Currently only the three vec_distance_* functions; other functions
1183
/// surface as `NotImplemented` errors with the function name in the
1184
/// message so users see what they tried.
1185
fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1✔
1186
    // Function name lives in `name.0[0]` for unqualified calls. Anything
1187
    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
1188
    let name = match func.name.0.as_slice() {
2✔
1189
        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2✔
1190
        _ => {
1191
            return Err(SQLRiteError::NotImplemented(format!(
×
1192
                "qualified function names not supported: {:?}",
1193
                func.name
1194
            )));
1195
        }
1196
    };
1197

1198
    match name.as_str() {
2✔
1199
        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2✔
1200
            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
3✔
1201
            let dist = match name.as_str() {
2✔
1202
                "vec_distance_l2" => vec_distance_l2(&a, &b),
3✔
1203
                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
4✔
1204
                "vec_distance_dot" => vec_distance_dot(&a, &b),
3✔
1205
                _ => unreachable!(),
1206
            };
1207
            // Widen f32 → f64 for the runtime Value. Vectors are stored
1208
            // as f32 (consistent with industry convention for embeddings),
1209
            // but the executor's numeric type is f64 so distances slot
1210
            // into Value::Real cleanly and can be compared / ordered with
1211
            // other reals via the existing arithmetic + comparison paths.
1212
            Ok(Value::Real(dist as f64))
1✔
1213
        }
1214
        other => Err(SQLRiteError::NotImplemented(format!(
1✔
1215
            "unknown function: {other}(...)"
1216
        ))),
1217
    }
1218
}
1219

1220
/// Extracts exactly two `Vec<f32>` arguments from a function call,
1221
/// validating arity and that both sides are Vector-typed with matching
1222
/// dimensions. Used by all three vec_distance_* functions.
1223
fn extract_two_vector_args(
1✔
1224
    fn_name: &str,
1225
    args: &FunctionArguments,
1226
    table: &Table,
1227
    rowid: i64,
1228
) -> Result<(Vec<f32>, Vec<f32>)> {
1229
    let arg_list = match args {
1✔
1230
        FunctionArguments::List(l) => &l.args,
1✔
1231
        _ => {
1232
            return Err(SQLRiteError::General(format!(
×
1233
                "{fn_name}() expects exactly two vector arguments"
1234
            )));
1235
        }
1236
    };
1237
    if arg_list.len() != 2 {
1✔
1238
        return Err(SQLRiteError::General(format!(
×
1239
            "{fn_name}() expects exactly 2 arguments, got {}",
1240
            arg_list.len()
×
1241
        )));
1242
    }
1243
    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1✔
1244
    for (i, arg) in arg_list.iter().enumerate() {
3✔
1245
        let expr = match arg {
2✔
1246
            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1✔
1247
            other => {
×
1248
                return Err(SQLRiteError::NotImplemented(format!(
×
1249
                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
1250
                )));
1251
            }
1252
        };
1253
        let val = eval_expr(expr, table, rowid)?;
1✔
1254
        match val {
1✔
1255
            Value::Vector(v) => out.push(v),
1✔
1256
            other => {
×
1257
                return Err(SQLRiteError::General(format!(
×
1258
                    "{fn_name}() argument {i} is not a vector: got {}",
1259
                    other.to_display_string()
×
1260
                )));
1261
            }
1262
        }
1263
    }
1264
    let b = out.pop().unwrap();
1✔
1265
    let a = out.pop().unwrap();
2✔
1266
    if a.len() != b.len() {
2✔
1267
        return Err(SQLRiteError::General(format!(
1✔
1268
            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1269
            a.len(),
2✔
1270
            b.len()
1✔
1271
        )));
1272
    }
1273
    Ok((a, b))
1✔
1274
}
1275

1276
/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
1277
/// Smaller-is-closer; identical vectors return 0.0.
1278
pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1✔
1279
    debug_assert_eq!(a.len(), b.len());
1✔
1280
    let mut sum = 0.0f32;
1✔
1281
    for i in 0..a.len() {
2✔
1282
        let d = a[i] - b[i];
2✔
1283
        sum += d * d;
1✔
1284
    }
1285
    sum.sqrt()
1✔
1286
}
1287

1288
/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
1289
/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
1290
/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
1291
///
1292
/// Errors if either vector has zero magnitude — cosine similarity is
1293
/// undefined for the zero vector and silently returning NaN would
1294
/// poison `ORDER BY` ranking. Callers who want the silent-NaN
1295
/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
1296
/// themselves.
1297
pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1✔
1298
    debug_assert_eq!(a.len(), b.len());
1✔
1299
    let mut dot = 0.0f32;
1✔
1300
    let mut norm_a_sq = 0.0f32;
1✔
1301
    let mut norm_b_sq = 0.0f32;
1✔
1302
    for i in 0..a.len() {
2✔
1303
        dot += a[i] * b[i];
2✔
1304
        norm_a_sq += a[i] * a[i];
2✔
1305
        norm_b_sq += b[i] * b[i];
2✔
1306
    }
1307
    let denom = (norm_a_sq * norm_b_sq).sqrt();
1✔
1308
    if denom == 0.0 {
1✔
1309
        return Err(SQLRiteError::General(
1✔
1310
            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1✔
1311
        ));
1312
    }
1313
    Ok(1.0 - dot / denom)
1✔
1314
}
1315

1316
/// Negated dot product: −(a·b).
1317
/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
1318
/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
1319
pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1✔
1320
    debug_assert_eq!(a.len(), b.len());
1✔
1321
    let mut dot = 0.0f32;
1✔
1322
    for i in 0..a.len() {
2✔
1323
        dot += a[i] * b[i];
2✔
1324
    }
1325
    -dot
1✔
1326
}
1327

1328
/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
1329
/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
1330
fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1✔
1331
    if matches!(l, Value::Null) || matches!(r, Value::Null) {
1✔
1332
        return Ok(Value::Null);
×
1333
    }
1334
    match (l, r) {
1✔
1335
        (Value::Integer(a), Value::Integer(b)) => match op {
1✔
1336
            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1✔
1337
            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
×
1338
            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1✔
1339
            BinaryOperator::Divide => {
1340
                if *b == 0 {
×
1341
                    Err(SQLRiteError::General("division by zero".to_string()))
×
1342
                } else {
1343
                    Ok(Value::Integer(a / b))
×
1344
                }
1345
            }
1346
            BinaryOperator::Modulo => {
1347
                if *b == 0 {
×
1348
                    Err(SQLRiteError::General("modulo by zero".to_string()))
×
1349
                } else {
1350
                    Ok(Value::Integer(a % b))
×
1351
                }
1352
            }
1353
            _ => unreachable!(),
1354
        },
1355
        // Anything involving a Real promotes both sides to f64.
1356
        (a, b) => {
×
1357
            let af = as_number(a)?;
×
1358
            let bf = as_number(b)?;
×
1359
            match op {
×
1360
                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
×
1361
                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
×
1362
                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
×
1363
                BinaryOperator::Divide => {
1364
                    if bf == 0.0 {
×
1365
                        Err(SQLRiteError::General("division by zero".to_string()))
×
1366
                    } else {
1367
                        Ok(Value::Real(af / bf))
×
1368
                    }
1369
                }
1370
                BinaryOperator::Modulo => {
1371
                    if bf == 0.0 {
×
1372
                        Err(SQLRiteError::General("modulo by zero".to_string()))
×
1373
                    } else {
1374
                        Ok(Value::Real(af % bf))
×
1375
                    }
1376
                }
1377
                _ => unreachable!(),
1378
            }
1379
        }
1380
    }
1381
}
1382

1383
fn as_number(v: &Value) -> Result<f64> {
×
1384
    match v {
×
1385
        Value::Integer(i) => Ok(*i as f64),
×
1386
        Value::Real(f) => Ok(*f),
×
1387
        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
×
1388
        other => Err(SQLRiteError::General(format!(
×
1389
            "arithmetic on non-numeric value '{}'",
1390
            other.to_display_string()
×
1391
        ))),
1392
    }
1393
}
1394

1395
fn as_bool(v: &Value) -> Result<bool> {
×
1396
    match v {
×
1397
        Value::Bool(b) => Ok(*b),
×
1398
        Value::Null => Ok(false),
1399
        Value::Integer(i) => Ok(*i != 0),
×
1400
        other => Err(SQLRiteError::Internal(format!(
×
1401
            "expected boolean, got {}",
1402
            other.to_display_string()
×
1403
        ))),
1404
    }
1405
}
1406

1407
fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1✔
1408
    use sqlparser::ast::Value as AstValue;
1409
    match v {
1✔
1410
        AstValue::Number(n, _) => {
1✔
1411
            if let Ok(i) = n.parse::<i64>() {
2✔
1412
                Ok(Value::Integer(i))
1✔
1413
            } else if let Ok(f) = n.parse::<f64>() {
2✔
1414
                Ok(Value::Real(f))
1✔
1415
            } else {
1416
                Err(SQLRiteError::Internal(format!(
×
1417
                    "could not parse numeric literal '{n}'"
1418
                )))
1419
            }
1420
        }
1421
        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1✔
1422
        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
×
1423
        AstValue::Null => Ok(Value::Null),
×
1424
        other => Err(SQLRiteError::NotImplemented(format!(
×
1425
            "unsupported literal value: {other:?}"
1426
        ))),
1427
    }
1428
}
1429

1430
#[cfg(test)]
1431
mod tests {
1432
    use super::*;
1433

1434
    // -----------------------------------------------------------------
1435
    // Phase 7b — Vector distance function math
1436
    // -----------------------------------------------------------------
1437

1438
    /// Float comparison helper — distance results need a small epsilon
1439
    /// because we accumulate sums across many f32 multiplies.
1440
    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1✔
1441
        (a - b).abs() < eps
1✔
1442
    }
1443

1444
    #[test]
1445
    fn vec_distance_l2_identical_is_zero() {
3✔
1446
        let v = vec![0.1, 0.2, 0.3];
1✔
1447
        assert_eq!(vec_distance_l2(&v, &v), 0.0);
2✔
1448
    }
1449

1450
    #[test]
1451
    fn vec_distance_l2_unit_basis_is_sqrt2() {
3✔
1452
        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
1453
        let a = vec![1.0, 0.0];
1✔
1454
        let b = vec![0.0, 1.0];
2✔
1455
        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
2✔
1456
    }
1457

1458
    #[test]
1459
    fn vec_distance_l2_known_value() {
3✔
1460
        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1461
        let a = vec![0.0, 0.0, 0.0];
1✔
1462
        let b = vec![3.0, 4.0, 0.0];
2✔
1463
        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
2✔
1464
    }
1465

1466
    #[test]
1467
    fn vec_distance_cosine_identical_is_zero() {
3✔
1468
        let v = vec![0.1, 0.2, 0.3];
1✔
1469
        let d = vec_distance_cosine(&v, &v).unwrap();
2✔
1470
        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1✔
1471
    }
1472

1473
    #[test]
1474
    fn vec_distance_cosine_orthogonal_is_one() {
3✔
1475
        // Two orthogonal unit vectors should have cosine distance = 1.0
1476
        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1477
        let a = vec![1.0, 0.0];
1✔
1478
        let b = vec![0.0, 1.0];
2✔
1479
        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
2✔
1480
    }
1481

1482
    #[test]
1483
    fn vec_distance_cosine_opposite_is_two() {
3✔
1484
        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1485
        let a = vec![1.0, 0.0, 0.0];
1✔
1486
        let b = vec![-1.0, 0.0, 0.0];
2✔
1487
        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
2✔
1488
    }
1489

1490
    #[test]
1491
    fn vec_distance_cosine_zero_magnitude_errors() {
3✔
1492
        // Cosine is undefined for the zero vector — error rather than NaN.
1493
        let a = vec![0.0, 0.0];
1✔
1494
        let b = vec![1.0, 0.0];
2✔
1495
        let err = vec_distance_cosine(&a, &b).unwrap_err();
2✔
1496
        assert!(format!("{err}").contains("zero-magnitude"));
2✔
1497
    }
1498

1499
    #[test]
1500
    fn vec_distance_dot_negates() {
3✔
1501
        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1502
        let a = vec![1.0, 2.0, 3.0];
1✔
1503
        let b = vec![4.0, 5.0, 6.0];
2✔
1504
        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
2✔
1505
    }
1506

1507
    #[test]
1508
    fn vec_distance_dot_orthogonal_is_zero() {
3✔
1509
        // Orthogonal vectors have dot product 0 → negated is also 0.
1510
        let a = vec![1.0, 0.0];
1✔
1511
        let b = vec![0.0, 1.0];
2✔
1512
        assert_eq!(vec_distance_dot(&a, &b), 0.0);
2✔
1513
    }
1514

1515
    #[test]
1516
    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3✔
1517
        // For unit-norm vectors: dot(a,b) = cos(a,b)
1518
        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1519
        // Useful sanity check that the two functions agree on unit vectors.
1520
        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1✔
1521
        let b = vec![0.8f32, 0.6]; // unit norm too
2✔
1522
        let dot = vec_distance_dot(&a, &b);
2✔
1523
        let cos = vec_distance_cosine(&a, &b).unwrap();
1✔
1524
        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1✔
1525
    }
1526

1527
    // -----------------------------------------------------------------
1528
    // Phase 7c — bounded-heap top-k correctness + benchmark
1529
    // -----------------------------------------------------------------
1530

1531
    use crate::sql::db::database::Database;
1532
    use crate::sql::parser::select::SelectQuery;
1533
    use sqlparser::dialect::SQLiteDialect;
1534
    use sqlparser::parser::Parser;
1535

1536
    /// Builds a `docs(id INTEGER PK, score REAL)` table with N rows of
1537
    /// distinct positive scores so top-k tests aren't sensitive to
1538
    /// tie-breaking (heap is unstable; full-sort is stable; we want
1539
    /// both to agree without arguing about equal-score row order).
1540
    ///
1541
    /// **Why positive scores:** the INSERT parser doesn't currently
1542
    /// handle `Expr::UnaryOp(Minus, …)` for negative number literals
1543
    /// (it would parse `-3.14` as a unary expression and the value
1544
    /// extractor would skip it). That's a pre-existing bug, out of
1545
    /// scope for 7c. Using the Knuth multiplicative hash gives us
1546
    /// distinct positive scrambled values without dancing around the
1547
    /// negative-literal limitation.
1548
    fn seed_score_table(n: usize) -> Database {
1✔
1549
        let mut db = Database::new("tempdb".to_string());
1✔
1550
        crate::sql::process_command(
1551
            "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1552
            &mut db,
1553
        )
1554
        .expect("create");
1555
        for i in 0..n {
1✔
1556
            // Knuth multiplicative hash mod 1_000_000 — distinct,
1557
            // dense in [0, 999_999], no collisions for n up to ~tens
1558
            // of thousands.
1559
            let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
2✔
1560
            let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1✔
1561
            crate::sql::process_command(&sql, &mut db).expect("insert");
2✔
1562
        }
1563
        db
1✔
1564
    }
1565

1566
    /// Helper: parses an SQL SELECT into a SelectQuery so we can drive
1567
    /// `select_topk` / `sort_rowids` directly without the rest of the
1568
    /// process_command pipeline.
1569
    fn parse_select(sql: &str) -> SelectQuery {
1✔
1570
        let dialect = SQLiteDialect {};
1571
        let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1✔
1572
        let stmt = ast.pop().expect("one statement");
2✔
1573
        SelectQuery::new(&stmt).expect("select-query")
2✔
1574
    }
1575

1576
    #[test]
1577
    fn topk_matches_full_sort_asc() {
3✔
1578
        // Build N=200, top-k=10. Bounded heap output must equal
1579
        // full-sort-then-truncate output (both produce ASC order).
1580
        let db = seed_score_table(200);
1✔
1581
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1582
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1✔
1583
        let order = q.order_by.as_ref().unwrap();
2✔
1584
        let all_rowids = table.rowids();
1✔
1585

1586
        // Full-sort path
1587
        let mut full = all_rowids.clone();
1✔
1588
        sort_rowids(&mut full, table, order).unwrap();
2✔
1589
        full.truncate(10);
1✔
1590

1591
        // Bounded-heap path
1592
        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1✔
1593

1594
        assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
2✔
1595
    }
1596

1597
    #[test]
1598
    fn topk_matches_full_sort_desc() {
3✔
1599
        // Same with DESC — verifies the direction-aware Ord wrapper.
1600
        let db = seed_score_table(200);
1✔
1601
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1602
        let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1✔
1603
        let order = q.order_by.as_ref().unwrap();
2✔
1604
        let all_rowids = table.rowids();
1✔
1605

1606
        let mut full = all_rowids.clone();
1✔
1607
        sort_rowids(&mut full, table, order).unwrap();
2✔
1608
        full.truncate(10);
1✔
1609

1610
        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1✔
1611

1612
        assert_eq!(
2✔
1613
            topk, full,
1614
            "top-k DESC via heap should match full-sort+truncate"
1615
        );
1616
    }
1617

1618
    #[test]
1619
    fn topk_k_larger_than_n_returns_everything_sorted() {
3✔
1620
        // The executor branches off to the full-sort path when k >= N,
1621
        // but if a caller invokes select_topk directly with k > N, it
1622
        // should still produce all-sorted output (no truncation
1623
        // because we don't have N items to truncate to k).
1624
        let db = seed_score_table(50);
1✔
1625
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1626
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1✔
1627
        let order = q.order_by.as_ref().unwrap();
2✔
1628
        let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1✔
1629
        assert_eq!(topk.len(), 50);
1✔
1630
        // All scores in ascending order.
1631
        let scores: Vec<f64> = topk
1✔
1632
            .iter()
1633
            .filter_map(|r| match table.get_value("score", *r) {
3✔
1634
                Some(Value::Real(f)) => Some(f),
1✔
1635
                _ => None,
×
1636
            })
1637
            .collect();
1638
        assert!(scores.windows(2).all(|w| w[0] <= w[1]));
4✔
1639
    }
1640

1641
    #[test]
1642
    fn topk_k_zero_returns_empty() {
3✔
1643
        let db = seed_score_table(10);
1✔
1644
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1645
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1✔
1646
        let order = q.order_by.as_ref().unwrap();
2✔
1647
        let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1✔
1648
        assert!(topk.is_empty());
1✔
1649
    }
1650

1651
    #[test]
1652
    fn topk_empty_input_returns_empty() {
3✔
1653
        let db = seed_score_table(0);
1✔
1654
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1655
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1✔
1656
        let order = q.order_by.as_ref().unwrap();
2✔
1657
        let topk = select_topk(&[], table, order, 5).unwrap();
1✔
1658
        assert!(topk.is_empty());
2✔
1659
    }
1660

1661
    #[test]
1662
    fn topk_works_through_select_executor_with_distance_function() {
4✔
1663
        // Integration check that the executor actually picks the
1664
        // bounded-heap path on a KNN-shaped query and produces the
1665
        // correct top-k.
1666
        let mut db = Database::new("tempdb".to_string());
1✔
1667
        crate::sql::process_command(
1668
            "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1669
            &mut db,
1670
        )
1671
        .unwrap();
1672
        // Five rows with distinct distances from probe [1.0, 0.0]:
1673
        //   id=1 [1.0, 0.0]   distance=0
1674
        //   id=2 [2.0, 0.0]   distance=1
1675
        //   id=3 [0.0, 3.0]   distance=√(1+9) = √10 ≈ 3.16
1676
        //   id=4 [1.0, 4.0]   distance=4
1677
        //   id=5 [10.0, 10.0] distance=√(81+100) ≈ 13.45
1678
        for v in &[
1✔
1679
            "[1.0, 0.0]",
1680
            "[2.0, 0.0]",
1681
            "[0.0, 3.0]",
1682
            "[1.0, 4.0]",
1683
            "[10.0, 10.0]",
1684
        ] {
1685
            crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3✔
1686
                .unwrap();
1687
        }
1688
        let resp = crate::sql::process_command(
1689
            "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1690
            &mut db,
1691
        )
1692
        .unwrap();
1693
        // Top-3 closest to [1.0, 0.0] are id=1, id=2, id=3 (in that order).
1694
        // The status message tells us how many rows came back.
1695
        assert!(resp.contains("3 rows returned"), "got: {resp}");
2✔
1696
    }
1697

1698
    /// Manual benchmark — not run by default. Recommended invocation:
1699
    ///
1700
    ///     cargo test -p sqlrite-engine --lib topk_benchmark --release \
1701
    ///         -- --ignored --nocapture
1702
    ///
1703
    /// (`--release` matters: Rust's optimized sort gets very fast under
1704
    /// optimization, so the heap's relative advantage is best observed
1705
    /// against a sort that's also been optimized.)
1706
    ///
1707
    /// Measured numbers on an Apple Silicon laptop with N=10_000 + k=10:
1708
    ///   - bounded heap:    ~820µs
1709
    ///   - full sort+trunc: ~1.5ms
1710
    ///   - ratio:           ~1.8×
1711
    ///
1712
    /// The advantage is real but moderate at this size because the sort
1713
    /// key here is a single REAL column read (cheap) and Rust's sort_by
1714
    /// has a very low constant factor. The asymptotic O(N log k) vs
1715
    /// O(N log N) advantage scales with N and with per-row work — KNN
1716
    /// queries where the sort key is `vec_distance_l2(col, [...])` are
1717
    /// where this path really pays off, because each key evaluation is
1718
    /// itself O(dim) and the heap path skips the per-row evaluation
1719
    /// in the comparator (see `sort_rowids` for the contrast).
1720
    #[test]
1721
    #[ignore]
1722
    fn topk_benchmark() {
1723
        use std::time::Instant;
1724
        const N: usize = 10_000;
1725
        const K: usize = 10;
1726

1727
        let db = seed_score_table(N);
1728
        let table = db.get_table("docs".to_string()).unwrap();
1729
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1730
        let order = q.order_by.as_ref().unwrap();
1731
        let all_rowids = table.rowids();
1732

1733
        // Time bounded heap.
1734
        let t0 = Instant::now();
1735
        let _topk = select_topk(&all_rowids, table, order, K).unwrap();
1736
        let heap_dur = t0.elapsed();
1737

1738
        // Time full sort + truncate.
1739
        let t1 = Instant::now();
1740
        let mut full = all_rowids.clone();
1741
        sort_rowids(&mut full, table, order).unwrap();
1742
        full.truncate(K);
1743
        let sort_dur = t1.elapsed();
1744

1745
        let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
1746
        println!("\n--- topk_benchmark (N={N}, k={K}) ---");
1747
        println!("  bounded heap:   {heap_dur:?}");
1748
        println!("  full sort+trunc: {sort_dur:?}");
1749
        println!("  speedup ratio:  {ratio:.2}×");
1750

1751
        // Soft assertion. Floor is 1.4× because the cheap-key
1752
        // benchmark hovers around 1.8× empirically; setting this too
1753
        // close to the measured value risks flaky CI on slower
1754
        // runners. Floor of 1.4× still catches an actual regression
1755
        // (e.g., if select_topk became O(N²) or stopped using the
1756
        // heap entirely).
1757
        assert!(
1758
            ratio > 1.4,
1759
            "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
1760
        );
1761
    }
1762
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc