• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

neuml / paperai / 12528686399

28 Dec 2024 06:05PM UTC coverage: 97.789% (+1.3%) from 96.538%
12528686399

push

github

davidmezzetti
Update build script

752 of 769 relevant lines covered (97.79%)

0.98 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.27
/src/python/paperai/query.py
1
"""
2
Query module
3
"""
4

5
import datetime
1✔
6
import re
1✔
7
import sys
1✔
8

9
from rich.console import Console
1✔
10

11
from txtai.pipeline import Tokenizer
1✔
12

13
from .highlights import Highlights
1✔
14
from .models import Models
1✔
15

16

17
class Query:
1✔
18
    """
19
    Methods to query an embeddings index.
20
    """
21

22
    @staticmethod
1✔
23
    def search(embeddings, cur, query, topn, threshold):
1✔
24
        """
25
        Executes an embeddings search for the input query. Each returned result is resolved
26
        to the full section row.
27

28
        Args:
29
            embeddings: embeddings model
30
            cur: database cursor
31
            query: query text
32
            topn: number of documents to return
33
            threshold: require at least this score to include result
34

35
        Returns:
36
            search results
37
        """
38

39
        if query == "*":
1✔
40
            return []
1✔
41

42
        # Default threshold if None
43
        threshold = threshold if threshold is not None else 0.25
1✔
44

45
        results = []
1✔
46

47
        # Get list of required and prohibited tokens
48
        must = [token.strip("+") for token in query.split() if token.startswith("+") and len(token) > 1]
1✔
49
        mnot = [token.strip("-") for token in query.split() if token.startswith("-") and len(token) > 1]
1✔
50

51
        # Tokenize search query, if necessary
52
        query = Tokenizer.tokenize(query) if embeddings.isweighted() else query
1✔
53

54
        # Retrieve topn * 5 to account for duplicate matches
55
        for result in embeddings.search(query, topn * 5):
1✔
56
            uid, score = (result["id"], result["score"]) if isinstance(result, dict) else result
1✔
57

58
            if score >= threshold:
1✔
59
                cur.execute("SELECT Article, Text FROM sections WHERE id = ?", [uid])
1✔
60

61
                # Get matching row
62
                sid, text = cur.fetchone()
1✔
63

64
                # Add result if:
65
                #   - all required tokens are present or there are not required tokens AND
66
                #   - all prohibited tokens are not present or there are not prohibited tokens
67
                if (not must or all(token.lower() in text.lower() for token in must)) and (
1✔
68
                    not mnot or all(token.lower() not in text.lower() for token in mnot)
69
                ):
70
                    # Save result
71
                    results.append((uid, score, sid, text))
1✔
72

73
        return results
1✔
74

75
    @staticmethod
1✔
76
    def highlights(results, topn):
1✔
77
        """
78
        Builds a list of highlights for the search results. Returns top ranked sections by importance
79
        over the result list.
80

81
        Args:
82
            results: search results
83
            topn: number of highlights to extract
84

85
        Returns:
86
            top ranked sections
87
        """
88

89
        sections = {}
1✔
90
        for uid, score, _, text in results:
1✔
91
            # Filter out lower scored results
92
            if score >= 0.1:
1✔
93
                sections[text] = (uid, text)
1✔
94

95
        # Return up to 5 highlights
96
        return Highlights.build(sections.values(), min(topn, 5))
1✔
97

98
    @staticmethod
1✔
99
    def documents(results, topn):
1✔
100
        """
101
        Processes search results and groups by article.
102

103
        Args:
104
            results: search results
105
            topn: number of documents to return
106

107
        Returns:
108
            results grouped by article
109
        """
110

111
        documents = {}
1✔
112

113
        # Group by article
114
        for _, score, article, text in results:
1✔
115
            if article not in documents:
1✔
116
                documents[article] = set()
1✔
117

118
            documents[article].add((score, text))
1✔
119

120
        # Sort based on section id, which preserves original order
121
        for uid in documents:
1✔
122
            documents[uid] = sorted(list(documents[uid]), reverse=True)
1✔
123

124
        # Get documents with top n best sections
125
        topn = sorted(documents, key=lambda k: max(x[0] for x in documents[k]), reverse=True)[:topn]
1✔
126
        return {uid: documents[uid] for uid in topn}
1✔
127

128
    @staticmethod
1✔
129
    def all(cur):
1✔
130
        """
131
        Gets a list of all article ids.
132

133
        Args:
134
            cur: database cursor
135

136
        Returns:
137
            list of all ids as a dict
138
        """
139

140
        cur.execute("SELECT Id FROM articles")
1✔
141
        return {row[0]: None for row in cur.fetchall()}
1✔
142

143
    @staticmethod
1✔
144
    def authors(authors):
1✔
145
        """
146
        Formats a short authors string
147

148
        Args:
149
            authors: full authors string
150

151
        Returns:
152
            short author string
153
        """
154

155
        if authors:
1✔
156
            authors = authors.split("; ")[0]
1✔
157
            if "," in authors:
1✔
158
                authors = authors.split(",")[0]
1✔
159
            else:
160
                authors = authors.split()[-1]
×
161

162
            return f"{authors} et al"
1✔
163

164
        return None
1✔
165

166
    @staticmethod
1✔
167
    def date(date):
1✔
168
        """
169
        Formats a date string.
170

171
        Args:
172
            date: input date string
173

174
        Returns:
175
            formatted date
176
        """
177

178
        if date:
1✔
179
            date = datetime.datetime.strptime(date, "%Y-%m-%d %H:%M:%S")
1✔
180

181
            # 1/1 dates had no month/day specified, use only year
182
            if date.month == 1 and date.day == 1:
1✔
183
                return date.strftime("%Y")
1✔
184

185
            return date.strftime("%Y-%m-%d")
1✔
186

187
        return None
1✔
188

189
    @staticmethod
1✔
190
    def text(text):
1✔
191
        """
192
        Formats match text.
193

194
        Args:
195
            text: input text
196

197
        Returns:
198
            formatted text
199
        """
200

201
        if text:
1✔
202
            # Remove reference links ([1], [2], etc)
203
            text = re.sub(r"\s*[\[(][0-9, ]+[\])]\s*", " ", text)
1✔
204

205
            # Remove •
206
            text = text.replace("•", "")
1✔
207

208
            # Remove http links
209
            text = re.sub(r"http.+?\s", " ", text)
1✔
210

211
        return text
1✔
212

213
    @staticmethod
1✔
214
    def query(embeddings, db, query, topn, threshold):
1✔
215
        """
216
        Executes a query against the embeddings model.
217

218
        Args:
219
            embeddings: embeddings model
220
            db: open SQLite database
221
            query: query string
222
            topn: number of query results
223
            threshold: query match score threshold
224
        """
225

226
        # Default to 10 results if not specified
227
        topn = topn if topn else 10
1✔
228

229
        cur = db.cursor()
1✔
230

231
        # Create console printer
232
        console = Console(soft_wrap=True)
1✔
233
        with console.capture() as output:
1✔
234
            # Print query
235
            console.print(f"[dark_orange]Query: {query}[/dark_orange]")
1✔
236
            console.print()
1✔
237

238
            # Execute query
239
            results = Query.search(embeddings, cur, query, topn, threshold)
1✔
240

241
            # Extract top sections as highlights
242
            highlights = Query.highlights(results, int(topn / 5))
1✔
243
            if highlights:
1✔
244
                console.print("[deep_sky_blue1]Highlights[/deep_sky_blue1]")
1✔
245
                for highlight in highlights:
1✔
246
                    console.print(
1✔
247
                        (f"[bright_blue] - {Query.text(highlight)}[/bright_blue]"),
248
                        highlight=False,
249
                    )
250

251
                console.print()
1✔
252

253
            # Get results grouped by document
254
            documents = Query.documents(results, topn)
1✔
255

256
            # Article header
257
            console.print("[deep_sky_blue1]Articles[/deep_sky_blue1]")
1✔
258
            console.print()
1✔
259

260
            # Print each result, sorted by max score descending
261
            for uid in sorted(documents, key=lambda k: sum(x[0] for x in documents[k]), reverse=True):
1✔
262
                cur.execute(
1✔
263
                    "SELECT Title, Published, Publication, Entry, Id, Reference FROM articles WHERE id = ?",
264
                    [uid],
265
                )
266
                article = cur.fetchone()
1✔
267

268
                console.print(f"Title: {article[0]}", highlight=False)
1✔
269
                console.print(f"Published: {Query.date(article[1])}", highlight=False)
1✔
270
                console.print(f"Publication: {article[2]}", highlight=False)
1✔
271
                console.print(f"Entry: {article[3]}", highlight=False)
1✔
272
                console.print(f"Id: {article[4]}", highlight=False)
1✔
273
                console.print(f"Reference: {article[5]}")
1✔
274

275
                # Print top matches
276
                for score, text in documents[uid]:
1✔
277
                    console.print(
1✔
278
                        f"[bright_blue] - ({score:.4f}): {Query.text(text)}[/bright_blue]",
279
                        highlight=False,
280
                    )
281

282
                console.print()
1✔
283

284
        # Print console output
285
        print(output.get())
1✔
286

287
    @staticmethod
1✔
288
    def run(query, topn=None, path=None, threshold=None):
1✔
289
        """
290
        Executes a query against an index.
291

292
        Args:
293
            query: input query
294
            topn: number of results
295
            path: model path
296
            threshold: query match score threshold
297
        """
298

299
        # Load model
300
        embeddings, db = Models.load(path)
1✔
301

302
        # Query the database
303
        Query.query(embeddings, db, query, topn, threshold)
1✔
304

305
        # Free resources
306
        Models.close(db)
1✔
307

308

309
if __name__ == "__main__":
1✔
310
    if len(sys.argv) > 1:
×
311
        Query.run(
×
312
            sys.argv[1],
313
            int(sys.argv[2]) if len(sys.argv) > 2 else None,
314
            sys.argv[3] if len(sys.argv) > 3 else None,
315
            float(sys.argv[4]) if len(sys.argv) > 4 else None,
316
        )
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc