• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

joaoh82 / rust_sqlite / 25093849524

29 Apr 2026 06:15AM UTC coverage: 69.044% (-1.9%) from 70.927%
25093849524

push

github

web-flow
Phase 7e: JSON column type + path queries (#54)

Adds the JSON storage class and four path-aware query functions, closing
the second of Phase 7's two storage primitives (the first was VECTOR(N)
in 7a). Shape mirrors SQLite's JSON1 extension — JSON values store as
canonical UTF-8 text, validated via `serde_json::from_str` at INSERT and
UPDATE time. Phase 7 plan Q3 originally proposed bincoded `serde_json::
Value`, but bincode was removed from the engine in Phase 3c (cell-based
encoding replaced it); rather than re-add bincode for one column type,
JSON-as-text matches SQLite's choice and reuses the existing Text storage
path. Q3 in `docs/phase-7-plan.md` records the scope correction inline.

Engine surface:

- `DataType::Json` variant alongside `Vector(N)`. `JSONB` parses as an
  alias (Postgres convention; both store as text in our case).
- INSERT/UPDATE on a JSON column runs `serde_json::from_str::<Value>`;
  malformed JSON is rejected with `Type mismatch: expected JSON for
  column 'foo': <serde error>`. NULLs pass through untouched.
- UNIQUE on a JSON column treats the value as raw text (string equality
  on the canonical form).
- `table_to_create_sql` round-trips JSON columns; `build_empty_table`,
  `Row::Text(BTreeMap::new())` storage, and the `clone_datatype` helpers
  in `executor.rs` and `pager/mod.rs` all gained the new arm.

Functions (executor.rs, ~370 LOC):

- `json_extract(json[, path])` — walks the path, returns the resolved
  node coerced to the closest SQL type. Strings → TEXT, numbers →
  INTEGER/REAL, booleans → BOOLEAN, `null` → NULL, composites
  (object/array) → canonical JSON text.
- `json_type(json[, path])` — returns one of `'object'`, `'array'`,
  `'string'`, `'integer'`, `'real'`, `'true'`, `'false'`, `'null'`.
- `json_array_length(json[, path])` — element count; errors if the
  resolved node isn't an array.
- `json_object_keys(json[, path])` — keys as a JSON-array text in
  insertion order (e.g. `'["a","b","c"]'`). Diverges f... (continued)

154 of 210 new or added lines in 5 files covered. (73.33%)

227 existing lines in 3 files now uncovered.

5382 of 7795 relevant lines covered (69.04%)

1.42 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

74.02
/src/sql/executor.rs
1
//! Query executors — evaluate parsed SQL statements against the in-memory
2
//! storage and produce formatted output.
3

4
use std::cmp::Ordering;
5

6
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7
use sqlparser::ast::{
8
    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9
    FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10
    TableWithJoins, UnaryOperator, Update,
11
};
12

13
use crate::error::{Result, SQLRiteError};
14
use crate::sql::db::database::Database;
15
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16
use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17
use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18
use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19

20
/// Executes a parsed `SelectQuery` against the database and returns a
21
/// human-readable rendering of the result set (prettytable). Also returns
22
/// the number of rows produced, for the top-level status message.
23
/// Structured result of a SELECT: column names in projection order,
24
/// and each matching row as a `Vec<Value>` aligned with the columns.
25
/// Phase 5a introduced this so the public `Connection` / `Statement`
26
/// API has typed rows to yield; the existing `execute_select` that
27
/// returns pre-rendered text is now a thin wrapper on top.
28
pub struct SelectResult {
29
    pub columns: Vec<String>,
30
    pub rows: Vec<Vec<Value>>,
31
}
32

33
/// Executes a SELECT and returns structured rows. The typed rows are
34
/// what the new public API streams to callers; the REPL / Tauri app
35
/// pre-render into a prettytable via `execute_select`.
36
pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
1✔
37
    let table = db
4✔
38
        .get_table(query.table_name.clone())
2✔
39
        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
4✔
40

41
    // Resolve projection to a concrete ordered column list.
42
    let projected_cols: Vec<String> = match &query.projection {
1✔
43
        Projection::All => table.column_names(),
1✔
44
        Projection::Columns(cols) => {
1✔
45
            for c in cols {
2✔
46
                if !table.contains_column(c.to_string()) {
2✔
47
                    return Err(SQLRiteError::Internal(format!(
1✔
48
                        "Column '{c}' does not exist on table '{}'",
49
                        query.table_name
50
                    )));
51
                }
52
            }
53
            cols.clone()
1✔
54
        }
55
    };
56

57
    // Collect matching rowids. If the WHERE is the shape `col = literal`
58
    // and `col` has a secondary index, probe the index for an O(log N)
59
    // seek; otherwise fall back to the full table scan.
60
    let matching = match select_rowids(table, query.selection.as_ref())? {
2✔
61
        RowidSource::IndexProbe(rowids) => rowids,
1✔
62
        RowidSource::FullScan => {
63
            let mut out = Vec::new();
1✔
64
            for rowid in table.rowids() {
3✔
65
                if let Some(expr) = &query.selection {
2✔
66
                    if !eval_predicate(expr, table, rowid)? {
2✔
67
                        continue;
68
                    }
69
                }
70
                out.push(rowid);
2✔
71
            }
72
            out
1✔
73
        }
74
    };
75
    let mut matching = matching;
1✔
76

77
    // Phase 7c — bounded-heap top-k optimization.
78
    //
79
    // The naive "ORDER BY <expr>" path (Phase 7b) sorts every matching
80
    // rowid: O(N log N) sort_by + a truncate. For KNN queries
81
    //
82
    //     SELECT id FROM docs
83
    //     ORDER BY vec_distance_l2(embedding, [...])
84
    //     LIMIT 10;
85
    //
86
    // N is the table row count and k is the LIMIT. With a bounded
87
    // max-heap of size k we can find the top-k in O(N log k) — same
88
    // sort_by-per-row cost on the heap operations, but k is typically
89
    // 10-100 while N can be millions.
90
    //
91
    // Phase 7d.2 — HNSW ANN probe.
92
    //
93
    // Even better than the bounded heap: if the ORDER BY expression is
94
    // exactly `vec_distance_l2(<col>, <bracket-array literal>)` AND
95
    // `<col>` has an HNSW index attached, skip the linear scan
96
    // entirely and probe the graph in O(log N). Approximate but
97
    // typically ≥ 0.95 recall (verified by the recall tests in
98
    // src/sql/hnsw.rs).
99
    //
100
    // We branch in cases:
101
    //   1. ORDER BY + LIMIT k matches the HNSW probe pattern  → graph probe.
102
    //   2. ORDER BY + LIMIT k where k < |matching|            → bounded heap (7c).
103
    //   3. ORDER BY without LIMIT, or LIMIT >= |matching|     → full sort.
104
    //   4. LIMIT without ORDER BY                              → just truncate.
105
    match (&query.order_by, query.limit) {
2✔
106
        (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
3✔
107
            matching = try_hnsw_probe(table, &order.expr, k).unwrap();
1✔
108
        }
109
        (Some(order), Some(k)) if k < matching.len() => {
1✔
110
            matching = select_topk(&matching, table, order, k)?;
1✔
111
        }
112
        (Some(order), _) => {
×
113
            sort_rowids(&mut matching, table, order)?;
×
114
            if let Some(k) = query.limit {
×
115
                matching.truncate(k);
×
116
            }
117
        }
118
        (None, Some(k)) => {
×
119
            matching.truncate(k);
×
120
        }
121
        (None, None) => {}
122
    }
123

124
    // Build typed rows. Missing cells surface as `Value::Null` — that
125
    // maps a column-not-present-for-this-rowid case onto the public
126
    // `Row::get` → `Option<T>` surface cleanly.
127
    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
2✔
128
    for rowid in &matching {
2✔
129
        let row: Vec<Value> = projected_cols
1✔
130
            .iter()
131
            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
3✔
132
            .collect();
133
        rows.push(row);
1✔
134
    }
135

136
    Ok(SelectResult {
1✔
137
        columns: projected_cols,
1✔
138
        rows,
1✔
139
    })
140
}
141

142
/// Executes a SELECT and returns `(rendered_table, row_count)`. The
143
/// REPL and Tauri app use this to keep the table-printing behaviour
144
/// the engine has always shipped. Structured callers use
145
/// `execute_select_rows` instead.
146
pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
1✔
147
    let result = execute_select_rows(query, db)?;
1✔
148
    let row_count = result.rows.len();
2✔
149

150
    let mut print_table = PrintTable::new();
1✔
151
    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
4✔
152
    print_table.add_row(PrintRow::new(header_cells));
1✔
153

154
    for row in &result.rows {
1✔
155
        let cells: Vec<PrintCell> = row
1✔
156
            .iter()
157
            .map(|v| PrintCell::new(&v.to_display_string()))
3✔
158
            .collect();
159
        print_table.add_row(PrintRow::new(cells));
1✔
160
    }
161

162
    Ok((print_table.to_string(), row_count))
1✔
163
}
164

165
/// Executes a DELETE statement. Returns the number of rows removed.
166
pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
1✔
167
    let Statement::Delete(Delete {
1✔
168
        from, selection, ..
1✔
169
    }) = stmt
1✔
170
    else {
171
        return Err(SQLRiteError::Internal(
×
172
            "execute_delete called on a non-DELETE statement".to_string(),
×
173
        ));
174
    };
175

176
    let tables = match from {
1✔
177
        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
2✔
178
    };
179
    let table_name = extract_single_table_name(tables)?;
1✔
180

181
    // Compute matching rowids with an immutable borrow, then mutate.
182
    let matching: Vec<i64> = {
183
        let table = db
1✔
184
            .get_table(table_name.clone())
2✔
185
            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1✔
186
        match select_rowids(table, selection.as_ref())? {
1✔
187
            RowidSource::IndexProbe(rowids) => rowids,
1✔
188
            RowidSource::FullScan => {
189
                let mut out = Vec::new();
1✔
190
                for rowid in table.rowids() {
3✔
191
                    if let Some(expr) = selection {
2✔
192
                        if !eval_predicate(expr, table, rowid)? {
2✔
193
                            continue;
194
                        }
195
                    }
196
                    out.push(rowid);
2✔
197
                }
198
                out
1✔
199
            }
200
        }
201
    };
202

203
    let table = db.get_table_mut(table_name)?;
2✔
204
    for rowid in &matching {
1✔
205
        table.delete_row(*rowid);
2✔
206
    }
207
    // Phase 7d.3 — any DELETE invalidates every HNSW index on this
208
    // table (the deleted node could still appear in other nodes'
209
    // neighbor lists, breaking subsequent searches). Mark dirty so
210
    // the next save rebuilds from current rows before serializing.
211
    if !matching.is_empty() {
1✔
212
        for entry in &mut table.hnsw_indexes {
3✔
213
            entry.needs_rebuild = true;
1✔
214
        }
215
    }
216
    Ok(matching.len())
2✔
217
}
218

219
/// Executes an UPDATE statement. Returns the number of rows updated.
220
pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
1✔
221
    let Statement::Update(Update {
1✔
222
        table,
1✔
223
        assignments,
1✔
224
        from,
1✔
225
        selection,
1✔
226
        ..
227
    }) = stmt
1✔
228
    else {
229
        return Err(SQLRiteError::Internal(
×
230
            "execute_update called on a non-UPDATE statement".to_string(),
×
231
        ));
232
    };
233

234
    if from.is_some() {
1✔
235
        return Err(SQLRiteError::NotImplemented(
×
236
            "UPDATE ... FROM is not supported yet".to_string(),
×
237
        ));
238
    }
239

240
    let table_name = extract_table_name(table)?;
1✔
241

242
    // Resolve assignment targets to plain column names and verify they exist.
243
    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
2✔
244
    {
245
        let tbl = db
1✔
246
            .get_table(table_name.clone())
2✔
247
            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1✔
248
        for a in assignments {
2✔
249
            let col = match &a.target {
1✔
250
                AssignmentTarget::ColumnName(name) => name
2✔
251
                    .0
252
                    .last()
1✔
253
                    .map(|p| p.to_string())
3✔
254
                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
1✔
255
                AssignmentTarget::Tuple(_) => {
256
                    return Err(SQLRiteError::NotImplemented(
×
257
                        "tuple assignment targets are not supported".to_string(),
×
258
                    ));
259
                }
260
            };
261
            if !tbl.contains_column(col.clone()) {
2✔
262
                return Err(SQLRiteError::Internal(format!(
×
263
                    "UPDATE references unknown column '{col}'"
264
                )));
265
            }
266
            parsed_assignments.push((col, a.value.clone()));
1✔
267
        }
268
    }
269

270
    // Gather matching rowids + the new values to write for each assignment, under
271
    // an immutable borrow. Uses the index-probe fast path when the WHERE is
272
    // `col = literal` on an indexed column.
273
    let work: Vec<(i64, Vec<(String, Value)>)> = {
274
        let tbl = db.get_table(table_name.clone())?;
1✔
275
        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
1✔
276
            RowidSource::IndexProbe(rowids) => rowids,
1✔
277
            RowidSource::FullScan => {
278
                let mut out = Vec::new();
1✔
279
                for rowid in tbl.rowids() {
3✔
280
                    if let Some(expr) = selection {
2✔
281
                        if !eval_predicate(expr, tbl, rowid)? {
2✔
282
                            continue;
283
                        }
284
                    }
285
                    out.push(rowid);
2✔
286
                }
287
                out
1✔
288
            }
289
        };
290
        let mut rows_to_update = Vec::new();
1✔
291
        for rowid in matched_rowids {
4✔
292
            let mut values = Vec::with_capacity(parsed_assignments.len());
2✔
293
            for (col, expr) in &parsed_assignments {
3✔
294
                // UPDATE's RHS is evaluated in the context of the row being updated,
295
                // so column references on the right resolve to the current row's values.
296
                let v = eval_expr(expr, tbl, rowid)?;
2✔
297
                values.push((col.clone(), v));
2✔
298
            }
299
            rows_to_update.push((rowid, values));
1✔
300
        }
301
        rows_to_update
1✔
302
    };
303

304
    let tbl = db.get_table_mut(table_name)?;
2✔
305
    for (rowid, values) in &work {
1✔
306
        for (col, v) in values {
2✔
307
            tbl.set_value(col, *rowid, v.clone())?;
1✔
308
        }
309
    }
310

311
    // Phase 7d.3 — UPDATE may have changed a vector column that an
312
    // HNSW index covers. Mark every covering index dirty so save
313
    // rebuilds from current rows. (Updates that only touched
314
    // non-vector columns also mark dirty, which is over-conservative
315
    // but harmless — the rebuild walks rows anyway, and the cost is
316
    // only paid on save.)
317
    if !work.is_empty() {
1✔
318
        let updated_columns: std::collections::HashSet<&str> = work
1✔
319
            .iter()
320
            .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
5✔
321
            .collect();
322
        for entry in &mut tbl.hnsw_indexes {
2✔
323
            if updated_columns.contains(entry.column_name.as_str()) {
3✔
324
                entry.needs_rebuild = true;
1✔
325
            }
326
        }
327
    }
328
    Ok(work.len())
2✔
329
}
330

331
/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> [USING <method>] (<column>)`.
332
/// Single-column indexes only.
333
///
334
/// Two flavours, branching on the optional `USING <method>` clause:
335
///   - **No USING, or `USING btree`**: regular B-Tree secondary index
336
///     (Phase 3e). Indexable types: Integer, Text.
337
///   - **`USING hnsw`**: HNSW ANN index (Phase 7d.2). Indexable types:
338
///     Vector(N) only. Distance metric is L2 by default; cosine and
339
///     dot variants are deferred to Phase 7d.x.
340
///
341
/// Returns the (possibly synthesized) index name for the status message.
342
pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
1✔
343
    let Statement::CreateIndex(CreateIndex {
1✔
344
        name,
1✔
345
        table_name,
1✔
346
        columns,
1✔
347
        using,
1✔
348
        unique,
1✔
349
        if_not_exists,
1✔
350
        predicate,
1✔
351
        ..
352
    }) = stmt
1✔
353
    else {
354
        return Err(SQLRiteError::Internal(
×
355
            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
×
356
        ));
357
    };
358

359
    if predicate.is_some() {
1✔
360
        return Err(SQLRiteError::NotImplemented(
×
361
            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
×
362
        ));
363
    }
364

365
    if columns.len() != 1 {
1✔
366
        return Err(SQLRiteError::NotImplemented(format!(
×
367
            "multi-column indexes are not supported yet ({} columns given)",
368
            columns.len()
×
369
        )));
370
    }
371

372
    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
3✔
373
        SQLRiteError::NotImplemented(
×
374
            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
×
375
        )
376
    })?;
377

378
    // Detect USING <method>. The `using` field on CreateIndex covers the
379
    // pre-column form `CREATE INDEX … USING hnsw (col)`. (sqlparser also
380
    // accepts a post-column form `… (col) USING hnsw` and parks that in
381
    // `index_options`; we don't bother with it — the canonical form is
382
    // pre-column and matches PG/pgvector convention.)
383
    let method = match using {
1✔
384
        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
2✔
385
            IndexMethod::Hnsw
1✔
386
        }
387
        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
×
388
            IndexMethod::Btree
×
389
        }
390
        Some(other) => {
×
391
            return Err(SQLRiteError::NotImplemented(format!(
×
392
                "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
393
            )));
394
        }
395
        None => IndexMethod::Btree,
1✔
396
    };
397

398
    let table_name_str = table_name.to_string();
1✔
399
    let column_name = match &columns[0].column.expr {
2✔
400
        Expr::Identifier(ident) => ident.value.clone(),
2✔
401
        Expr::CompoundIdentifier(parts) => parts
×
402
            .last()
×
403
            .map(|p| p.value.clone())
×
404
            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
×
405
        other => {
×
406
            return Err(SQLRiteError::NotImplemented(format!(
×
407
                "CREATE INDEX only supports simple column references, got {other:?}"
408
            )));
409
        }
410
    };
411

412
    // Validate: table exists, column exists, type matches the index method,
413
    // name is unique across both index kinds. Snapshot (rowid, value) pairs
414
    // up front under the immutable borrow so the mutable attach later
415
    // doesn't fight over `self`.
416
    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
1✔
417
        let table = db.get_table(table_name_str.clone()).map_err(|_| {
2✔
418
            SQLRiteError::General(format!(
×
419
                "CREATE INDEX references unknown table '{table_name_str}'"
420
            ))
421
        })?;
422
        if !table.contains_column(column_name.clone()) {
1✔
423
            return Err(SQLRiteError::General(format!(
×
424
                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
425
            )));
426
        }
427
        let col = table
3✔
428
            .columns
429
            .iter()
430
            .find(|c| c.column_name == column_name)
3✔
431
            .expect("we just verified the column exists");
432

433
        // Name uniqueness check spans BOTH index kinds — a btree and an
434
        // hnsw can't share a name.
435
        if table.index_by_name(&index_name).is_some()
1✔
436
            || table.hnsw_indexes.iter().any(|i| i.name == index_name)
4✔
437
        {
438
            if *if_not_exists {
1✔
439
                return Ok(index_name);
1✔
440
            }
441
            return Err(SQLRiteError::General(format!(
2✔
442
                "index '{index_name}' already exists"
443
            )));
444
        }
445
        let datatype = clone_datatype(&col.datatype);
1✔
446

447
        let mut pairs = Vec::new();
1✔
448
        for rowid in table.rowids() {
3✔
449
            if let Some(v) = table.get_value(&column_name, rowid) {
2✔
450
                pairs.push((rowid, v));
1✔
451
            }
452
        }
453
        (datatype, pairs)
1✔
454
    };
455

456
    match method {
1✔
457
        IndexMethod::Btree => create_btree_index(
458
            db,
459
            &table_name_str,
1✔
460
            &index_name,
1✔
461
            &column_name,
1✔
462
            &datatype,
463
            *unique,
1✔
464
            &existing_rowids_and_values,
1✔
465
        ),
466
        IndexMethod::Hnsw => create_hnsw_index(
467
            db,
468
            &table_name_str,
1✔
469
            &index_name,
1✔
470
            &column_name,
1✔
471
            &datatype,
472
            *unique,
1✔
473
            &existing_rowids_and_values,
1✔
474
        ),
475
    }
476
}
477

478
/// `USING <method>` choices recognized by `execute_create_index`. A
479
/// missing USING clause defaults to `Btree` so existing CREATE INDEX
480
/// statements (Phase 3e) keep working unchanged.
481
#[derive(Debug, Clone, Copy)]
482
enum IndexMethod {
483
    Btree,
484
    Hnsw,
485
}
486

487
/// Builds a Phase 3e B-Tree secondary index and attaches it to the table.
488
fn create_btree_index(
1✔
489
    db: &mut Database,
490
    table_name: &str,
491
    index_name: &str,
492
    column_name: &str,
493
    datatype: &DataType,
494
    unique: bool,
495
    existing: &[(i64, Value)],
496
) -> Result<String> {
497
    let mut idx = SecondaryIndex::new(
3✔
498
        index_name.to_string(),
1✔
499
        table_name.to_string(),
2✔
500
        column_name.to_string(),
1✔
501
        datatype,
502
        unique,
503
        IndexOrigin::Explicit,
504
    )?;
505

506
    // Populate from existing rows. UNIQUE violations here mean the
507
    // existing data already breaks the new index's constraint — a
508
    // common source of user confusion, so be explicit.
509
    for (rowid, v) in existing {
2✔
510
        if unique && idx.would_violate_unique(v) {
2✔
511
            return Err(SQLRiteError::General(format!(
1✔
512
                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
513
                 already contains the duplicate value {}",
514
                v.to_display_string()
1✔
515
            )));
516
        }
517
        idx.insert(v, *rowid)?;
2✔
518
    }
519

520
    let table_mut = db.get_table_mut(table_name.to_string())?;
1✔
521
    table_mut.secondary_indexes.push(idx);
1✔
522
    Ok(index_name.to_string())
1✔
523
}
524

525
/// Builds a Phase 7d.2 HNSW index and attaches it to the table.
526
fn create_hnsw_index(
1✔
527
    db: &mut Database,
528
    table_name: &str,
529
    index_name: &str,
530
    column_name: &str,
531
    datatype: &DataType,
532
    unique: bool,
533
    existing: &[(i64, Value)],
534
) -> Result<String> {
535
    // HNSW only makes sense on VECTOR columns. Reject anything else
536
    // with a clear message — this is the most likely user error.
537
    let dim = match datatype {
1✔
538
        DataType::Vector(d) => *d,
1✔
539
        other => {
1✔
540
            return Err(SQLRiteError::General(format!(
1✔
541
                "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
542
            )));
543
        }
544
    };
545

546
    if unique {
1✔
547
        return Err(SQLRiteError::General(
×
548
            "UNIQUE has no meaning for HNSW indexes".to_string(),
×
549
        ));
550
    }
551

552
    // Build the in-memory graph. Distance metric is L2 by default
553
    // (Phase 7d.2 doesn't yet expose a knob for picking cosine/dot —
554
    // see `docs/phase-7-plan.md` for the deferral).
555
    //
556
    // Seed: hash the index name so different indexes get different
557
    // graph topologies, but the same index always gets the same one
558
    // — useful when debugging recall / index size.
559
    let seed = hash_str_to_seed(index_name);
1✔
560
    let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
1✔
561

562
    // Snapshot the (rowid, vector) pairs into a side map so the
563
    // get_vec closure below can serve them by id without re-borrowing
564
    // the table (we're already holding `existing` — flatten it).
565
    let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1✔
566
        std::collections::HashMap::with_capacity(existing.len());
567
    for (rowid, v) in existing {
2✔
568
        match v {
1✔
569
            Value::Vector(vec) => {
1✔
570
                if vec.len() != dim {
1✔
571
                    return Err(SQLRiteError::Internal(format!(
×
572
                        "row {rowid} stores a {}-dim vector in column '{column_name}' \
573
                         declared as VECTOR({dim}) — schema invariant violated",
574
                        vec.len()
×
575
                    )));
576
                }
577
                vec_map.insert(*rowid, vec.clone());
2✔
578
            }
579
            // Non-vector values (theoretical NULL, type coercion bug)
580
            // get skipped — they wouldn't have a sensible graph
581
            // position anyway.
582
            _ => continue,
583
        }
584
    }
585

586
    for (rowid, _) in existing {
1✔
587
        if let Some(v) = vec_map.get(rowid) {
2✔
588
            let v_clone = v.clone();
1✔
589
            idx.insert(*rowid, &v_clone, |id| {
3✔
590
                vec_map.get(&id).cloned().unwrap_or_default()
1✔
591
            });
592
        }
593
    }
594

595
    let table_mut = db.get_table_mut(table_name.to_string())?;
1✔
596
    table_mut.hnsw_indexes.push(HnswIndexEntry {
2✔
597
        name: index_name.to_string(),
1✔
598
        column_name: column_name.to_string(),
1✔
599
        index: idx,
1✔
600
        // Freshly built — no DELETE/UPDATE has invalidated it yet.
601
        needs_rebuild: false,
602
    });
603
    Ok(index_name.to_string())
1✔
604
}
605

606
/// Stable, deterministic hash of a string into a u64 RNG seed. FNV-1a;
607
/// avoids pulling in `std::hash::DefaultHasher` (which is randomized
608
/// per process).
609
fn hash_str_to_seed(s: &str) -> u64 {
1✔
610
    let mut h: u64 = 0xCBF29CE484222325;
1✔
611
    for b in s.as_bytes() {
2✔
612
        h ^= *b as u64;
1✔
613
        h = h.wrapping_mul(0x100000001B3);
1✔
614
    }
615
    h
1✔
616
}
617

618
/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
619
/// because the enum has no ergonomic reason to be cloneable elsewhere.
620
fn clone_datatype(dt: &DataType) -> DataType {
1✔
621
    match dt {
1✔
622
        DataType::Integer => DataType::Integer,
1✔
623
        DataType::Text => DataType::Text,
1✔
624
        DataType::Real => DataType::Real,
×
625
        DataType::Bool => DataType::Bool,
×
626
        DataType::Vector(dim) => DataType::Vector(*dim),
1✔
NEW
627
        DataType::Json => DataType::Json,
×
628
        DataType::None => DataType::None,
×
629
        DataType::Invalid => DataType::Invalid,
×
630
    }
631
}
632

633
fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1✔
634
    if tables.len() != 1 {
1✔
635
        return Err(SQLRiteError::NotImplemented(
×
636
            "multi-table DELETE is not supported yet".to_string(),
×
637
        ));
638
    }
639
    extract_table_name(&tables[0])
2✔
640
}
641

642
fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1✔
643
    if !twj.joins.is_empty() {
1✔
644
        return Err(SQLRiteError::NotImplemented(
×
645
            "JOIN is not supported yet".to_string(),
×
646
        ));
647
    }
648
    match &twj.relation {
1✔
649
        TableFactor::Table { name, .. } => Ok(name.to_string()),
1✔
650
        _ => Err(SQLRiteError::NotImplemented(
×
651
            "only plain table references are supported".to_string(),
×
652
        )),
653
    }
654
}
655

656
/// Tells the executor how to produce its candidate rowid list.
657
enum RowidSource {
658
    /// The WHERE was simple enough to probe a secondary index directly.
659
    /// The `Vec` already contains exactly the rows the index matched;
660
    /// no further WHERE evaluation is needed (the probe is precise).
661
    IndexProbe(Vec<i64>),
662
    /// No applicable index; caller falls back to walking `table.rowids()`
663
    /// and evaluating the WHERE on each row.
664
    FullScan,
665
}
666

667
/// Try to satisfy `WHERE` with an index probe. Currently supports the
668
/// simplest shape: a single `col = literal` (or `literal = col`) where
669
/// `col` is on a secondary index. AND/OR/range predicates fall back to
670
/// full scan — those can be layered on later without changing the caller.
671
fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1✔
672
    let Some(expr) = selection else {
1✔
673
        return Ok(RowidSource::FullScan);
1✔
674
    };
675
    let Some((col, literal)) = try_extract_equality(expr) else {
2✔
676
        return Ok(RowidSource::FullScan);
1✔
677
    };
678
    let Some(idx) = table.index_for_column(&col) else {
2✔
679
        return Ok(RowidSource::FullScan);
1✔
680
    };
681

682
    // Convert the literal into a runtime Value. If the literal type doesn't
683
    // match the column's index we still need correct semantics — evaluate
684
    // the WHERE against every row. Fall back to full scan.
685
    let literal_value = match convert_literal(&literal) {
2✔
686
        Ok(v) => v,
1✔
687
        Err(_) => return Ok(RowidSource::FullScan),
×
688
    };
689

690
    // Index lookup returns the full list of rowids matching this equality
691
    // predicate. For unique indexes that's at most one; for non-unique it
692
    // can be many.
693
    let mut rowids = idx.lookup(&literal_value);
1✔
694
    rowids.sort_unstable();
2✔
695
    Ok(RowidSource::IndexProbe(rowids))
1✔
696
}
697

698
/// Recognizes `expr` as a simple equality on a column reference against a
699
/// literal. Returns `(column_name, literal_value)` if the shape matches;
700
/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
701
fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1✔
702
    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
703
    let peeled = match expr {
1✔
704
        Expr::Nested(inner) => inner.as_ref(),
1✔
705
        other => other,
1✔
706
    };
707
    let Expr::BinaryOp { left, op, right } = peeled else {
1✔
708
        return None;
×
709
    };
710
    if !matches!(op, BinaryOperator::Eq) {
1✔
711
        return None;
1✔
712
    }
713
    let col_from = |e: &Expr| -> Option<String> {
1✔
714
        match e {
1✔
715
            Expr::Identifier(ident) => Some(ident.value.clone()),
1✔
716
            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
×
717
            _ => None,
1✔
718
        }
719
    };
720
    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1✔
721
        if let Expr::Value(v) = e {
2✔
722
            Some(v.value.clone())
1✔
723
        } else {
724
            None
1✔
725
        }
726
    };
727
    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
3✔
728
        return Some((c, l));
1✔
729
    }
730
    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
3✔
731
        return Some((c, l));
1✔
732
    }
733
    None
1✔
734
}
735

736
/// Recognizes the HNSW-probable query pattern and probes the graph
737
/// if a matching index exists.
738
///
739
/// Looks for ORDER BY `vec_distance_l2(<col>, <bracket-array literal>)`
740
/// where the table has an HNSW index attached to `<col>`. On a match,
741
/// returns the top-k rowids straight from the graph (O(log N)). On
742
/// any miss — different function name, no matching index, query
743
/// dimension wrong, etc. — returns `None` and the caller falls through
744
/// to the bounded-heap brute-force path (7c) or the full sort (7b),
745
/// preserving correct results regardless of whether the HNSW pathway
746
/// kicked in.
747
///
748
/// Phase 7d.2 caveats:
749
/// - Only `vec_distance_l2` is recognized. Cosine and dot fall through
750
///   to brute-force because we don't yet expose a per-index distance
751
///   knob (deferred to Phase 7d.x — see `docs/phase-7-plan.md`).
752
/// - Only ASCENDING order makes sense for "k nearest" — DESC ORDER BY
753
///   `vec_distance_l2(...) LIMIT k` would mean "k farthest", which
754
///   isn't what the index is built for. We don't bother to detect
755
///   `ascending == false` here; the optimizer just skips and the
756
///   fallback path handles it correctly (slower).
757
fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1✔
758
    if k == 0 {
1✔
759
        return None;
×
760
    }
761

762
    // Pattern-match: order expr must be a function call vec_distance_l2(a, b).
763
    let func = match order_expr {
1✔
764
        Expr::Function(f) => f,
1✔
765
        _ => return None,
1✔
766
    };
767
    let fname = match func.name.0.as_slice() {
2✔
768
        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2✔
769
        _ => return None,
×
770
    };
771
    if fname != "vec_distance_l2" {
2✔
772
        return None;
×
773
    }
774

775
    // Extract the two args as raw Exprs.
776
    let arg_list = match &func.args {
1✔
777
        FunctionArguments::List(l) => &l.args,
1✔
778
        _ => return None,
×
779
    };
780
    if arg_list.len() != 2 {
2✔
781
        return None;
×
782
    }
783
    let exprs: Vec<&Expr> = arg_list
1✔
784
        .iter()
785
        .filter_map(|a| match a {
3✔
786
            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1✔
787
            _ => None,
×
788
        })
789
        .collect();
790
    if exprs.len() != 2 {
2✔
791
        return None;
×
792
    }
793

794
    // One arg must be a column reference (the indexed col); the other
795
    // must be a bracket-array literal (the query vector). Try both
796
    // orderings — pgvector's idiom puts the column on the left, but
797
    // SQL is commutative for distance.
798
    let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
3✔
799
        Some(v) => v,
1✔
800
        None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
×
801
            Some(v) => v,
×
802
            None => return None,
×
803
        },
804
    };
805

806
    // Find the HNSW index on this column.
807
    let entry = table
4✔
808
        .hnsw_indexes
809
        .iter()
1✔
810
        .find(|e| e.column_name == col_name)?;
3✔
811

812
    // Dimension sanity check — the query vector must match the
813
    // indexed column's declared dimension. If it doesn't, the brute-
814
    // force fallback would also error at the vec_distance_l2 dim-check;
815
    // returning None here lets that path produce the user-visible
816
    // error message.
817
    let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
3✔
818
        Some(c) => match &c.datatype {
1✔
819
            DataType::Vector(d) => *d,
1✔
820
            _ => return None,
×
821
        },
822
        None => return None,
×
823
    };
824
    if query_vec.len() != declared_dim {
2✔
825
        return None;
×
826
    }
827

828
    // Probe the graph. Vectors are looked up from the table's row
829
    // storage — a closure rather than a `&Table` so the algorithm
830
    // module stays decoupled from the SQL types.
831
    let column_for_closure = col_name.clone();
1✔
832
    let table_ref = table;
833
    let result = entry.index.search(&query_vec, k, |id| {
3✔
834
        match table_ref.get_value(&column_for_closure, id) {
1✔
835
            Some(Value::Vector(v)) => v,
1✔
836
            _ => Vec::new(),
×
837
        }
838
    });
839
    Some(result)
1✔
840
}
841

842
/// Helper for `try_hnsw_probe`: given two function args, identify which
843
/// one is a bare column identifier (the indexed column) and which is a
844
/// bracket-array literal (the query vector). Returns
845
/// `Some((column_name, query_vec))` on a match, `None` otherwise.
846
fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1✔
847
    let col_name = match a {
1✔
848
        Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2✔
849
        _ => return None,
×
850
    };
851
    let lit_str = match b {
1✔
852
        Expr::Identifier(ident) if ident.quote_style == Some('[') => {
2✔
853
            format!("[{}]", ident.value)
1✔
854
        }
855
        _ => return None,
×
856
    };
857
    let v = parse_vector_literal(&lit_str).ok()?;
2✔
858
    Some((col_name, v))
1✔
859
}
860

861
/// One entry in the bounded-heap top-k path. Holds a pre-evaluated
862
/// sort key + the rowid it came from. The `asc` flag inverts `Ord`
863
/// so a single `BinaryHeap<HeapEntry>` works for both ASC and DESC
864
/// without wrapping in `std::cmp::Reverse` at the call site:
865
///
866
///   - ASC LIMIT k = "k smallest": natural Ord. Max-heap top is the
867
///     largest currently kept; new items smaller than top displace.
868
///   - DESC LIMIT k = "k largest": Ord reversed. Max-heap top is now
869
///     the smallest currently kept (under reversed Ord, smallest
870
///     looks largest); new items larger than top displace.
871
///
872
/// In both cases the displacement test reduces to "new entry < heap top".
873
struct HeapEntry {
874
    key: Value,
875
    rowid: i64,
876
    asc: bool,
877
}
878

879
impl PartialEq for HeapEntry {
880
    fn eq(&self, other: &Self) -> bool {
×
881
        self.cmp(other) == Ordering::Equal
×
882
    }
883
}
884

885
impl Eq for HeapEntry {}
886

887
impl PartialOrd for HeapEntry {
888
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1✔
889
        Some(self.cmp(other))
1✔
890
    }
891
}
892

893
impl Ord for HeapEntry {
894
    fn cmp(&self, other: &Self) -> Ordering {
1✔
895
        let raw = compare_values(Some(&self.key), Some(&other.key));
1✔
896
        if self.asc { raw } else { raw.reverse() }
1✔
897
    }
898
}
899

900
/// Bounded-heap top-k selection. Returns at most `k` rowids in the
901
/// caller's desired order (ascending key for `order.ascending`,
902
/// descending otherwise).
903
///
904
/// O(N log k) where N = `matching.len()`. Caller must check
905
/// `k < matching.len()` for this to be a win — for k ≥ N the
906
/// `sort_rowids` full-sort path is the same asymptotic cost without
907
/// the heap overhead.
908
fn select_topk(
1✔
909
    matching: &[i64],
910
    table: &Table,
911
    order: &OrderByClause,
912
    k: usize,
913
) -> Result<Vec<i64>> {
914
    use std::collections::BinaryHeap;
915

916
    if k == 0 || matching.is_empty() {
1✔
917
        return Ok(Vec::new());
1✔
918
    }
919

920
    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1✔
921

922
    for &rowid in matching {
3✔
923
        let key = eval_expr(&order.expr, table, rowid)?;
2✔
924
        let entry = HeapEntry {
925
            key,
926
            rowid,
927
            asc: order.ascending,
1✔
928
        };
929

930
        if heap.len() < k {
2✔
931
            heap.push(entry);
2✔
932
        } else {
933
            // peek() returns the largest under our direction-aware Ord
934
            // — the worst entry currently kept. Displace it iff the
935
            // new entry is "better" (i.e. compares Less).
936
            if entry < *heap.peek().unwrap() {
2✔
937
                heap.pop();
1✔
938
                heap.push(entry);
1✔
939
            }
940
        }
941
    }
942

943
    // `into_sorted_vec` returns ascending under our direction-aware Ord:
944
    //   ASC: ascending by raw key (what we want)
945
    //   DESC: ascending under reversed Ord = descending by raw key (what
946
    //         we want for an ORDER BY DESC LIMIT k result)
947
    Ok(heap
2✔
948
        .into_sorted_vec()
1✔
949
        .into_iter()
1✔
950
        .map(|e| e.rowid)
3✔
951
        .collect())
1✔
952
}
953

954
fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1✔
955
    // Phase 7b: ORDER BY now accepts any expression (column ref,
956
    // arithmetic, function call, …). Pre-compute the sort key for
957
    // every rowid up front so the comparator is called O(N log N)
958
    // times against pre-evaluated Values rather than re-evaluating
959
    // the expression O(N log N) times. Not strictly necessary today,
960
    // but vital once 7d's HNSW index lands and this same code path
961
    // could be running tens of millions of distance computations.
962
    let mut keys: Vec<(i64, Result<Value>)> = rowids
2✔
963
        .iter()
964
        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
3✔
965
        .collect();
966

967
    // Surface the FIRST evaluation error if any. We could be lazy
968
    // and let sort_by encounter it, but `Ord::cmp` can't return a
969
    // Result and we'd have to swallow errors silently.
970
    for (_, k) in &keys {
2✔
971
        if let Err(e) = k {
1✔
972
            return Err(SQLRiteError::General(format!(
×
973
                "ORDER BY expression failed: {e}"
974
            )));
975
        }
976
    }
977

978
    keys.sort_by(|(_, ka), (_, kb)| {
3✔
979
        // Both unwrap()s are safe — we just verified above that
980
        // every key Result is Ok.
981
        let va = ka.as_ref().unwrap();
1✔
982
        let vb = kb.as_ref().unwrap();
1✔
983
        let ord = compare_values(Some(va), Some(vb));
1✔
984
        if order.ascending { ord } else { ord.reverse() }
1✔
985
    });
986

987
    // Write the sorted rowids back into the caller's slice.
988
    for (i, (rowid, _)) in keys.into_iter().enumerate() {
2✔
989
        rowids[i] = rowid;
2✔
990
    }
991
    Ok(())
1✔
992
}
993

994
fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1✔
995
    match (a, b) {
2✔
996
        (None, None) => Ordering::Equal,
×
997
        (None, _) => Ordering::Less,
×
998
        (_, None) => Ordering::Greater,
×
999
        (Some(a), Some(b)) => match (a, b) {
2✔
1000
            (Value::Null, Value::Null) => Ordering::Equal,
×
1001
            (Value::Null, _) => Ordering::Less,
×
1002
            (_, Value::Null) => Ordering::Greater,
×
1003
            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1✔
1004
            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1✔
1005
            (Value::Integer(x), Value::Real(y)) => {
×
1006
                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
×
1007
            }
1008
            (Value::Real(x), Value::Integer(y)) => {
×
1009
                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
×
1010
            }
1011
            (Value::Text(x), Value::Text(y)) => x.cmp(y),
1✔
1012
            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
×
1013
            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
1014
            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
×
1015
        },
1016
    }
1017
}
1018

1019
/// Returns `true` if the row at `rowid` matches the predicate expression.
1020
pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1✔
1021
    let v = eval_expr(expr, table, rowid)?;
2✔
1022
    match v {
1✔
1023
        Value::Bool(b) => Ok(b),
1✔
1024
        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
1025
        Value::Integer(i) => Ok(i != 0),
×
1026
        other => Err(SQLRiteError::Internal(format!(
×
1027
            "WHERE clause must evaluate to boolean, got {}",
1028
            other.to_display_string()
×
1029
        ))),
1030
    }
1031
}
1032

1033
fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1✔
1034
    match expr {
1✔
1035
        Expr::Nested(inner) => eval_expr(inner, table, rowid),
2✔
1036

1037
        Expr::Identifier(ident) => {
1✔
1038
            // Phase 7b — sqlparser parses bracket-array literals like
1039
            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
1040
            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
1041
            // in expression-evaluation position (SELECT projection, WHERE,
1042
            // ORDER BY, function args), parse the bracketed content as a
1043
            // vector literal so the rest of the executor can compare /
1044
            // distance-compute against it. Same trick the INSERT parser
1045
            // uses; the executor needed its own copy because expression
1046
            // eval runs on a different code path.
1047
            if ident.quote_style == Some('[') {
1✔
1048
                let raw = format!("[{}]", ident.value);
1✔
1049
                let v = parse_vector_literal(&raw)?;
2✔
1050
                return Ok(Value::Vector(v));
1✔
1051
            }
1052
            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1✔
1053
        }
1054

1055
        Expr::CompoundIdentifier(parts) => {
×
1056
            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
1057
            let col = parts
×
1058
                .last()
×
1059
                .map(|i| i.value.as_str())
×
1060
                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
×
1061
            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
×
1062
        }
1063

1064
        Expr::Value(v) => convert_literal(&v.value),
1✔
1065

1066
        Expr::UnaryOp { op, expr } => {
×
1067
            let inner = eval_expr(expr, table, rowid)?;
×
1068
            match op {
×
1069
                UnaryOperator::Not => match inner {
×
1070
                    Value::Bool(b) => Ok(Value::Bool(!b)),
×
1071
                    Value::Null => Ok(Value::Null),
×
1072
                    other => Err(SQLRiteError::Internal(format!(
×
1073
                        "NOT applied to non-boolean value: {}",
1074
                        other.to_display_string()
×
1075
                    ))),
1076
                },
1077
                UnaryOperator::Minus => match inner {
×
1078
                    Value::Integer(i) => Ok(Value::Integer(-i)),
×
1079
                    Value::Real(f) => Ok(Value::Real(-f)),
×
1080
                    Value::Null => Ok(Value::Null),
×
1081
                    other => Err(SQLRiteError::Internal(format!(
×
1082
                        "unary minus on non-numeric value: {}",
1083
                        other.to_display_string()
×
1084
                    ))),
1085
                },
1086
                UnaryOperator::Plus => Ok(inner),
×
1087
                other => Err(SQLRiteError::NotImplemented(format!(
×
1088
                    "unary operator {other:?} is not supported"
1089
                ))),
1090
            }
1091
        }
1092

1093
        Expr::BinaryOp { left, op, right } => match op {
1✔
1094
            BinaryOperator::And => {
1095
                let l = eval_expr(left, table, rowid)?;
×
1096
                let r = eval_expr(right, table, rowid)?;
×
1097
                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
×
1098
            }
1099
            BinaryOperator::Or => {
1100
                let l = eval_expr(left, table, rowid)?;
×
1101
                let r = eval_expr(right, table, rowid)?;
×
1102
                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
×
1103
            }
1104
            cmp @ (BinaryOperator::Eq
1105
            | BinaryOperator::NotEq
1106
            | BinaryOperator::Lt
1107
            | BinaryOperator::LtEq
1108
            | BinaryOperator::Gt
1109
            | BinaryOperator::GtEq) => {
1110
                let l = eval_expr(left, table, rowid)?;
2✔
1111
                let r = eval_expr(right, table, rowid)?;
2✔
1112
                // Any comparison involving NULL is unknown → false in a WHERE.
1113
                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1✔
1114
                    return Ok(Value::Bool(false));
1✔
1115
                }
1116
                let ord = compare_values(Some(&l), Some(&r));
2✔
1117
                let result = match cmp {
1✔
1118
                    BinaryOperator::Eq => ord == Ordering::Equal,
2✔
1119
                    BinaryOperator::NotEq => ord != Ordering::Equal,
×
1120
                    BinaryOperator::Lt => ord == Ordering::Less,
2✔
1121
                    BinaryOperator::LtEq => ord != Ordering::Greater,
×
1122
                    BinaryOperator::Gt => ord == Ordering::Greater,
2✔
1123
                    BinaryOperator::GtEq => ord != Ordering::Less,
×
1124
                    _ => unreachable!(),
1125
                };
1126
                Ok(Value::Bool(result))
1✔
1127
            }
1128
            arith @ (BinaryOperator::Plus
1129
            | BinaryOperator::Minus
1130
            | BinaryOperator::Multiply
1131
            | BinaryOperator::Divide
1132
            | BinaryOperator::Modulo) => {
1133
                let l = eval_expr(left, table, rowid)?;
2✔
1134
                let r = eval_expr(right, table, rowid)?;
2✔
1135
                eval_arith(arith, &l, &r)
1✔
1136
            }
1137
            BinaryOperator::StringConcat => {
1138
                let l = eval_expr(left, table, rowid)?;
×
1139
                let r = eval_expr(right, table, rowid)?;
×
1140
                if matches!(l, Value::Null) || matches!(r, Value::Null) {
×
1141
                    return Ok(Value::Null);
×
1142
                }
1143
                Ok(Value::Text(format!(
×
1144
                    "{}{}",
1145
                    l.to_display_string(),
×
1146
                    r.to_display_string()
×
1147
                )))
1148
            }
1149
            other => Err(SQLRiteError::NotImplemented(format!(
×
1150
                "binary operator {other:?} is not supported yet"
1151
            ))),
1152
        },
1153

1154
        // Phase 7b — function-call dispatch. Currently only the three
1155
        // vector-distance functions; this match arm becomes the single
1156
        // place to register more SQL functions later (e.g. abs(),
1157
        // length(), …) without re-touching the rest of the executor.
1158
        //
1159
        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
1160
        // of three don't parse natively in sqlparser (we'd need a
1161
        // string-preprocessing pass or a sqlparser fork). Deferred to
1162
        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
1163
        // corrections" note.
1164
        Expr::Function(func) => eval_function(func, table, rowid),
1✔
1165

1166
        other => Err(SQLRiteError::NotImplemented(format!(
×
1167
            "unsupported expression in WHERE/projection: {other:?}"
1168
        ))),
1169
    }
1170
}
1171

1172
/// Dispatches an `Expr::Function` to its built-in implementation.
1173
/// Currently only the three vec_distance_* functions; other functions
1174
/// surface as `NotImplemented` errors with the function name in the
1175
/// message so users see what they tried.
1176
fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1✔
1177
    // Function name lives in `name.0[0]` for unqualified calls. Anything
1178
    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
1179
    let name = match func.name.0.as_slice() {
2✔
1180
        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2✔
1181
        _ => {
1182
            return Err(SQLRiteError::NotImplemented(format!(
×
1183
                "qualified function names not supported: {:?}",
1184
                func.name
1185
            )));
1186
        }
1187
    };
1188

1189
    match name.as_str() {
2✔
1190
        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2✔
1191
            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
3✔
1192
            let dist = match name.as_str() {
2✔
1193
                "vec_distance_l2" => vec_distance_l2(&a, &b),
3✔
1194
                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
4✔
1195
                "vec_distance_dot" => vec_distance_dot(&a, &b),
3✔
1196
                _ => unreachable!(),
1197
            };
1198
            // Widen f32 → f64 for the runtime Value. Vectors are stored
1199
            // as f32 (consistent with industry convention for embeddings),
1200
            // but the executor's numeric type is f64 so distances slot
1201
            // into Value::Real cleanly and can be compared / ordered with
1202
            // other reals via the existing arithmetic + comparison paths.
1203
            Ok(Value::Real(dist as f64))
1✔
1204
        }
1205
        // Phase 7e — JSON functions. All four parse the JSON text on
1206
        // demand (we don't cache parsed values), then resolve a path
1207
        // (default `$` = root). The path resolver handles `.key` for
1208
        // object access and `[N]` for array index. SQLite-style.
1209
        "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
3✔
1210
        "json_type" => json_fn_type(&name, &func.args, table, rowid),
4✔
1211
        "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
4✔
1212
        "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
2✔
1213
        other => Err(SQLRiteError::NotImplemented(format!(
2✔
1214
            "unknown function: {other}(...)"
1215
        ))),
1216
    }
1217
}
1218

1219
// -----------------------------------------------------------------
1220
// Phase 7e — JSON path-extraction functions
1221
// -----------------------------------------------------------------
1222

1223
/// Extracts the JSON-typed text + optional path string out of a
1224
/// function call's args. Used by all four json_* functions.
1225
///
1226
/// Arity rules (matching SQLite JSON1):
1227
///   - 1 arg  → JSON value, path defaults to `$` (root)
1228
///   - 2 args → (JSON value, path text)
1229
///
1230
/// Returns `(json_text, path)` so caller can serde_json::from_str
1231
/// + walk_json_path on it.
1232
fn extract_json_and_path(
1✔
1233
    fn_name: &str,
1234
    args: &FunctionArguments,
1235
    table: &Table,
1236
    rowid: i64,
1237
) -> Result<(String, String)> {
1238
    let arg_list = match args {
1✔
1239
        FunctionArguments::List(l) => &l.args,
1✔
1240
        _ => {
NEW
1241
            return Err(SQLRiteError::General(format!(
×
1242
                "{fn_name}() expects 1 or 2 arguments"
1243
            )));
1244
        }
1245
    };
1246
    if !(arg_list.len() == 1 || arg_list.len() == 2) {
2✔
NEW
1247
        return Err(SQLRiteError::General(format!(
×
1248
            "{fn_name}() expects 1 or 2 arguments, got {}",
NEW
1249
            arg_list.len()
×
1250
        )));
1251
    }
1252
    // Evaluate first arg → must produce text.
1253
    let first_expr = match &arg_list[0] {
2✔
1254
        FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1✔
NEW
1255
        other => {
×
NEW
1256
            return Err(SQLRiteError::NotImplemented(format!(
×
1257
                "{fn_name}() argument 0 has unsupported shape: {other:?}"
1258
            )));
1259
        }
1260
    };
1261
    let json_text = match eval_expr(first_expr, table, rowid)? {
1✔
1262
        Value::Text(s) => s,
1✔
1263
        Value::Null => {
NEW
1264
            return Err(SQLRiteError::General(format!(
×
1265
                "{fn_name}() called on NULL — JSON column has no value for this row"
1266
            )));
1267
        }
NEW
1268
        other => {
×
NEW
1269
            return Err(SQLRiteError::General(format!(
×
1270
                "{fn_name}() argument 0 is not JSON-typed: got {}",
NEW
1271
                other.to_display_string()
×
1272
            )));
1273
        }
1274
    };
1275

1276
    // Path defaults to root `$` when omitted.
1277
    let path = if arg_list.len() == 2 {
2✔
1278
        let path_expr = match &arg_list[1] {
2✔
1279
            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1✔
NEW
1280
            other => {
×
NEW
1281
                return Err(SQLRiteError::NotImplemented(format!(
×
1282
                    "{fn_name}() argument 1 has unsupported shape: {other:?}"
1283
                )));
1284
            }
1285
        };
1286
        match eval_expr(path_expr, table, rowid)? {
1✔
1287
            Value::Text(s) => s,
1✔
NEW
1288
            other => {
×
NEW
1289
                return Err(SQLRiteError::General(format!(
×
1290
                    "{fn_name}() path argument must be a string literal, got {}",
NEW
1291
                    other.to_display_string()
×
1292
                )));
1293
            }
1294
        }
1295
    } else {
NEW
1296
        "$".to_string()
×
1297
    };
1298

1299
    Ok((json_text, path))
1✔
1300
}
1301

1302
/// Walks a `serde_json::Value` along a JSONPath subset:
1303
///   - `$` is the root
1304
///   - `.key` for object access (key may not contain `.` or `[`)
1305
///   - `[N]` for array index (N a non-negative integer)
1306
///   - chains arbitrarily: `$.foo.bar[0].baz`
1307
///
1308
/// Returns `Ok(None)` for "path didn't match anything" (NULL in SQL),
1309
/// `Err` for malformed paths. Matches SQLite JSON1's semantic
1310
/// distinction: missing-key = NULL, malformed-path = error.
1311
fn walk_json_path<'a>(
1✔
1312
    value: &'a serde_json::Value,
1313
    path: &str,
1314
) -> Result<Option<&'a serde_json::Value>> {
1315
    let mut chars = path.chars().peekable();
1✔
1316
    if chars.next() != Some('$') {
1✔
1317
        return Err(SQLRiteError::General(format!(
1✔
1318
            "JSON path must start with '$', got `{path}`"
1319
        )));
1320
    }
1321
    let mut current = value;
1✔
1322
    while let Some(&c) = chars.peek() {
2✔
1323
        match c {
1✔
1324
            '.' => {
1325
                chars.next();
1✔
1326
                let mut key = String::new();
1✔
1327
                while let Some(&c) = chars.peek() {
2✔
1328
                    if c == '.' || c == '[' {
2✔
1329
                        break;
1330
                    }
1331
                    key.push(c);
1✔
1332
                    chars.next();
1✔
1333
                }
1334
                if key.is_empty() {
2✔
NEW
1335
                    return Err(SQLRiteError::General(format!(
×
1336
                        "JSON path has empty key after '.' in `{path}`"
1337
                    )));
1338
                }
1339
                match current.get(&key) {
2✔
1340
                    Some(v) => current = v,
1✔
1341
                    None => return Ok(None),
1✔
1342
                }
1343
            }
1344
            '[' => {
1345
                chars.next();
1✔
1346
                let mut idx_str = String::new();
1✔
1347
                while let Some(&c) = chars.peek() {
2✔
1348
                    if c == ']' {
1✔
1349
                        break;
1350
                    }
1351
                    idx_str.push(c);
1✔
1352
                    chars.next();
1✔
1353
                }
1354
                if chars.next() != Some(']') {
2✔
NEW
1355
                    return Err(SQLRiteError::General(format!(
×
1356
                        "JSON path has unclosed `[` in `{path}`"
1357
                    )));
1358
                }
1359
                let idx: usize = idx_str.trim().parse().map_err(|_| {
2✔
NEW
1360
                    SQLRiteError::General(format!(
×
1361
                        "JSON path has non-integer index `[{idx_str}]` in `{path}`"
1362
                    ))
1363
                })?;
1364
                match current.get(idx) {
1✔
1365
                    Some(v) => current = v,
1✔
NEW
1366
                    None => return Ok(None),
×
1367
                }
1368
            }
NEW
1369
            other => {
×
NEW
1370
                return Err(SQLRiteError::General(format!(
×
1371
                    "JSON path has unexpected character `{other}` in `{path}` \
1372
                     (expected `.`, `[`, or end-of-path)"
1373
                )));
1374
            }
1375
        }
1376
    }
1377
    Ok(Some(current))
1✔
1378
}
1379

1380
/// Converts a serde_json scalar to a SQLRite Value. For composite
1381
/// types (object, array) returns the JSON-encoded text — callers
1382
/// pattern-match on shape from the calling json_* function.
1383
fn json_value_to_sql(v: &serde_json::Value) -> Value {
1✔
1384
    match v {
1✔
NEW
1385
        serde_json::Value::Null => Value::Null,
×
NEW
1386
        serde_json::Value::Bool(b) => Value::Bool(*b),
×
1387
        serde_json::Value::Number(n) => {
1✔
1388
            // Match SQLite: integer if it fits an i64, else f64.
1389
            if let Some(i) = n.as_i64() {
3✔
1390
                Value::Integer(i)
1✔
NEW
1391
            } else if let Some(f) = n.as_f64() {
×
NEW
1392
                Value::Real(f)
×
1393
            } else {
NEW
1394
                Value::Null
×
1395
            }
1396
        }
1397
        serde_json::Value::String(s) => Value::Text(s.clone()),
1✔
1398
        // Objects + arrays come out as JSON-encoded text. Same as
1399
        // SQLite's json_extract: composite results round-trip through
1400
        // text rather than being modeled as a richer Value type.
NEW
1401
        composite => Value::Text(composite.to_string()),
×
1402
    }
1403
}
1404

1405
fn json_fn_extract(
1✔
1406
    name: &str,
1407
    args: &FunctionArguments,
1408
    table: &Table,
1409
    rowid: i64,
1410
) -> Result<Value> {
1411
    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1✔
1412
    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3✔
NEW
1413
        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
×
1414
    })?;
1415
    match walk_json_path(&parsed, &path)? {
2✔
1416
        Some(v) => Ok(json_value_to_sql(v)),
2✔
1417
        None => Ok(Value::Null),
1✔
1418
    }
1419
}
1420

1421
fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
1✔
1422
    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1✔
1423
    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2✔
NEW
1424
        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
×
1425
    })?;
1426
    let resolved = match walk_json_path(&parsed, &path)? {
2✔
1427
        Some(v) => v,
1✔
NEW
1428
        None => return Ok(Value::Null),
×
1429
    };
1430
    let ty = match resolved {
2✔
1431
        serde_json::Value::Null => "null",
1✔
1432
        serde_json::Value::Bool(true) => "true",
1✔
NEW
1433
        serde_json::Value::Bool(false) => "false",
×
1434
        serde_json::Value::Number(n) => {
1✔
1435
            if n.is_i64() || n.is_u64() {
4✔
1436
                "integer"
1✔
1437
            } else {
1438
                "real"
1✔
1439
            }
1440
        }
1441
        serde_json::Value::String(_) => "text",
1✔
1442
        serde_json::Value::Array(_) => "array",
1✔
1443
        serde_json::Value::Object(_) => "object",
1✔
1444
    };
1445
    Ok(Value::Text(ty.to_string()))
2✔
1446
}
1447

1448
fn json_fn_array_length(
1✔
1449
    name: &str,
1450
    args: &FunctionArguments,
1451
    table: &Table,
1452
    rowid: i64,
1453
) -> Result<Value> {
1454
    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1✔
1455
    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
3✔
NEW
1456
        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
×
1457
    })?;
1458
    let resolved = match walk_json_path(&parsed, &path)? {
2✔
1459
        Some(v) => v,
1✔
NEW
1460
        None => return Ok(Value::Null),
×
1461
    };
1462
    match resolved.as_array() {
2✔
1463
        Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2✔
1464
        None => Err(SQLRiteError::General(format!(
1✔
1465
            "{name}() resolved to a non-array value at path `{path}`"
1466
        ))),
1467
    }
1468
}
1469

NEW
1470
fn json_fn_object_keys(
×
1471
    name: &str,
1472
    args: &FunctionArguments,
1473
    table: &Table,
1474
    rowid: i64,
1475
) -> Result<Value> {
NEW
1476
    let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
×
1477
    let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1✔
NEW
1478
        SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
×
1479
    })?;
NEW
1480
    let resolved = match walk_json_path(&parsed, &path)? {
×
NEW
1481
        Some(v) => v,
×
NEW
1482
        None => return Ok(Value::Null),
×
1483
    };
NEW
1484
    let obj = resolved.as_object().ok_or_else(|| {
×
NEW
1485
        SQLRiteError::General(format!(
×
1486
            "{name}() resolved to a non-object value at path `{path}`"
1487
        ))
1488
    })?;
1489
    // SQLite's json_object_keys is a table-valued function (one row
1490
    // per key). Without set-returning function support we can't
1491
    // reproduce that shape; instead return the keys as a JSON array
1492
    // text. Caller can iterate via json_array_length + json_extract,
1493
    // or just treat it as a serialized list. Document this divergence
1494
    // in supported-sql.md.
1495
    let keys: Vec<serde_json::Value> = obj
1496
        .keys()
NEW
1497
        .map(|k| serde_json::Value::String(k.clone()))
×
1498
        .collect();
NEW
1499
    Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
×
1500
}
1501

1502
/// Extracts exactly two `Vec<f32>` arguments from a function call,
1503
/// validating arity and that both sides are Vector-typed with matching
1504
/// dimensions. Used by all three vec_distance_* functions.
1505
fn extract_two_vector_args(
1✔
1506
    fn_name: &str,
1507
    args: &FunctionArguments,
1508
    table: &Table,
1509
    rowid: i64,
1510
) -> Result<(Vec<f32>, Vec<f32>)> {
1511
    let arg_list = match args {
1✔
1512
        FunctionArguments::List(l) => &l.args,
1✔
1513
        _ => {
1514
            return Err(SQLRiteError::General(format!(
×
1515
                "{fn_name}() expects exactly two vector arguments"
1516
            )));
1517
        }
1518
    };
1519
    if arg_list.len() != 2 {
1✔
1520
        return Err(SQLRiteError::General(format!(
×
1521
            "{fn_name}() expects exactly 2 arguments, got {}",
1522
            arg_list.len()
×
1523
        )));
1524
    }
1525
    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1✔
1526
    for (i, arg) in arg_list.iter().enumerate() {
3✔
1527
        let expr = match arg {
2✔
1528
            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1✔
1529
            other => {
×
1530
                return Err(SQLRiteError::NotImplemented(format!(
×
1531
                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
1532
                )));
1533
            }
1534
        };
1535
        let val = eval_expr(expr, table, rowid)?;
1✔
1536
        match val {
1✔
1537
            Value::Vector(v) => out.push(v),
1✔
1538
            other => {
×
1539
                return Err(SQLRiteError::General(format!(
×
1540
                    "{fn_name}() argument {i} is not a vector: got {}",
1541
                    other.to_display_string()
×
1542
                )));
1543
            }
1544
        }
1545
    }
1546
    let b = out.pop().unwrap();
1✔
1547
    let a = out.pop().unwrap();
2✔
1548
    if a.len() != b.len() {
2✔
1549
        return Err(SQLRiteError::General(format!(
1✔
1550
            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1551
            a.len(),
2✔
1552
            b.len()
1✔
1553
        )));
1554
    }
1555
    Ok((a, b))
1✔
1556
}
1557

1558
/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
1559
/// Smaller-is-closer; identical vectors return 0.0.
1560
pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1✔
1561
    debug_assert_eq!(a.len(), b.len());
1✔
1562
    let mut sum = 0.0f32;
1✔
1563
    for i in 0..a.len() {
2✔
1564
        let d = a[i] - b[i];
2✔
1565
        sum += d * d;
1✔
1566
    }
1567
    sum.sqrt()
1✔
1568
}
1569

1570
/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
1571
/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
1572
/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
1573
///
1574
/// Errors if either vector has zero magnitude — cosine similarity is
1575
/// undefined for the zero vector and silently returning NaN would
1576
/// poison `ORDER BY` ranking. Callers who want the silent-NaN
1577
/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
1578
/// themselves.
1579
pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1✔
1580
    debug_assert_eq!(a.len(), b.len());
1✔
1581
    let mut dot = 0.0f32;
1✔
1582
    let mut norm_a_sq = 0.0f32;
1✔
1583
    let mut norm_b_sq = 0.0f32;
1✔
1584
    for i in 0..a.len() {
2✔
1585
        dot += a[i] * b[i];
2✔
1586
        norm_a_sq += a[i] * a[i];
2✔
1587
        norm_b_sq += b[i] * b[i];
2✔
1588
    }
1589
    let denom = (norm_a_sq * norm_b_sq).sqrt();
1✔
1590
    if denom == 0.0 {
1✔
1591
        return Err(SQLRiteError::General(
1✔
1592
            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1✔
1593
        ));
1594
    }
1595
    Ok(1.0 - dot / denom)
1✔
1596
}
1597

1598
/// Negated dot product: −(a·b).
1599
/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
1600
/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
1601
pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1✔
1602
    debug_assert_eq!(a.len(), b.len());
1✔
1603
    let mut dot = 0.0f32;
1✔
1604
    for i in 0..a.len() {
2✔
1605
        dot += a[i] * b[i];
2✔
1606
    }
1607
    -dot
1✔
1608
}
1609

1610
/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
1611
/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
1612
fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1✔
1613
    if matches!(l, Value::Null) || matches!(r, Value::Null) {
1✔
1614
        return Ok(Value::Null);
×
1615
    }
1616
    match (l, r) {
1✔
1617
        (Value::Integer(a), Value::Integer(b)) => match op {
1✔
1618
            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1✔
1619
            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
×
1620
            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1✔
1621
            BinaryOperator::Divide => {
1622
                if *b == 0 {
×
1623
                    Err(SQLRiteError::General("division by zero".to_string()))
×
1624
                } else {
1625
                    Ok(Value::Integer(a / b))
×
1626
                }
1627
            }
1628
            BinaryOperator::Modulo => {
1629
                if *b == 0 {
×
1630
                    Err(SQLRiteError::General("modulo by zero".to_string()))
×
1631
                } else {
1632
                    Ok(Value::Integer(a % b))
×
1633
                }
1634
            }
1635
            _ => unreachable!(),
1636
        },
1637
        // Anything involving a Real promotes both sides to f64.
1638
        (a, b) => {
×
1639
            let af = as_number(a)?;
×
1640
            let bf = as_number(b)?;
×
1641
            match op {
×
1642
                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
×
1643
                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
×
1644
                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
×
1645
                BinaryOperator::Divide => {
1646
                    if bf == 0.0 {
×
1647
                        Err(SQLRiteError::General("division by zero".to_string()))
×
1648
                    } else {
1649
                        Ok(Value::Real(af / bf))
×
1650
                    }
1651
                }
1652
                BinaryOperator::Modulo => {
1653
                    if bf == 0.0 {
×
1654
                        Err(SQLRiteError::General("modulo by zero".to_string()))
×
1655
                    } else {
1656
                        Ok(Value::Real(af % bf))
×
1657
                    }
1658
                }
1659
                _ => unreachable!(),
1660
            }
1661
        }
1662
    }
1663
}
1664

1665
fn as_number(v: &Value) -> Result<f64> {
×
1666
    match v {
×
1667
        Value::Integer(i) => Ok(*i as f64),
×
1668
        Value::Real(f) => Ok(*f),
×
1669
        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
×
1670
        other => Err(SQLRiteError::General(format!(
×
1671
            "arithmetic on non-numeric value '{}'",
1672
            other.to_display_string()
×
1673
        ))),
1674
    }
1675
}
1676

1677
fn as_bool(v: &Value) -> Result<bool> {
×
1678
    match v {
×
1679
        Value::Bool(b) => Ok(*b),
×
1680
        Value::Null => Ok(false),
1681
        Value::Integer(i) => Ok(*i != 0),
×
1682
        other => Err(SQLRiteError::Internal(format!(
×
1683
            "expected boolean, got {}",
1684
            other.to_display_string()
×
1685
        ))),
1686
    }
1687
}
1688

1689
fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1✔
1690
    use sqlparser::ast::Value as AstValue;
1691
    match v {
1✔
1692
        AstValue::Number(n, _) => {
1✔
1693
            if let Ok(i) = n.parse::<i64>() {
2✔
1694
                Ok(Value::Integer(i))
1✔
1695
            } else if let Ok(f) = n.parse::<f64>() {
2✔
1696
                Ok(Value::Real(f))
1✔
1697
            } else {
1698
                Err(SQLRiteError::Internal(format!(
×
1699
                    "could not parse numeric literal '{n}'"
1700
                )))
1701
            }
1702
        }
1703
        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1✔
1704
        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
×
1705
        AstValue::Null => Ok(Value::Null),
×
1706
        other => Err(SQLRiteError::NotImplemented(format!(
×
1707
            "unsupported literal value: {other:?}"
1708
        ))),
1709
    }
1710
}
1711

1712
#[cfg(test)]
1713
mod tests {
1714
    use super::*;
1715

1716
    // -----------------------------------------------------------------
1717
    // Phase 7b — Vector distance function math
1718
    // -----------------------------------------------------------------
1719

1720
    /// Float comparison helper — distance results need a small epsilon
1721
    /// because we accumulate sums across many f32 multiplies.
1722
    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1✔
1723
        (a - b).abs() < eps
1✔
1724
    }
1725

1726
    #[test]
1727
    fn vec_distance_l2_identical_is_zero() {
3✔
1728
        let v = vec![0.1, 0.2, 0.3];
1✔
1729
        assert_eq!(vec_distance_l2(&v, &v), 0.0);
2✔
1730
    }
1731

1732
    #[test]
1733
    fn vec_distance_l2_unit_basis_is_sqrt2() {
3✔
1734
        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
1735
        let a = vec![1.0, 0.0];
1✔
1736
        let b = vec![0.0, 1.0];
2✔
1737
        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
2✔
1738
    }
1739

1740
    #[test]
1741
    fn vec_distance_l2_known_value() {
3✔
1742
        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1743
        let a = vec![0.0, 0.0, 0.0];
1✔
1744
        let b = vec![3.0, 4.0, 0.0];
2✔
1745
        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
2✔
1746
    }
1747

1748
    #[test]
1749
    fn vec_distance_cosine_identical_is_zero() {
3✔
1750
        let v = vec![0.1, 0.2, 0.3];
1✔
1751
        let d = vec_distance_cosine(&v, &v).unwrap();
2✔
1752
        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1✔
1753
    }
1754

1755
    #[test]
1756
    fn vec_distance_cosine_orthogonal_is_one() {
3✔
1757
        // Two orthogonal unit vectors should have cosine distance = 1.0
1758
        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1759
        let a = vec![1.0, 0.0];
1✔
1760
        let b = vec![0.0, 1.0];
2✔
1761
        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
2✔
1762
    }
1763

1764
    #[test]
1765
    fn vec_distance_cosine_opposite_is_two() {
3✔
1766
        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1767
        let a = vec![1.0, 0.0, 0.0];
1✔
1768
        let b = vec![-1.0, 0.0, 0.0];
2✔
1769
        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
2✔
1770
    }
1771

1772
    #[test]
1773
    fn vec_distance_cosine_zero_magnitude_errors() {
3✔
1774
        // Cosine is undefined for the zero vector — error rather than NaN.
1775
        let a = vec![0.0, 0.0];
1✔
1776
        let b = vec![1.0, 0.0];
2✔
1777
        let err = vec_distance_cosine(&a, &b).unwrap_err();
2✔
1778
        assert!(format!("{err}").contains("zero-magnitude"));
2✔
1779
    }
1780

1781
    #[test]
1782
    fn vec_distance_dot_negates() {
3✔
1783
        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1784
        let a = vec![1.0, 2.0, 3.0];
1✔
1785
        let b = vec![4.0, 5.0, 6.0];
2✔
1786
        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
2✔
1787
    }
1788

1789
    #[test]
1790
    fn vec_distance_dot_orthogonal_is_zero() {
3✔
1791
        // Orthogonal vectors have dot product 0 → negated is also 0.
1792
        let a = vec![1.0, 0.0];
1✔
1793
        let b = vec![0.0, 1.0];
2✔
1794
        assert_eq!(vec_distance_dot(&a, &b), 0.0);
2✔
1795
    }
1796

1797
    #[test]
1798
    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3✔
1799
        // For unit-norm vectors: dot(a,b) = cos(a,b)
1800
        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1801
        // Useful sanity check that the two functions agree on unit vectors.
1802
        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1✔
1803
        let b = vec![0.8f32, 0.6]; // unit norm too
2✔
1804
        let dot = vec_distance_dot(&a, &b);
2✔
1805
        let cos = vec_distance_cosine(&a, &b).unwrap();
1✔
1806
        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1✔
1807
    }
1808

1809
    // -----------------------------------------------------------------
1810
    // Phase 7c — bounded-heap top-k correctness + benchmark
1811
    // -----------------------------------------------------------------
1812

1813
    use crate::sql::db::database::Database;
1814
    use crate::sql::parser::select::SelectQuery;
1815
    use sqlparser::dialect::SQLiteDialect;
1816
    use sqlparser::parser::Parser;
1817

1818
    /// Builds a `docs(id INTEGER PK, score REAL)` table with N rows of
1819
    /// distinct positive scores so top-k tests aren't sensitive to
1820
    /// tie-breaking (heap is unstable; full-sort is stable; we want
1821
    /// both to agree without arguing about equal-score row order).
1822
    ///
1823
    /// **Why positive scores:** the INSERT parser doesn't currently
1824
    /// handle `Expr::UnaryOp(Minus, …)` for negative number literals
1825
    /// (it would parse `-3.14` as a unary expression and the value
1826
    /// extractor would skip it). That's a pre-existing bug, out of
1827
    /// scope for 7c. Using the Knuth multiplicative hash gives us
1828
    /// distinct positive scrambled values without dancing around the
1829
    /// negative-literal limitation.
1830
    fn seed_score_table(n: usize) -> Database {
1✔
1831
        let mut db = Database::new("tempdb".to_string());
1✔
1832
        crate::sql::process_command(
1833
            "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1834
            &mut db,
1835
        )
1836
        .expect("create");
1837
        for i in 0..n {
1✔
1838
            // Knuth multiplicative hash mod 1_000_000 — distinct,
1839
            // dense in [0, 999_999], no collisions for n up to ~tens
1840
            // of thousands.
1841
            let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
2✔
1842
            let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1✔
1843
            crate::sql::process_command(&sql, &mut db).expect("insert");
2✔
1844
        }
1845
        db
1✔
1846
    }
1847

1848
    /// Helper: parses an SQL SELECT into a SelectQuery so we can drive
1849
    /// `select_topk` / `sort_rowids` directly without the rest of the
1850
    /// process_command pipeline.
1851
    fn parse_select(sql: &str) -> SelectQuery {
1✔
1852
        let dialect = SQLiteDialect {};
1853
        let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1✔
1854
        let stmt = ast.pop().expect("one statement");
2✔
1855
        SelectQuery::new(&stmt).expect("select-query")
2✔
1856
    }
1857

1858
    #[test]
1859
    fn topk_matches_full_sort_asc() {
3✔
1860
        // Build N=200, top-k=10. Bounded heap output must equal
1861
        // full-sort-then-truncate output (both produce ASC order).
1862
        let db = seed_score_table(200);
1✔
1863
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1864
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1✔
1865
        let order = q.order_by.as_ref().unwrap();
2✔
1866
        let all_rowids = table.rowids();
1✔
1867

1868
        // Full-sort path
1869
        let mut full = all_rowids.clone();
1✔
1870
        sort_rowids(&mut full, table, order).unwrap();
2✔
1871
        full.truncate(10);
1✔
1872

1873
        // Bounded-heap path
1874
        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1✔
1875

1876
        assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
2✔
1877
    }
1878

1879
    #[test]
1880
    fn topk_matches_full_sort_desc() {
3✔
1881
        // Same with DESC — verifies the direction-aware Ord wrapper.
1882
        let db = seed_score_table(200);
1✔
1883
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1884
        let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1✔
1885
        let order = q.order_by.as_ref().unwrap();
2✔
1886
        let all_rowids = table.rowids();
1✔
1887

1888
        let mut full = all_rowids.clone();
1✔
1889
        sort_rowids(&mut full, table, order).unwrap();
2✔
1890
        full.truncate(10);
1✔
1891

1892
        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1✔
1893

1894
        assert_eq!(
2✔
1895
            topk, full,
1896
            "top-k DESC via heap should match full-sort+truncate"
1897
        );
1898
    }
1899

1900
    #[test]
1901
    fn topk_k_larger_than_n_returns_everything_sorted() {
3✔
1902
        // The executor branches off to the full-sort path when k >= N,
1903
        // but if a caller invokes select_topk directly with k > N, it
1904
        // should still produce all-sorted output (no truncation
1905
        // because we don't have N items to truncate to k).
1906
        let db = seed_score_table(50);
1✔
1907
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1908
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1✔
1909
        let order = q.order_by.as_ref().unwrap();
2✔
1910
        let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1✔
1911
        assert_eq!(topk.len(), 50);
1✔
1912
        // All scores in ascending order.
1913
        let scores: Vec<f64> = topk
1✔
1914
            .iter()
1915
            .filter_map(|r| match table.get_value("score", *r) {
3✔
1916
                Some(Value::Real(f)) => Some(f),
1✔
1917
                _ => None,
×
1918
            })
1919
            .collect();
1920
        assert!(scores.windows(2).all(|w| w[0] <= w[1]));
4✔
1921
    }
1922

1923
    #[test]
1924
    fn topk_k_zero_returns_empty() {
3✔
1925
        let db = seed_score_table(10);
1✔
1926
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1927
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1✔
1928
        let order = q.order_by.as_ref().unwrap();
2✔
1929
        let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1✔
1930
        assert!(topk.is_empty());
1✔
1931
    }
1932

1933
    #[test]
1934
    fn topk_empty_input_returns_empty() {
3✔
1935
        let db = seed_score_table(0);
1✔
1936
        let table = db.get_table("docs".to_string()).unwrap();
2✔
1937
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1✔
1938
        let order = q.order_by.as_ref().unwrap();
2✔
1939
        let topk = select_topk(&[], table, order, 5).unwrap();
1✔
1940
        assert!(topk.is_empty());
2✔
1941
    }
1942

1943
    #[test]
1944
    fn topk_works_through_select_executor_with_distance_function() {
4✔
1945
        // Integration check that the executor actually picks the
1946
        // bounded-heap path on a KNN-shaped query and produces the
1947
        // correct top-k.
1948
        let mut db = Database::new("tempdb".to_string());
1✔
1949
        crate::sql::process_command(
1950
            "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1951
            &mut db,
1952
        )
1953
        .unwrap();
1954
        // Five rows with distinct distances from probe [1.0, 0.0]:
1955
        //   id=1 [1.0, 0.0]   distance=0
1956
        //   id=2 [2.0, 0.0]   distance=1
1957
        //   id=3 [0.0, 3.0]   distance=√(1+9) = √10 ≈ 3.16
1958
        //   id=4 [1.0, 4.0]   distance=4
1959
        //   id=5 [10.0, 10.0] distance=√(81+100) ≈ 13.45
1960
        for v in &[
1✔
1961
            "[1.0, 0.0]",
1962
            "[2.0, 0.0]",
1963
            "[0.0, 3.0]",
1964
            "[1.0, 4.0]",
1965
            "[10.0, 10.0]",
1966
        ] {
1967
            crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3✔
1968
                .unwrap();
1969
        }
1970
        let resp = crate::sql::process_command(
1971
            "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1972
            &mut db,
1973
        )
1974
        .unwrap();
1975
        // Top-3 closest to [1.0, 0.0] are id=1, id=2, id=3 (in that order).
1976
        // The status message tells us how many rows came back.
1977
        assert!(resp.contains("3 rows returned"), "got: {resp}");
2✔
1978
    }
1979

1980
    /// Manual benchmark — not run by default. Recommended invocation:
1981
    ///
1982
    ///     cargo test -p sqlrite-engine --lib topk_benchmark --release \
1983
    ///         -- --ignored --nocapture
1984
    ///
1985
    /// (`--release` matters: Rust's optimized sort gets very fast under
1986
    /// optimization, so the heap's relative advantage is best observed
1987
    /// against a sort that's also been optimized.)
1988
    ///
1989
    /// Measured numbers on an Apple Silicon laptop with N=10_000 + k=10:
1990
    ///   - bounded heap:    ~820µs
1991
    ///   - full sort+trunc: ~1.5ms
1992
    ///   - ratio:           ~1.8×
1993
    ///
1994
    /// The advantage is real but moderate at this size because the sort
1995
    /// key here is a single REAL column read (cheap) and Rust's sort_by
1996
    /// has a very low constant factor. The asymptotic O(N log k) vs
1997
    /// O(N log N) advantage scales with N and with per-row work — KNN
1998
    /// queries where the sort key is `vec_distance_l2(col, [...])` are
1999
    /// where this path really pays off, because each key evaluation is
2000
    /// itself O(dim) and the heap path skips the per-row evaluation
2001
    /// in the comparator (see `sort_rowids` for the contrast).
2002
    #[test]
2003
    #[ignore]
2004
    fn topk_benchmark() {
2005
        use std::time::Instant;
2006
        const N: usize = 10_000;
2007
        const K: usize = 10;
2008

2009
        let db = seed_score_table(N);
2010
        let table = db.get_table("docs".to_string()).unwrap();
2011
        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2012
        let order = q.order_by.as_ref().unwrap();
2013
        let all_rowids = table.rowids();
2014

2015
        // Time bounded heap.
2016
        let t0 = Instant::now();
2017
        let _topk = select_topk(&all_rowids, table, order, K).unwrap();
2018
        let heap_dur = t0.elapsed();
2019

2020
        // Time full sort + truncate.
2021
        let t1 = Instant::now();
2022
        let mut full = all_rowids.clone();
2023
        sort_rowids(&mut full, table, order).unwrap();
2024
        full.truncate(K);
2025
        let sort_dur = t1.elapsed();
2026

2027
        let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
2028
        println!("\n--- topk_benchmark (N={N}, k={K}) ---");
2029
        println!("  bounded heap:   {heap_dur:?}");
2030
        println!("  full sort+trunc: {sort_dur:?}");
2031
        println!("  speedup ratio:  {ratio:.2}×");
2032

2033
        // Soft assertion. Floor is 1.4× because the cheap-key
2034
        // benchmark hovers around 1.8× empirically; setting this too
2035
        // close to the measured value risks flaky CI on slower
2036
        // runners. Floor of 1.4× still catches an actual regression
2037
        // (e.g., if select_topk became O(N²) or stopped using the
2038
        // heap entirely).
2039
        assert!(
2040
            ratio > 1.4,
2041
            "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
2042
        );
2043
    }
2044
}
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc