• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 16933015230

13 Aug 2025 09:18AM UTC coverage: 92.184% (+0.2%) from 91.969%
16933015230

Pull #9699

github

web-flow
Merge cfbd602e7 into 8160ea8bf
Pull Request #9699: feat: Update `source_id_meta_field` in `SentenceWindowRetriever` to also accept a list of values

12891 of 13984 relevant lines covered (92.18%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

67.72
haystack/components/embedders/openai_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from dataclasses import replace
1✔
7
from typing import Any, Optional
1✔
8

9
from more_itertools import batched
1✔
10
from openai import APIError, AsyncOpenAI, OpenAI
1✔
11
from tqdm import tqdm
1✔
12
from tqdm.asyncio import tqdm as async_tqdm
1✔
13

14
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
15
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
16
from haystack.utils.http_client import init_http_client
1✔
17

18
logger = logging.getLogger(__name__)
1✔
19

20

21
@component
1✔
22
class OpenAIDocumentEmbedder:
1✔
23
    """
24
    Computes document embeddings using OpenAI models.
25

26
    ### Usage example
27

28
    ```python
29
    from haystack import Document
30
    from haystack.components.embedders import OpenAIDocumentEmbedder
31

32
    doc = Document(content="I love pizza!")
33

34
    document_embedder = OpenAIDocumentEmbedder()
35

36
    result = document_embedder.run([doc])
37
    print(result['documents'][0].embedding)
38

39
    # [0.017020374536514282, -0.023255806416273117, ...]
40
    ```
41
    """
42

43
    def __init__(  # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
1✔
44
        self,
45
        api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
46
        model: str = "text-embedding-ada-002",
47
        dimensions: Optional[int] = None,
48
        api_base_url: Optional[str] = None,
49
        organization: Optional[str] = None,
50
        prefix: str = "",
51
        suffix: str = "",
52
        batch_size: int = 32,
53
        progress_bar: bool = True,
54
        meta_fields_to_embed: Optional[list[str]] = None,
55
        embedding_separator: str = "\n",
56
        timeout: Optional[float] = None,
57
        max_retries: Optional[int] = None,
58
        http_client_kwargs: Optional[dict[str, Any]] = None,
59
        *,
60
        raise_on_failure: bool = False,
61
    ):
62
        """
63
        Creates an OpenAIDocumentEmbedder component.
64

65
        Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
66
        environment variables to override the `timeout` and `max_retries` parameters respectively
67
        in the OpenAI client.
68

69
        :param api_key:
70
            The OpenAI API key.
71
            You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
72
            during initialization.
73
        :param model:
74
            The name of the model to use for calculating embeddings.
75
            The default model is `text-embedding-ada-002`.
76
        :param dimensions:
77
            The number of dimensions of the resulting embeddings. Only `text-embedding-3` and
78
            later models support this parameter.
79
        :param api_base_url:
80
            Overrides the default base URL for all HTTP requests.
81
        :param organization:
82
            Your OpenAI organization ID. See OpenAI's
83
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
84
            for more information.
85
        :param prefix:
86
            A string to add at the beginning of each text.
87
        :param suffix:
88
            A string to add at the end of each text.
89
        :param batch_size:
90
            Number of documents to embed at once.
91
        :param progress_bar:
92
            If `True`, shows a progress bar when running.
93
        :param meta_fields_to_embed:
94
            List of metadata fields to embed along with the document text.
95
        :param embedding_separator:
96
            Separator used to concatenate the metadata fields to the document text.
97
        :param timeout:
98
            Timeout for OpenAI client calls. If not set, it defaults to either the
99
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
100
        :param max_retries:
101
            Maximum number of retries to contact OpenAI after an internal error.
102
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries.
103
        :param http_client_kwargs:
104
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
105
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
106
        :param raise_on_failure:
107
            Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
108
            and continue processing the remaining documents. If `True`, it will raise an exception on failure.
109
        """
110
        self.api_key = api_key
1✔
111
        self.model = model
1✔
112
        self.dimensions = dimensions
1✔
113
        self.api_base_url = api_base_url
1✔
114
        self.organization = organization
1✔
115
        self.prefix = prefix
1✔
116
        self.suffix = suffix
1✔
117
        self.batch_size = batch_size
1✔
118
        self.progress_bar = progress_bar
1✔
119
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
120
        self.embedding_separator = embedding_separator
1✔
121
        self.timeout = timeout
1✔
122
        self.max_retries = max_retries
1✔
123
        self.http_client_kwargs = http_client_kwargs
1✔
124
        self.raise_on_failure = raise_on_failure
1✔
125

126
        if timeout is None:
1✔
127
            timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
128
        if max_retries is None:
1✔
129
            max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
130

131
        client_kwargs: dict[str, Any] = {
1✔
132
            "api_key": api_key.resolve_value(),
133
            "organization": organization,
134
            "base_url": api_base_url,
135
            "timeout": timeout,
136
            "max_retries": max_retries,
137
        }
138

139
        self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
1✔
140
        self.async_client = AsyncOpenAI(
1✔
141
            http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
142
        )
143

144
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
145
        """
146
        Data that is sent to Posthog for usage analytics.
147
        """
148
        return {"model": self.model}
×
149

150
    def to_dict(self) -> dict[str, Any]:
1✔
151
        """
152
        Serializes the component to a dictionary.
153

154
        :returns:
155
            Dictionary with serialized data.
156
        """
157
        return default_to_dict(
1✔
158
            self,
159
            api_key=self.api_key.to_dict(),
160
            model=self.model,
161
            dimensions=self.dimensions,
162
            api_base_url=self.api_base_url,
163
            organization=self.organization,
164
            prefix=self.prefix,
165
            suffix=self.suffix,
166
            batch_size=self.batch_size,
167
            progress_bar=self.progress_bar,
168
            meta_fields_to_embed=self.meta_fields_to_embed,
169
            embedding_separator=self.embedding_separator,
170
            timeout=self.timeout,
171
            max_retries=self.max_retries,
172
            http_client_kwargs=self.http_client_kwargs,
173
            raise_on_failure=self.raise_on_failure,
174
        )
175

176
    @classmethod
1✔
177
    def from_dict(cls, data: dict[str, Any]) -> "OpenAIDocumentEmbedder":
1✔
178
        """
179
        Deserializes the component from a dictionary.
180

181
        :param data:
182
            Dictionary to deserialize from.
183
        :returns:
184
            Deserialized component.
185
        """
186
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
1✔
187
        return default_from_dict(cls, data)
1✔
188

189
    def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]:
1✔
190
        """
191
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
192
        """
193
        texts_to_embed = {}
1✔
194
        for doc in documents:
1✔
195
            meta_values_to_embed = [
1✔
196
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
197
            ]
198

199
            texts_to_embed[doc.id] = (
1✔
200
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
201
            )
202

203
        return texts_to_embed
1✔
204

205
    def _embed_batch(
1✔
206
        self, texts_to_embed: dict[str, str], batch_size: int
207
    ) -> tuple[dict[str, list[float]], dict[str, Any]]:
208
        """
209
        Embed a list of texts in batches.
210
        """
211

212
        doc_ids_to_embeddings: dict[str, list[float]] = {}
1✔
213
        meta: dict[str, Any] = {}
1✔
214
        for batch in tqdm(
1✔
215
            batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
216
        ):
217
            args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch], "encoding_format": "float"}
1✔
218

219
            if self.dimensions is not None:
1✔
220
                args["dimensions"] = self.dimensions
×
221

222
            try:
1✔
223
                response = self.client.embeddings.create(**args)
1✔
224
            except APIError as exc:
1✔
225
                ids = ", ".join(b[0] for b in batch)
1✔
226
                msg = "Failed embedding of documents {ids} caused by {exc}"
1✔
227
                logger.exception(msg, ids=ids, exc=exc)
1✔
228
                if self.raise_on_failure:
1✔
229
                    raise exc
1✔
230
                continue
1✔
231

232
            embeddings = [el.embedding for el in response.data]
1✔
233
            doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
1✔
234

235
            if "model" not in meta:
1✔
236
                meta["model"] = response.model
1✔
237
            if "usage" not in meta:
1✔
238
                meta["usage"] = dict(response.usage)
1✔
239
            else:
240
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
241
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
242

243
        return doc_ids_to_embeddings, meta
1✔
244

245
    async def _embed_batch_async(
1✔
246
        self, texts_to_embed: dict[str, str], batch_size: int
247
    ) -> tuple[dict[str, list[float]], dict[str, Any]]:
248
        """
249
        Embed a list of texts in batches asynchronously.
250
        """
251

252
        doc_ids_to_embeddings: dict[str, list[float]] = {}
×
253
        meta: dict[str, Any] = {}
×
254

255
        batches = list(batched(texts_to_embed.items(), batch_size))
×
256
        if self.progress_bar:
×
257
            batches = async_tqdm(batches, desc="Calculating embeddings")
×
258

259
        for batch in batches:
×
260
            args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
×
261

262
            if self.dimensions is not None:
×
263
                args["dimensions"] = self.dimensions
×
264

265
            try:
×
266
                response = await self.async_client.embeddings.create(**args)
×
267
            except APIError as exc:
×
268
                ids = ", ".join(b[0] for b in batch)
×
269
                msg = "Failed embedding of documents {ids} caused by {exc}"
×
270
                logger.exception(msg, ids=ids, exc=exc)
×
271
                if self.raise_on_failure:
×
272
                    raise exc
×
273
                continue
×
274

275
            embeddings = [el.embedding for el in response.data]
×
276
            doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
×
277

278
            if "model" not in meta:
×
279
                meta["model"] = response.model
×
280
            if "usage" not in meta:
×
281
                meta["usage"] = dict(response.usage)
×
282
            else:
283
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
284
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
285

286
        return doc_ids_to_embeddings, meta
×
287

288
    @component.output_types(documents=list[Document], meta=dict[str, Any])
1✔
289
    def run(self, documents: list[Document]):
1✔
290
        """
291
        Embeds a list of documents.
292

293
        :param documents:
294
            A list of documents to embed.
295

296
        :returns:
297
            A dictionary with the following keys:
298
            - `documents`: A list of documents with embeddings.
299
            - `meta`: Information about the usage of the model.
300
        """
301
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
302
            raise TypeError(
1✔
303
                "OpenAIDocumentEmbedder expects a list of Documents as input."
304
                "In case you want to embed a string, please use the OpenAITextEmbedder."
305
            )
306

307
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
1✔
308

309
        doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
1✔
310

311
        new_documents = []
1✔
312
        for doc in documents:
1✔
313
            if doc.id in doc_ids_to_embeddings:
1✔
314
                new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
1✔
315
            else:
316
                new_documents.append(replace(doc))
1✔
317

318
        return {"documents": new_documents, "meta": meta}
1✔
319

320
    @component.output_types(documents=list[Document], meta=dict[str, Any])
1✔
321
    async def run_async(self, documents: list[Document]):
1✔
322
        """
323
        Embeds a list of documents asynchronously.
324

325
        :param documents:
326
            A list of documents to embed.
327

328
        :returns:
329
            A dictionary with the following keys:
330
            - `documents`: A list of documents with embeddings.
331
            - `meta`: Information about the usage of the model.
332
        """
333
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
×
334
            raise TypeError(
×
335
                "OpenAIDocumentEmbedder expects a list of Documents as input. "
336
                "In case you want to embed a string, please use the OpenAITextEmbedder."
337
            )
338

339
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
×
340

341
        doc_ids_to_embeddings, meta = await self._embed_batch_async(
×
342
            texts_to_embed=texts_to_embed, batch_size=self.batch_size
343
        )
344

345
        new_documents = []
×
346
        for doc in documents:
×
347
            if doc.id in doc_ids_to_embeddings:
×
348
                new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
×
349
            else:
350
                new_documents.append(replace(doc))
×
351

352
        return {"documents": new_documents, "meta": meta}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc