• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13651977173

04 Mar 2025 10:48AM UTC coverage: 90.218% (+0.2%) from 90.017%
13651977173

Pull #8906

github

web-flow
Merge f9482a15a into 0d65b4caa
Pull Request #8906: refactor!: remove `dataframe` field from `Document` and `ExtractedTableAnswer`; make `pandas` optional

9601 of 10642 relevant lines covered (90.22%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

69.62
haystack/components/embedders/azure_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Dict, List, Optional, Tuple
1✔
7

8
from more_itertools import batched
1✔
9
from openai import APIError
1✔
10
from openai.lib.azure import AzureOpenAI
1✔
11
from tqdm import tqdm
1✔
12

13
from haystack import Document, component, default_from_dict, default_to_dict, logging
1✔
14
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
@component
1✔
20
class AzureOpenAIDocumentEmbedder:
1✔
21
    """
22
    Calculates document embeddings using OpenAI models deployed on Azure.
23

24
    ### Usage example
25

26
    ```python
27
    from haystack import Document
28
    from haystack.components.embedders import AzureOpenAIDocumentEmbedder
29

30
    doc = Document(content="I love pizza!")
31

32
    document_embedder = AzureOpenAIDocumentEmbedder()
33

34
    result = document_embedder.run([doc])
35
    print(result['documents'][0].embedding)
36

37
    # [0.017020374536514282, -0.023255806416273117, ...]
38
    ```
39
    """
40

41
    def __init__(  # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
1✔
42
        self,
43
        azure_endpoint: Optional[str] = None,
44
        api_version: Optional[str] = "2023-05-15",
45
        azure_deployment: str = "text-embedding-ada-002",
46
        dimensions: Optional[int] = None,
47
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
48
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
49
        organization: Optional[str] = None,
50
        prefix: str = "",
51
        suffix: str = "",
52
        batch_size: int = 32,
53
        progress_bar: bool = True,
54
        meta_fields_to_embed: Optional[List[str]] = None,
55
        embedding_separator: str = "\n",
56
        timeout: Optional[float] = None,
57
        max_retries: Optional[int] = None,
58
        *,
59
        default_headers: Optional[Dict[str, str]] = None,
60
    ):
61
        """
62
        Creates an AzureOpenAIDocumentEmbedder component.
63

64
        :param azure_endpoint:
65
            The endpoint of the model deployed on Azure.
66
        :param api_version:
67
            The version of the API to use.
68
        :param azure_deployment:
69
            The name of the model deployed on Azure. The default model is text-embedding-ada-002.
70
        :param dimensions:
71
            The number of dimensions of the resulting embeddings. Only supported in text-embedding-3
72
            and later models.
73
        :param api_key:
74
            The Azure OpenAI API key.
75
            You can set it with an environment variable `AZURE_OPENAI_API_KEY`, or pass with this
76
            parameter during initialization.
77
        :param azure_ad_token:
78
            Microsoft Entra ID token, see Microsoft's
79
            [Entra ID](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id)
80
            documentation for more information. You can set it with an environment variable
81
            `AZURE_OPENAI_AD_TOKEN`, or pass with this parameter during initialization.
82
            Previously called Azure Active Directory.
83
        :param organization:
84
            Your organization ID. See OpenAI's
85
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
86
            for more information.
87
        :param prefix:
88
            A string to add at the beginning of each text.
89
        :param suffix:
90
            A string to add at the end of each text.
91
        :param batch_size:
92
            Number of documents to embed at once.
93
        :param progress_bar:
94
            If `True`, shows a progress bar when running.
95
        :param meta_fields_to_embed:
96
            List of metadata fields to embed along with the document text.
97
        :param embedding_separator:
98
            Separator used to concatenate the metadata fields to the document text.
99
        :param timeout: The timeout for `AzureOpenAI` client calls, in seconds.
100
            If not set, defaults to either the
101
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
102
        :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
103
            If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
104
        :param default_headers: Default headers to send to the AzureOpenAI client.
105
        """
106
        # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
107
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
108
        if not azure_endpoint:
1✔
109
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
110

111
        if api_key is None and azure_ad_token is None:
1✔
112
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
113

114
        self.api_key = api_key
1✔
115
        self.azure_ad_token = azure_ad_token
1✔
116
        self.api_version = api_version
1✔
117
        self.azure_endpoint = azure_endpoint
1✔
118
        self.azure_deployment = azure_deployment
1✔
119
        self.dimensions = dimensions
1✔
120
        self.organization = organization
1✔
121
        self.prefix = prefix
1✔
122
        self.suffix = suffix
1✔
123
        self.batch_size = batch_size
1✔
124
        self.progress_bar = progress_bar
1✔
125
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
126
        self.embedding_separator = embedding_separator
1✔
127
        self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
1✔
128
        self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
1✔
129
        self.default_headers = default_headers or {}
1✔
130

131
        self._client = AzureOpenAI(
1✔
132
            api_version=api_version,
133
            azure_endpoint=azure_endpoint,
134
            azure_deployment=azure_deployment,
135
            api_key=api_key.resolve_value() if api_key is not None else None,
136
            azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
137
            organization=organization,
138
            timeout=self.timeout,
139
            max_retries=self.max_retries,
140
            default_headers=self.default_headers,
141
        )
142

143
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
144
        """
145
        Data that is sent to Posthog for usage analytics.
146
        """
147
        return {"model": self.azure_deployment}
×
148

149
    def to_dict(self) -> Dict[str, Any]:
1✔
150
        """
151
        Serializes the component to a dictionary.
152

153
        :returns:
154
            Dictionary with serialized data.
155
        """
156
        return default_to_dict(
1✔
157
            self,
158
            azure_endpoint=self.azure_endpoint,
159
            azure_deployment=self.azure_deployment,
160
            dimensions=self.dimensions,
161
            organization=self.organization,
162
            api_version=self.api_version,
163
            prefix=self.prefix,
164
            suffix=self.suffix,
165
            batch_size=self.batch_size,
166
            progress_bar=self.progress_bar,
167
            meta_fields_to_embed=self.meta_fields_to_embed,
168
            embedding_separator=self.embedding_separator,
169
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
170
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
171
            timeout=self.timeout,
172
            max_retries=self.max_retries,
173
            default_headers=self.default_headers,
174
        )
175

176
    @classmethod
1✔
177
    def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
1✔
178
        """
179
        Deserializes the component from a dictionary.
180

181
        :param data:
182
            Dictionary to deserialize from.
183
        :returns:
184
            Deserialized component.
185
        """
186
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
187
        return default_from_dict(cls, data)
1✔
188

189
    def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
1✔
190
        """
191
        Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
192
        """
193
        texts_to_embed = {}
×
194
        for doc in documents:
×
195
            meta_values_to_embed = [
×
196
                str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
197
            ]
198

199
            text_to_embed = (
×
200
                self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
201
            ).replace("\n", " ")
202

203
            texts_to_embed[doc.id] = text_to_embed
×
204
        return texts_to_embed
×
205

206
    def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple[List[List[float]], Dict[str, Any]]:
1✔
207
        """
208
        Embed a list of texts in batches.
209
        """
210

211
        all_embeddings: List[List[float]] = []
1✔
212
        meta: Dict[str, Any] = {"model": "", "usage": {"prompt_tokens": 0, "total_tokens": 0}}
1✔
213

214
        for batch in tqdm(
1✔
215
            batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
216
        ):
217
            args: Dict[str, Any] = {"model": self.azure_deployment, "input": [b[1] for b in batch]}
1✔
218

219
            if self.dimensions is not None:
1✔
220
                args["dimensions"] = self.dimensions
×
221

222
            try:
1✔
223
                response = self._client.embeddings.create(**args)
1✔
224
            except APIError as e:
1✔
225
                # Log the error but continue processing
226
                ids = ", ".join(b[0] for b in batch)
1✔
227
                logger.exception(f"Failed embedding of documents {ids} caused by {e}")
1✔
228
                continue
1✔
229

230
            embeddings = [el.embedding for el in response.data]
×
231
            all_embeddings.extend(embeddings)
×
232

233
            # Update the meta information only once if it's empty
234
            if not meta["model"]:
×
235
                meta["model"] = response.model
×
236
                meta["usage"] = dict(response.usage)
×
237
            else:
238
                # Update the usage tokens
239
                meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
×
240
                meta["usage"]["total_tokens"] += response.usage.total_tokens
×
241

242
        return all_embeddings, meta
1✔
243

244
    @component.output_types(documents=List[Document], meta=Dict[str, Any])
1✔
245
    def run(self, documents: List[Document]) -> Dict[str, Any]:
1✔
246
        """
247
        Embeds a list of documents.
248

249
        :param documents:
250
            Documents to embed.
251

252
        :returns:
253
            A dictionary with the following keys:
254
            - `documents`: A list of documents with embeddings.
255
            - `meta`: Information about the usage of the model.
256
        """
257
        if not (isinstance(documents, list) and all(isinstance(doc, Document) for doc in documents)):
×
258
            raise TypeError("Input must be a list of Document instances. For strings, use AzureOpenAITextEmbedder.")
×
259

260
        texts_to_embed = self._prepare_texts_to_embed(documents=documents)
×
261
        embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
×
262

263
        # Assign the corresponding embeddings to each document
264
        for doc, emb in zip(documents, embeddings):
×
265
            doc.embedding = emb
×
266

267
        return {"documents": documents, "meta": meta}
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc