• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

neuml / txtai / 3875424846

pending completion
3875424846

push

github

davidmezzetti
Add language-modeling task to HFTrainer, closes #403

36 of 36 new or added lines in 4 files covered. (100.0%)

4501 of 4505 relevant lines covered (99.91%)

1.0 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.59
/src/python/txtai/vectors/words.py
1
"""
2
Word Vectors module
3
"""
4

5
import logging
1✔
6
import os
1✔
7
import pickle
1✔
8
import tempfile
1✔
9

10
from errno import ENOENT
1✔
11
from multiprocessing import Pool
1✔
12

13
import numpy as np
1✔
14

15
# Conditionally import Word Vector libraries as they aren't installed by default
16
try:
1✔
17
    import fasttext
1✔
18
    from pymagnitude import converter, Magnitude
1✔
19

20
    WORDS = True
1✔
21
except ImportError:
1✔
22
    WORDS = False
1✔
23

24
from .. import __pickle__
1✔
25

26
from .base import Vectors
1✔
27
from ..pipeline import Tokenizer
1✔
28

29
# Logging configuration
30
logger = logging.getLogger(__name__)
1✔
31

32
# Multiprocessing helper methods
33
# pylint: disable=W0603
34
VECTORS = None
1✔
35

36

37
def create(config, scoring):
1✔
38
    """
39
    Multiprocessing helper method. Creates a global embeddings object to be accessed in a new subprocess.
40

41
    Args:
42
        config: vector configuration
43
        scoring: scoring instance
44
    """
45

46
    global VECTORS
47

48
    # Create a global embedding object using configuration and saved
49
    VECTORS = WordVectors(config, scoring)
×
50

51

52
def transform(document):
1✔
53
    """
54
    Multiprocessing helper method. Transforms document into an embeddings vector.
55

56
    Args:
57
        document: (id, data, tags)
58

59
    Returns:
60
        (id, embedding)
61
    """
62

63
    return (document[0], VECTORS.transform(document))
×
64

65

66
class WordVectors(Vectors):
1✔
67
    """
68
    Builds sentence embeddings/vectors using weighted word embeddings.
69
    """
70

71
    def load(self, path):
1✔
72
        # Ensure that vector path exists
73
        if not path or not os.path.isfile(path):
1✔
74
            raise IOError(ENOENT, "Vector model file not found", path)
1✔
75

76
        # Load magnitude model. If this is a training run (uninitialized config), block until vectors are fully loaded
77
        return Magnitude(path, case_insensitive=True, blocking=not self.initialized)
1✔
78

79
    def encode(self, data):
1✔
80
        # Iterate over each data element, tokenize (if necessary) and build an aggregated embeddings vector
81
        embeddings = []
1✔
82
        for tokens in data:
1✔
83
            # Convert to tokens if necessary
84
            if isinstance(tokens, str):
1✔
85
                tokens = Tokenizer.tokenize(tokens)
1✔
86

87
            # Generate weights for each vector using a scoring method
88
            weights = self.scoring.weights(tokens) if self.scoring else None
1✔
89

90
            # pylint: disable=E1133
91
            if weights and [x for x in weights if x > 0]:
1✔
92
                # Build weighted average embeddings vector. Create weights array os float32 to match embeddings precision.
93
                embedding = np.average(self.lookup(tokens), weights=np.array(weights, dtype=np.float32), axis=0)
1✔
94
            else:
95
                # If no weights, use mean
96
                embedding = np.mean(self.lookup(tokens), axis=0)
1✔
97

98
            embeddings.append(embedding)
1✔
99

100
        return np.array(embeddings, dtype=np.float32)
1✔
101

102
    def index(self, documents, batchsize=1):
1✔
103
        # Use default single process indexing logic
104
        if "parallel" in self.config and not self.config["parallel"]:
1✔
105
            return super().index(documents, batchsize)
1✔
106

107
        # Customize indexing logic with multiprocessing pool to efficiently build vectors
108
        ids, dimensions, batches, stream = [], None, 0, None
1✔
109

110
        # Shared objects with Pool
111
        args = (self.config, self.scoring)
1✔
112

113
        # Convert all documents to embedding arrays, stream embeddings to disk to control memory usage
114
        with Pool(os.cpu_count(), initializer=create, initargs=args) as pool:
1✔
115
            with tempfile.NamedTemporaryFile(mode="wb", suffix=".npy", delete=False) as output:
1✔
116
                stream = output.name
1✔
117
                embeddings = []
1✔
118
                for uid, embedding in pool.imap(transform, documents):
1✔
119
                    if not dimensions:
1✔
120
                        # Set number of dimensions for embeddings
121
                        dimensions = embedding.shape[0]
1✔
122

123
                    ids.append(uid)
1✔
124
                    embeddings.append(embedding)
1✔
125

126
                    if len(embeddings) == batchsize:
1✔
127
                        pickle.dump(np.array(embeddings, dtype=np.float32), output, protocol=__pickle__)
1✔
128
                        batches += 1
1✔
129

130
                        embeddings = []
1✔
131

132
                # Final embeddings batch
133
                if embeddings:
1✔
134
                    pickle.dump(np.array(embeddings, dtype=np.float32), output, protocol=__pickle__)
1✔
135
                    batches += 1
1✔
136

137
        return (ids, dimensions, batches, stream)
1✔
138

139
    def lookup(self, tokens):
1✔
140
        """
141
        Queries word vectors for given list of input tokens.
142

143
        Args:
144
            tokens: list of tokens to query
145

146
        Returns:
147
            word vectors array
148
        """
149

150
        return self.model.query(tokens)
1✔
151

152
    @staticmethod
1✔
153
    def isdatabase(path):
154
        """
155
        Checks if this is a SQLite database file which is the file format used for word vectors databases.
156

157
        Args:
158
            path: path to check
159

160
        Returns:
161
            True if this is a SQLite database
162
        """
163

164
        if isinstance(path, str) and os.path.isfile(path) and os.path.getsize(path) >= 100:
1✔
165
            # Read 100 byte SQLite header
166
            with open(path, "rb") as f:
1✔
167
                header = f.read(100)
1✔
168

169
            # Check for SQLite header
170
            return header.startswith(b"SQLite format 3\000")
1✔
171

172
        return False
1✔
173

174
    @staticmethod
1✔
175
    def build(data, size, mincount, path):
176
        """
177
        Builds fastText vectors from a file.
178

179
        Args:
180
            data: path to input data file
181
            size: number of vector dimensions
182
            mincount: minimum number of occurrences required to register a token
183
            path: path to output file
184
        """
185

186
        # Train on data file using largest dimension size
187
        model = fasttext.train_unsupervised(data, dim=size, minCount=mincount)
1✔
188

189
        # Output file path
190
        logging.info("Building %d dimension model", size)
1✔
191

192
        # Output vectors in vec/txt format
193
        with open(path + ".txt", "w", encoding="utf-8") as output:
1✔
194
            words = model.get_words()
1✔
195
            output.write(f"{len(words)} {model.get_dimension()}\n")
1✔
196

197
            for word in words:
1✔
198
                # Skip end of line token
199
                if word != "</s>":
1✔
200
                    vector = model.get_word_vector(word)
1✔
201
                    data = ""
1✔
202
                    for v in vector:
1✔
203
                        data += " " + str(v)
1✔
204

205
                    output.write(word + data + "\n")
1✔
206

207
        # Build magnitude vectors database
208
        logging.info("Converting vectors to magnitude format")
1✔
209
        converter.convert(path + ".txt", path + ".magnitude", subword=True)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc