• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14497519261

16 Apr 2025 04:16PM UTC coverage: 90.372% (-0.002%) from 90.374%
14497519261

Pull #9215

github

web-flow
Merge 6ceca1c16 into e5dc4ef94
Pull Request #9215: feat: Allow OpenAI client config in `OpenAIChatGenerator` and `AzureOpenAIChatGenerator`

10691 of 11830 relevant lines covered (90.37%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.0
haystack/components/embedders/azure_document_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Dict, List, Optional
1✔
7

8
from openai.lib.azure import AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
1✔
9

10
from haystack import component, default_from_dict, default_to_dict, logging
1✔
11
from haystack.components.embedders import OpenAIDocumentEmbedder
1✔
12
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
13
from haystack.utils.http_client import init_http_client
1✔
14

15
logger = logging.getLogger(__name__)
1✔
16

17

18
@component
1✔
19
class AzureOpenAIDocumentEmbedder(OpenAIDocumentEmbedder):
1✔
20
    """
21
    Calculates document embeddings using OpenAI models deployed on Azure.
22

23
    ### Usage example
24

25
    ```python
26
    from haystack import Document
27
    from haystack.components.embedders import AzureOpenAIDocumentEmbedder
28

29
    doc = Document(content="I love pizza!")
30

31
    document_embedder = AzureOpenAIDocumentEmbedder()
32

33
    result = document_embedder.run([doc])
34
    print(result['documents'][0].embedding)
35

36
    # [0.017020374536514282, -0.023255806416273117, ...]
37
    ```
38
    """
39

40
    # pylint: disable=super-init-not-called
41
    def __init__(  # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
1✔
42
        self,
43
        azure_endpoint: Optional[str] = None,
44
        api_version: Optional[str] = "2023-05-15",
45
        azure_deployment: str = "text-embedding-ada-002",
46
        dimensions: Optional[int] = None,
47
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
48
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
49
        organization: Optional[str] = None,
50
        prefix: str = "",
51
        suffix: str = "",
52
        batch_size: int = 32,
53
        progress_bar: bool = True,
54
        meta_fields_to_embed: Optional[List[str]] = None,
55
        embedding_separator: str = "\n",
56
        timeout: Optional[float] = None,
57
        max_retries: Optional[int] = None,
58
        *,
59
        default_headers: Optional[Dict[str, str]] = None,
60
        azure_ad_token_provider: Optional[AzureADTokenProvider] = None,
61
        http_client_kwargs: Optional[Dict[str, Any]] = None,
62
    ):
63
        """
64
        Creates an AzureOpenAIDocumentEmbedder component.
65

66
        :param azure_endpoint:
67
            The endpoint of the model deployed on Azure.
68
        :param api_version:
69
            The version of the API to use.
70
        :param azure_deployment:
71
            The name of the model deployed on Azure. The default model is text-embedding-ada-002.
72
        :param dimensions:
73
            The number of dimensions of the resulting embeddings. Only supported in text-embedding-3
74
            and later models.
75
        :param api_key:
76
            The Azure OpenAI API key.
77
            You can set it with an environment variable `AZURE_OPENAI_API_KEY`, or pass with this
78
            parameter during initialization.
79
        :param azure_ad_token:
80
            Microsoft Entra ID token, see Microsoft's
81
            [Entra ID](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id)
82
            documentation for more information. You can set it with an environment variable
83
            `AZURE_OPENAI_AD_TOKEN`, or pass with this parameter during initialization.
84
            Previously called Azure Active Directory.
85
        :param organization:
86
            Your organization ID. See OpenAI's
87
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
88
            for more information.
89
        :param prefix:
90
            A string to add at the beginning of each text.
91
        :param suffix:
92
            A string to add at the end of each text.
93
        :param batch_size:
94
            Number of documents to embed at once.
95
        :param progress_bar:
96
            If `True`, shows a progress bar when running.
97
        :param meta_fields_to_embed:
98
            List of metadata fields to embed along with the document text.
99
        :param embedding_separator:
100
            Separator used to concatenate the metadata fields to the document text.
101
        :param timeout: The timeout for `AzureOpenAI` client calls, in seconds.
102
            If not set, defaults to either the
103
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
104
        :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
105
            If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable or to 5 retries.
106
        :param default_headers: Default headers to send to the AzureOpenAI client.
107
        :param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
108
            every request.
109
        :param http_client_kwargs:
110
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
111
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
112
        """
113
        # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
114
        # with the API.
115

116
        # if not provided as a parameter, azure_endpoint is read from the env var AZURE_OPENAI_ENDPOINT
117
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
118
        if not azure_endpoint:
1✔
119
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
120

121
        if api_key is None and azure_ad_token is None:
1✔
122
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
123

124
        self.api_key = api_key  # type: ignore[assignment] # mypy does not understand that api_key can be None
1✔
125
        self.azure_ad_token = azure_ad_token
1✔
126
        self.api_version = api_version
1✔
127
        self.azure_endpoint = azure_endpoint
1✔
128
        self.azure_deployment = azure_deployment
1✔
129
        self.model = azure_deployment
1✔
130
        self.dimensions = dimensions
1✔
131
        self.organization = organization
1✔
132
        self.prefix = prefix
1✔
133
        self.suffix = suffix
1✔
134
        self.batch_size = batch_size
1✔
135
        self.progress_bar = progress_bar
1✔
136
        self.meta_fields_to_embed = meta_fields_to_embed or []
1✔
137
        self.embedding_separator = embedding_separator
1✔
138
        self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
139
        self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
140
        self.default_headers = default_headers or {}
1✔
141
        self.azure_ad_token_provider = azure_ad_token_provider
1✔
142
        self.http_client_kwargs = http_client_kwargs
1✔
143

144
        client_args: Dict[str, Any] = {
1✔
145
            "api_version": api_version,
146
            "azure_endpoint": azure_endpoint,
147
            "azure_deployment": azure_deployment,
148
            "azure_ad_token_provider": azure_ad_token_provider,
149
            "api_key": api_key.resolve_value() if api_key is not None else None,
150
            "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None,
151
            "organization": organization,
152
            "timeout": self.timeout,
153
            "max_retries": self.max_retries,
154
            "default_headers": self.default_headers,
155
        }
156

157
        self.client = AzureOpenAI(
1✔
158
            http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
159
        )
160
        self.async_client = AsyncAzureOpenAI(
1✔
161
            http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
162
        )
163

164
    def to_dict(self) -> Dict[str, Any]:
1✔
165
        """
166
        Serializes the component to a dictionary.
167

168
        :returns:
169
            Dictionary with serialized data.
170
        """
171
        azure_ad_token_provider_name = None
1✔
172
        if self.azure_ad_token_provider:
1✔
173
            azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
1✔
174
        return default_to_dict(
1✔
175
            self,
176
            azure_endpoint=self.azure_endpoint,
177
            azure_deployment=self.azure_deployment,
178
            dimensions=self.dimensions,
179
            organization=self.organization,
180
            api_version=self.api_version,
181
            prefix=self.prefix,
182
            suffix=self.suffix,
183
            batch_size=self.batch_size,
184
            progress_bar=self.progress_bar,
185
            meta_fields_to_embed=self.meta_fields_to_embed,
186
            embedding_separator=self.embedding_separator,
187
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
188
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
189
            timeout=self.timeout,
190
            max_retries=self.max_retries,
191
            default_headers=self.default_headers,
192
            azure_ad_token_provider=azure_ad_token_provider_name,
193
            http_client_kwargs=self.http_client_kwargs,
194
        )
195

196
    @classmethod
1✔
197
    def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAIDocumentEmbedder":
1✔
198
        """
199
        Deserializes the component from a dictionary.
200

201
        :param data:
202
            Dictionary to deserialize from.
203
        :returns:
204
            Deserialized component.
205
        """
206
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
207
        serialized_azure_ad_token_provider = data["init_parameters"].get("azure_ad_token_provider")
1✔
208
        if serialized_azure_ad_token_provider:
1✔
209
            data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
1✔
210
                serialized_azure_ad_token_provider
211
            )
212
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc