• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14156679192

30 Mar 2025 01:51PM UTC coverage: 89.887% (-0.3%) from 90.161%
14156679192

Pull #9140

github

web-flow
Merge a2b39ed39 into e483ec6f5
Pull Request #9140: feat(embedders): Add async support for OpenAI document embedder

10222 of 11372 relevant lines covered (89.89%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

59.29
haystack/components/embedders/openai_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Dict, List, Optional, Tuple
1✔
7

8
from more_itertools import batched
1✔
9
from openai import APIError, OpenAI
1✔
10
from tqdm import tqdm
1✔
11
from tqdm.asyncio import tqdm as async_tqdm
1✔
12

13
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
@component
1✔
20
class OpenAIDocumentEmbedder:
1✔
21
    """
22
    Computes document embeddings using OpenAI models.
23

24
    ### Usage example
25

26
    ```python
27
    from haystack import Document
28
    from haystack.components.embedders import OpenAIDocumentEmbedder
29

30
    doc = Document(content="I love pizza!")
31

32
    document_embedder = OpenAIDocumentEmbedder()
33

34
    result = document_embedder.run([doc])
35
    print(result['documents'][0].embedding)
36

37
    # [0.017020374536514282, -0.023255806416273117, ...]
38
    ```
39
    """
40

41
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
42
        self,
43
        api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
44
        model: str = "text-embedding-ada-002",
45
        dimensions: Optional[int] = None,
46
        api_base_url: Optional[str] = None,
47
        organization: Optional[str] = None,
48
        prefix: str = "",
49
        suffix: str = "",
50
        batch_size: int = 32,
51
        progress_bar: bool = True,
52
        meta_fields_to_embed: Optional[List[str]] = None,
53
        embedding_separator: str = "\n",
54
        timeout: Optional[float] = None,
55
        max_retries: Optional[int] = None,
56
    ):
57
        """
58
        Creates an OpenAIDocumentEmbedder component.
59

60
        Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
61
        environment variables to override the `timeout` and `max_retries` parameters respectively
62
        in the OpenAI client.
63

64
        :param api_key:
65
            The OpenAI API key.
66
            You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
67
            during initialization.
68
        :param model:
69
            The name of the model to use for calculating embeddings.
70
            The default model is `text-embedding-ada-002`.
71
        :param dimensions:
72
            The number of dimensions of the resulting embeddings. Only `text-embedding-3` and
73
            later models support this parameter.
74
        :param api_base_url:
75
            Overrides the default base URL for all HTTP requests.
76
        :param organization:
77
            Your OpenAI organization ID. See OpenAI's
78
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
79
            for more information.
80
        :param prefix:
81
            A string to add at the beginning of each text.
82
        :param suffix:
83
            A string to add at the end of each text.
84
        :param batch_size:
85
            Number of documents to embed at once.
86
        :param progress_bar:
87
            If `True`, shows a progress bar when running.
88
        :param meta_fields_to_embed:
89
            List of metadata fields to embed along with the document text.
90
        :param embedding_separator:
91
            Separator used to concatenate the metadata fields to the document text.
92
        :param timeout:
93
            Timeout for OpenAI client calls. If not set, it defaults to either the
94
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
95
        :param max_retries:
96
            Maximum number of retries to contact OpenAI after an internal error.
97
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries.
98
        """
99
        self.api_key = api_key
1✔
100
        self.model = model
1✔
101
        self.dimensions = dimensions
1✔
102
        self.api_base_url = api_base_url
1✔
103
        self.organization = organization
1✔
104
        self.prefix = prefix
1✔
105
        self.suffix = suffix
1✔
106
        self.batch_size = batch_size
1✔
107
        self.progress_bar = progress_bar
1✔
108
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
109
        self.embedding_separator = embedding_separator
1✔
110

111
        if timeout is None:
1✔
112
            timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
113
        if max_retries is None:
1✔
114
            max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
115

116
        self.client = OpenAI(
1✔
117
            api_key=api_key.resolve_value(),
118
            organization=organization,
119
            base_url=api_base_url,
120
            timeout=timeout,
121
            max_retries=max_retries,
122
        )
123

124
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
125
        """
126
        Data that is sent to Posthog for usage analytics.
127
        """
128
        return {"model": self.model}
×
129

130
    def to_dict(self) -> Dict[str, Any]:
1✔
131
        """
132
        Serializes the component to a dictionary.
133

134
        :returns:
135
            Dictionary with serialized data.
136
        """
137
        return default_to_dict(
1✔
138
            self,
139
            model=self.model,
140
            dimensions=self.dimensions,
141
            organization=self.organization,
142
            api_base_url=self.api_base_url,
143
            prefix=self.prefix,
144
            suffix=self.suffix,
145
            batch_size=self.batch_size,
146
            progress_bar=self.progress_bar,
147
            meta_fields_to_embed=self.meta_fields_to_embed,
148
            embedding_separator=self.embedding_separator,
149
            api_key=self.api_key.to_dict(),
150
        )
151

152
    @classmethod
1✔
153
    def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
1✔
154
        """
155
        Deserializes the component from a dictionary.
156

157
        :param data:
158
            Dictionary to deserialize from.
159
        :returns:
160
            Deserialized component.
161
        """
162
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
1✔
163
        return default_from_dict(cls, data)
1✔
164

165
    def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
1✔
166
        """
167
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
168
        """
169
        texts_to_embed = {}
1✔
170
        for doc in documents:
1✔
171
            meta_values_to_embed = [
1✔
172
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
173
            ]
174

175
            text_to_embed = (
1✔
176
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
177
            )
178

179
            # copied from OpenAI embedding_utils (https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py)
180
            # replace newlines, which can negatively affect performance.
181
            texts_to_embed[doc.id] = text_to_embed.replace("\n", " ")
1✔
182
        return texts_to_embed
1✔
183

184
    def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
1✔
185
        """
186
        Embed a list of texts in batches.
187
        """
188

189
        all_embeddings = []
1✔
190
        meta: Dict[str, Any] = {}
1✔
191
        for batch in tqdm(
1✔
192
            batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
193
        ):
194
            args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
1✔
195

196
            if self.dimensions is not None:
1✔
197
                args["dimensions"] = self.dimensions
×
198

199
            try:
1✔
200
                response = self.client.embeddings.create(**args)
1✔
201
            except APIError as exc:
1✔
202
                ids = ", ".join(b[0] for b in batch)
1✔
203
                msg = "Failed embedding of documents {ids} caused by {exc}"
1✔
204
                logger.exception(msg, ids=ids, exc=exc)
1✔
205
                continue
1✔
206

207
            embeddings = [el.embedding for el in response.data]
×
208
            all_embeddings.extend(embeddings)
×
209

210
            if "model" not in meta:
×
211
                meta["model"] = response.model
×
212
            if "usage" not in meta:
×
213
                meta["usage"] = dict(response.usage)
×
214
            else:
215
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
216
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
217

218
        return all_embeddings, meta
1✔
219

220
    async def _embed_batch_async(
1✔
221
        self, texts_to_embed: Dict[str, str], batch_size: int
222
    ) -> Tuple[List[List[float]], Dict[str, Any]]:
223
        """
224
        Embed a list of texts in batches asynchronously.
225
        """
226

227
        from openai import AsyncOpenAI
×
228

229
        async_client = AsyncOpenAI(
×
230
            api_key=self.api_key.resolve_value(),
231
            organization=self.organization,
232
            base_url=self.api_base_url,
233
            timeout=self.client.timeout,
234
            max_retries=self.client.max_retries,
235
        )
236

237
        all_embeddings = []
×
238
        meta: Dict[str, Any] = {}
×
239

240
        batches = list(batched(texts_to_embed.items(), batch_size))
×
241
        if self.progress_bar:
×
242
            batches = async_tqdm(batches, desc="Calculating embeddings")
×
243

244
        for batch in batches:
×
245
            args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
×
246

247
            if self.dimensions is not None:
×
248
                args["dimensions"] = self.dimensions
×
249

250
            try:
×
251
                response = await async_client.embeddings.create(**args)
×
252
            except APIError as exc:
×
253
                ids = ", ".join(b[0] for b in batch)
×
254
                msg = "Failed embedding of documents {ids} caused by {exc}"
×
255
                logger.exception(msg, ids=ids, exc=exc)
×
256
                continue
×
257

258
            embeddings = [el.embedding for el in response.data]
×
259
            all_embeddings.extend(embeddings)
×
260

261
            if "model" not in meta:
×
262
                meta["model"] = response.model
×
263
            if "usage" not in meta:
×
264
                meta["usage"] = dict(response.usage)
×
265
            else:
266
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
267
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
268

269
        await async_client.close()
×
270
        return all_embeddings, meta
×
271

272
    @component.output_types(documents=List[Document], meta=Dict[str, Any])
1✔
273
    def run(self, documents: List[Document]):
1✔
274
        """
275
        Embeds a list of documents.
276

277
        :param documents:
278
            A list of documents to embed.
279

280
        :returns:
281
            A dictionary with the following keys:
282
            - `documents`: A list of documents with embeddings.
283
            - `meta`: Information about the usage of the model.
284
        """
285
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
286
            raise TypeError(
1✔
287
                "OpenAIDocumentEmbedder expects a list of Documents as input."
288
                "In case you want to embed a string, please use the OpenAITextEmbedder."
289
            )
290

291
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
1✔
292

293
        embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
1✔
294

295
        for doc, emb in zip(documents, embeddings):
1✔
296
            doc.embedding = emb
×
297

298
        return {"documents": documents, "meta": meta}
1✔
299

300
    @component.output_types(documents=List[Document], meta=Dict[str, Any])
1✔
301
    async def run_async(self, documents: List[Document]):
1✔
302
        """
303
        Embeds a list of documents asynchronously.
304

305
        :param documents:
306
            A list of documents to embed.
307

308
        :returns:
309
            A dictionary with the following keys:
310
            - `documents`: A list of documents with embeddings.
311
            - `meta`: Information about the usage of the model.
312
        """
313
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
×
314
            raise TypeError(
×
315
                "OpenAIDocumentEmbedder expects a list of Documents as input. "
316
                "In case you want to embed a string, please use the OpenAITextEmbedder."
317
            )
318

319
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
×
320

321
        embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
×
322

323
        for doc, emb in zip(documents, embeddings):
×
324
            doc.embedding = emb
×
325

326
        return {"documents": documents, "meta": meta}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc