• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

neuml / txtai / 11062719915

27 Sep 2024 01:06AM CUT coverage: 99.946%. Remained the same
11062719915

push

github

davidmezzetti
Update documentation

7406 of 7410 relevant lines covered (99.95%)

1.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

100.0
/src/python/txtai/ann/sqlite.py
1
"""
2
SQLite module
3
"""
4

5
import os
1✔
6
import sqlite3
1✔
7

8
# Conditional import
9
try:
1✔
10
    import sqlite_vec
1✔
11

12
    SQLITEVEC = True
1✔
13
except ImportError:
1✔
14
    SQLITEVEC = False
1✔
15

16
from .base import ANN
1✔
17

18

19
class SQLite(ANN):
1✔
20
    """
21
    Builds an ANN index backed by a SQLite database.
22
    """
23

24
    def __init__(self, config):
1✔
25
        super().__init__(config)
1✔
26

27
        if not SQLITEVEC:
1✔
28
            raise ImportError('sqlite-vec is not available - install "ann" extra to enable')
1✔
29

30
        # Database parameters
31
        self.connection, self.cursor, self.path = None, None, ""
1✔
32

33
        # Quantization setting
34
        self.quantize = self.setting("quantize")
1✔
35
        self.quantize = 8 if isinstance(self.quantize, bool) else int(self.quantize) if self.quantize else None
1✔
36

37
    def load(self, path):
1✔
38
        self.path = path
1✔
39

40
    def index(self, embeddings):
1✔
41
        # Initialize tables
42
        self.initialize(recreate=True)
1✔
43

44
        # Add vectors
45
        self.database().executemany(self.insertsql(), enumerate(embeddings))
1✔
46

47
        # Add id offset and index build metadata
48
        self.config["offset"] = embeddings.shape[0]
1✔
49
        self.metadata(self.settings())
1✔
50

51
    def append(self, embeddings):
1✔
52
        self.database().executemany(self.insertsql(), [(x + self.config["offset"], row) for x, row in enumerate(embeddings)])
1✔
53

54
        self.config["offset"] += embeddings.shape[0]
1✔
55
        self.metadata()
1✔
56

57
    def delete(self, ids):
1✔
58
        self.database().executemany(self.deletesql(), [(x,) for x in ids])
1✔
59

60
    def search(self, queries, limit):
1✔
61
        results = []
1✔
62
        for query in queries:
1✔
63
            # Execute query
64
            self.database().execute(self.searchsql(), [query, limit])
1✔
65

66
            # Add query results
67
            results.append(list(self.database()))
1✔
68

69
        return results
1✔
70

71
    def count(self):
1✔
72
        self.database().execute(self.countsql())
1✔
73
        return self.cursor.fetchone()[0]
1✔
74

75
    def save(self, path):
1✔
76
        # Temporary database
77
        if not self.path:
1✔
78
            # Save temporary database
79
            self.connection.commit()
1✔
80

81
            # Copy data from current to new
82
            connection = self.copy(path)
1✔
83

84
            # Close temporary database
85
            self.connection.close()
1✔
86

87
            # Point connection to new connection
88
            self.connection = connection
1✔
89
            self.cursor = self.connection.cursor()
1✔
90
            self.path = path
1✔
91

92
        # Paths are equal, commit changes
93
        elif self.path == path:
1✔
94
            self.connection.commit()
1✔
95

96
        # New path is different from current path, copy data and continue using current connection
97
        else:
98
            self.copy(path).close()
1✔
99

100
    def initialize(self, recreate=False):
1✔
101
        """
102
        Initializes a new database session.
103

104
        Args:
105
            recreate: Recreates the database tables if True
106
        """
107

108
        # Create table
109
        self.database().execute(self.tablesql())
1✔
110

111
        # Clear data
112
        if recreate:
1✔
113
            self.database().execute(self.tosql("DELETE FROM {table}"))
1✔
114

115
    def settings(self):
1✔
116
        """
117
        Returns settings for this index.
118

119
        Returns:
120
            dict
121
        """
122

123
        sqlite, sqlitevec = self.database().execute("SELECT sqlite_version(), vec_version()").fetchone()
1✔
124

125
        return {"sqlite": sqlite, "sqlite-vec": sqlitevec}
1✔
126

127
    def database(self):
1✔
128
        """
129
        Gets the current database cursor. Creates a new connection
130
        if there isn't one.
131

132
        Returns:
133
            cursor
134
        """
135

136
        if not self.connection:
1✔
137
            self.connection = self.connect(self.path)
1✔
138
            self.cursor = self.connection.cursor()
1✔
139

140
        return self.cursor
1✔
141

142
    def connect(self, path):
1✔
143
        """
144
        Creates a new database connection.
145

146
        Args:
147
            path: path to database file
148

149
        Returns:
150
            database connection
151
        """
152

153
        # Create connection
154
        connection = sqlite3.connect(path, check_same_thread=False)
1✔
155

156
        # Load sqlite-vec extension
157
        connection.enable_load_extension(True)
1✔
158
        sqlite_vec.load(connection)
1✔
159
        connection.enable_load_extension(False)
1✔
160

161
        # Return connection and cursor
162
        return connection
1✔
163

164
    def copy(self, path):
1✔
165
        """
166
        Copies content from the current database into target.
167

168
        Args:
169
            path: target database path
170

171
        Returns:
172
            new database connection
173
        """
174

175
        # Delete existing file, if necessary
176
        if os.path.exists(path):
1✔
177
            os.remove(path)
1✔
178

179
        # Create new connection
180
        connection = self.connect(path)
1✔
181

182
        if self.connection.in_transaction:
1✔
183
            # Initialize connection
184
            connection.execute(self.tablesql())
1✔
185

186
            # The backup call will hang if there are uncommitted changes, need to copy over
187
            # with iterdump (which is much slower)
188
            for sql in self.connection.iterdump():
1✔
189
                if self.tosql('insert into "{table}"') in sql.lower():
1✔
190
                    connection.execute(sql)
1✔
191
        else:
192
            # Database is up to date, can do a more efficient copy with SQLite C API
193
            self.connection.backup(connection)
1✔
194

195
        return connection
1✔
196

197
    def tablesql(self):
1✔
198
        """
199
        Builds a CREATE table statement for table.
200

201
        Returns:
202
            CREATE TABLE
203
        """
204

205
        # Binary quantization
206
        if self.quantize == 1:
1✔
207
            embedding = f"embedding BIT[{self.config['dimensions']}]"
1✔
208

209
        # INT8 quantization
210
        elif self.quantize == 8:
1✔
211
            embedding = f"embedding INT8[{self.config['dimensions']}] distance=cosine"
1✔
212

213
        # Standard FLOAT32
214
        else:
215
            embedding = f"embedding FLOAT[{self.config['dimensions']}] distance=cosine"
1✔
216

217
        # Return CREATE TABLE sql
218
        return self.tosql(("CREATE VIRTUAL TABLE IF NOT EXISTS {table} USING vec0" "(indexid INTEGER PRIMARY KEY, " f"{embedding})"))
1✔
219

220
    def insertsql(self):
1✔
221
        """
222
        Creates an INSERT SQL statement.
223

224
        Returns:
225
            INSERT
226
        """
227

228
        return self.tosql(f"INSERT INTO {{table}}(indexid, embedding) VALUES (?, {self.embeddingsql()})")
1✔
229

230
    def deletesql(self):
1✔
231
        """
232
        Creates a DELETE SQL statement.
233

234
        Returns:
235
            DELETE
236
        """
237

238
        return self.tosql("DELETE FROM {table} WHERE indexid = ?")
1✔
239

240
    def searchsql(self):
1✔
241
        """
242
        Creates a SELECT SQL statement for search.
243

244
        Returns:
245
            SELECT
246
        """
247

248
        return self.tosql(("SELECT indexid, 1 - distance FROM {table} " f"WHERE embedding MATCH {self.embeddingsql()} AND k = ? ORDER BY distance"))
1✔
249

250
    def countsql(self):
1✔
251
        """
252
        Creates a SELECT COUNT statement.
253

254
        Returns:
255
            SELECT COUNT
256
        """
257

258
        return self.tosql("SELECT count(indexid) FROM {table}")
1✔
259

260
    def embeddingsql(self):
1✔
261
        """
262
        Creates an embeddings column SQL snippet.
263

264
        Returns:
265
            embeddings column SQL
266
        """
267

268
        # Binary quantization
269
        if self.quantize == 1:
1✔
270
            embedding = "vec_quantize_binary(?)"
1✔
271

272
        # INT8 quantization
273
        elif self.quantize == 8:
1✔
274
            embedding = "vec_quantize_int8(?, 'unit')"
1✔
275

276
        # Standard FLOAT32
277
        else:
278
            embedding = "?"
1✔
279

280
        return embedding
1✔
281

282
    def tosql(self, sql):
1✔
283
        """
284
        Creates a SQL statement substituting in the configured table name.
285

286
        Args:
287
            sql: SQL statement with a {table} parameter
288

289
        Returns:
290
            fully resolved SQL statement
291
        """
292

293
        table = self.setting("table", "vectors")
1✔
294
        return sql.format(table=table)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc