• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15186489091

22 May 2025 12:24PM CUT coverage: 90.423% (-0.03%) from 90.454%
15186489091

Pull #9406

github

web-flow
Merge 7a1615cb3 into e6a53b9dc
Pull Request #9406: feat: Extend AnswerBuilder for Agent

11104 of 12280 relevant lines covered (90.42%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

62.61
haystack/components/embedders/openai_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Dict, List, Optional, Tuple
1✔
7

8
from more_itertools import batched
1✔
9
from openai import APIError, AsyncOpenAI, OpenAI
1✔
10
from tqdm import tqdm
1✔
11
from tqdm.asyncio import tqdm as async_tqdm
1✔
12

13
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
15
from haystack.utils.http_client import init_http_client
1✔
16

17
logger = logging.getLogger(__name__)
1✔
18

19

20
@component
1✔
21
class OpenAIDocumentEmbedder:
1✔
22
    """
23
    Computes document embeddings using OpenAI models.
24

25
    ### Usage example
26

27
    ```python
28
    from haystack import Document
29
    from haystack.components.embedders import OpenAIDocumentEmbedder
30

31
    doc = Document(content="I love pizza!")
32

33
    document_embedder = OpenAIDocumentEmbedder()
34

35
    result = document_embedder.run([doc])
36
    print(result['documents'][0].embedding)
37

38
    # [0.017020374536514282, -0.023255806416273117, ...]
39
    ```
40
    """
41

42
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
43
        self,
44
        api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
45
        model: str = "text-embedding-ada-002",
46
        dimensions: Optional[int] = None,
47
        api_base_url: Optional[str] = None,
48
        organization: Optional[str] = None,
49
        prefix: str = "",
50
        suffix: str = "",
51
        batch_size: int = 32,
52
        progress_bar: bool = True,
53
        meta_fields_to_embed: Optional[List[str]] = None,
54
        embedding_separator: str = "\n",
55
        timeout: Optional[float] = None,
56
        max_retries: Optional[int] = None,
57
        http_client_kwargs: Optional[Dict[str, Any]] = None,
58
    ):
59
        """
60
        Creates an OpenAIDocumentEmbedder component.
61

62
        Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
63
        environment variables to override the `timeout` and `max_retries` parameters respectively
64
        in the OpenAI client.
65

66
        :param api_key:
67
            The OpenAI API key.
68
            You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
69
            during initialization.
70
        :param model:
71
            The name of the model to use for calculating embeddings.
72
            The default model is `text-embedding-ada-002`.
73
        :param dimensions:
74
            The number of dimensions of the resulting embeddings. Only `text-embedding-3` and
75
            later models support this parameter.
76
        :param api_base_url:
77
            Overrides the default base URL for all HTTP requests.
78
        :param organization:
79
            Your OpenAI organization ID. See OpenAI's
80
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
81
            for more information.
82
        :param prefix:
83
            A string to add at the beginning of each text.
84
        :param suffix:
85
            A string to add at the end of each text.
86
        :param batch_size:
87
            Number of documents to embed at once.
88
        :param progress_bar:
89
            If `True`, shows a progress bar when running.
90
        :param meta_fields_to_embed:
91
            List of metadata fields to embed along with the document text.
92
        :param embedding_separator:
93
            Separator used to concatenate the metadata fields to the document text.
94
        :param timeout:
95
            Timeout for OpenAI client calls. If not set, it defaults to either the
96
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
97
        :param max_retries:
98
            Maximum number of retries to contact OpenAI after an internal error.
99
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries.
100
        :param http_client_kwargs:
101
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
102
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
103
        """
104
        self.api_key = api_key
1✔
105
        self.model = model
1✔
106
        self.dimensions = dimensions
1✔
107
        self.api_base_url = api_base_url
1✔
108
        self.organization = organization
1✔
109
        self.prefix = prefix
1✔
110
        self.suffix = suffix
1✔
111
        self.batch_size = batch_size
1✔
112
        self.progress_bar = progress_bar
1✔
113
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
114
        self.embedding_separator = embedding_separator
1✔
115
        self.timeout = timeout
1✔
116
        self.max_retries = max_retries
1✔
117
        self.http_client_kwargs = http_client_kwargs
1✔
118

119
        if timeout is None:
1✔
120
            timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
121
        if max_retries is None:
1✔
122
            max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
123

124
        client_kwargs: Dict[str, Any] = {
1✔
125
            "api_key": api_key.resolve_value(),
126
            "organization": organization,
127
            "base_url": api_base_url,
128
            "timeout": timeout,
129
            "max_retries": max_retries,
130
        }
131

132
        self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
1✔
133
        self.async_client = AsyncOpenAI(
1✔
134
            http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
135
        )
136

137
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
138
        """
139
        Data that is sent to Posthog for usage analytics.
140
        """
141
        return {"model": self.model}
×
142

143
    def to_dict(self) -> Dict[str, Any]:
1✔
144
        """
145
        Serializes the component to a dictionary.
146

147
        :returns:
148
            Dictionary with serialized data.
149
        """
150
        return default_to_dict(
1✔
151
            self,
152
            api_key=self.api_key.to_dict(),
153
            model=self.model,
154
            dimensions=self.dimensions,
155
            api_base_url=self.api_base_url,
156
            organization=self.organization,
157
            prefix=self.prefix,
158
            suffix=self.suffix,
159
            batch_size=self.batch_size,
160
            progress_bar=self.progress_bar,
161
            meta_fields_to_embed=self.meta_fields_to_embed,
162
            embedding_separator=self.embedding_separator,
163
            timeout=self.timeout,
164
            max_retries=self.max_retries,
165
            http_client_kwargs=self.http_client_kwargs,
166
        )
167

168
    @classmethod
1✔
169
    def from_dict(cls, data: Dict[str, Any]) -> "OpenAIDocumentEmbedder":
1✔
170
        """
171
        Deserializes the component from a dictionary.
172

173
        :param data:
174
            Dictionary to deserialize from.
175
        :returns:
176
            Deserialized component.
177
        """
178
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
1✔
179
        return default_from_dict(cls, data)
1✔
180

181
    def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
1✔
182
        """
183
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
184
        """
185
        texts_to_embed = {}
1✔
186
        for doc in documents:
1✔
187
            meta_values_to_embed = [
1✔
188
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
189
            ]
190

191
            texts_to_embed[doc.id] = (
1✔
192
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
193
            )
194

195
        return texts_to_embed
1✔
196

197
    def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
1✔
198
        """
199
        Embed a list of texts in batches.
200
        """
201

202
        all_embeddings = []
1✔
203
        meta: Dict[str, Any] = {}
1✔
204
        for batch in tqdm(
1✔
205
            batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
206
        ):
207
            args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
1✔
208

209
            if self.dimensions is not None:
1✔
210
                args["dimensions"] = self.dimensions
×
211

212
            try:
1✔
213
                response = self.client.embeddings.create(**args)
1✔
214
            except APIError as exc:
1✔
215
                ids = ", ".join(b[0] for b in batch)
1✔
216
                msg = "Failed embedding of documents {ids} caused by {exc}"
1✔
217
                logger.exception(msg, ids=ids, exc=exc)
1✔
218
                continue
1✔
219

220
            embeddings = [el.embedding for el in response.data]
×
221
            all_embeddings.extend(embeddings)
×
222

223
            if "model" not in meta:
×
224
                meta["model"] = response.model
×
225
            if "usage" not in meta:
×
226
                meta["usage"] = dict(response.usage)
×
227
            else:
228
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
229
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
230

231
        return all_embeddings, meta
1✔
232

233
    async def _embed_batch_async(
1✔
234
        self, texts_to_embed: Dict[str, str], batch_size: int
235
    ) -> Tuple[List[List[float]], Dict[str, Any]]:
236
        """
237
        Embed a list of texts in batches asynchronously.
238
        """
239

240
        all_embeddings = []
×
241
        meta: Dict[str, Any] = {}
×
242

243
        batches = list(batched(texts_to_embed.items(), batch_size))
×
244
        if self.progress_bar:
×
245
            batches = async_tqdm(batches, desc="Calculating embeddings")
×
246

247
        for batch in batches:
×
248
            args: Dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
×
249

250
            if self.dimensions is not None:
×
251
                args["dimensions"] = self.dimensions
×
252

253
            try:
×
254
                response = await self.async_client.embeddings.create(**args)
×
255
            except APIError as exc:
×
256
                ids = ", ".join(b[0] for b in batch)
×
257
                msg = "Failed embedding of documents {ids} caused by {exc}"
×
258
                logger.exception(msg, ids=ids, exc=exc)
×
259
                continue
×
260

261
            embeddings = [el.embedding for el in response.data]
×
262
            all_embeddings.extend(embeddings)
×
263

264
            if "model" not in meta:
×
265
                meta["model"] = response.model
×
266
            if "usage" not in meta:
×
267
                meta["usage"] = dict(response.usage)
×
268
            else:
269
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
270
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
271

272
        return all_embeddings, meta
×
273

274
    @component.output_types(documents=List[Document], meta=Dict[str, Any])
1✔
275
    def run(self, documents: List[Document]):
1✔
276
        """
277
        Embeds a list of documents.
278

279
        :param documents:
280
            A list of documents to embed.
281

282
        :returns:
283
            A dictionary with the following keys:
284
            - `documents`: A list of documents with embeddings.
285
            - `meta`: Information about the usage of the model.
286
        """
287
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
1✔
288
            raise TypeError(
1✔
289
                "OpenAIDocumentEmbedder expects a list of Documents as input."
290
                "In case you want to embed a string, please use the OpenAITextEmbedder."
291
            )
292

293
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
1✔
294

295
        embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
1✔
296

297
        for doc, emb in zip(documents, embeddings):
1✔
298
            doc.embedding = emb
×
299

300
        return {"documents": documents, "meta": meta}
1✔
301

302
    @component.output_types(documents=List[Document], meta=Dict[str, Any])
1✔
303
    async def run_async(self, documents: List[Document]):
1✔
304
        """
305
        Embeds a list of documents asynchronously.
306

307
        :param documents:
308
            A list of documents to embed.
309

310
        :returns:
311
            A dictionary with the following keys:
312
            - `documents`: A list of documents with embeddings.
313
            - `meta`: Information about the usage of the model.
314
        """
315
        if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
×
316
            raise TypeError(
×
317
                "OpenAIDocumentEmbedder expects a list of Documents as input. "
318
                "In case you want to embed a string, please use the OpenAITextEmbedder."
319
            )
320

321
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
×
322

323
        embeddings, meta = await self._embed_batch_async(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
×
324

325
        for doc, emb in zip(documents, embeddings):
×
326
            doc.embedding = emb
×
327

328
        return {"documents": documents, "meta": meta}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc