• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 12744218044

13 Jan 2025 09:26AM UTC coverage: 91.352% (+0.3%) from 91.099%
12744218044

Pull #8693

github

web-flow
Merge 4a3ad897d into db76ae284
Pull Request #8693: feat: Add `ComponentTool` to Haystack tools

8968 of 9817 relevant lines covered (91.35%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

71.74
haystack/components/embedders/azure_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Dict, List, Optional
1✔
7

8
from openai.lib.azure import AzureOpenAI
1✔
9

10
from haystack import Document, component, default_from_dict, default_to_dict
1✔
11
from haystack.utils import Secret, deserialize_secrets_inplace
1✔
12

13

14
@component
1✔
15
class AzureOpenAITextEmbedder:
1✔
16
    """
17
    Embeds strings using OpenAI models deployed on Azure.
18

19
    ### Usage example
20

21
    ```python
22
    from haystack.components.embedders import AzureOpenAITextEmbedder
23

24
    text_to_embed = "I love pizza!"
25

26
    text_embedder = AzureOpenAITextEmbedder()
27

28
    print(text_embedder.run(text_to_embed))
29

30
    # {'embedding': [0.017020374536514282, -0.023255806416273117, ...],
31
    # 'meta': {'model': 'text-embedding-ada-002-v2',
32
    #          'usage': {'prompt_tokens': 4, 'total_tokens': 4}}}
33
    ```
34
    """
35

36
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
37
        self,
38
        azure_endpoint: Optional[str] = None,
39
        api_version: Optional[str] = "2023-05-15",
40
        azure_deployment: str = "text-embedding-ada-002",
41
        dimensions: Optional[int] = None,
42
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
43
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
44
        organization: Optional[str] = None,
45
        timeout: Optional[float] = None,
46
        max_retries: Optional[int] = None,
47
        prefix: str = "",
48
        suffix: str = "",
49
        *,
50
        default_headers: Optional[Dict[str, str]] = None,
51
    ):
52
        """
53
        Creates an AzureOpenAITextEmbedder component.
54

55
        :param azure_endpoint:
56
            The endpoint of the model deployed on Azure.
57
        :param api_version:
58
            The version of the API to use.
59
        :param azure_deployment:
60
            The name of the model deployed on Azure. The default model is text-embedding-ada-002.
61
        :param dimensions:
62
            The number of dimensions the resulting output embeddings should have. Only supported in text-embedding-3
63
            and later models.
64
        :param api_key:
65
            The Azure OpenAI API key.
66
            You can set it with an environment variable `AZURE_OPENAI_API_KEY`, or pass with this
67
            parameter during initialization.
68
        :param azure_ad_token:
69
            Microsoft Entra ID token, see Microsoft's
70
            [Entra ID](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id)
71
            documentation for more information. You can set it with an environment variable
72
            `AZURE_OPENAI_AD_TOKEN`, or pass with this parameter during initialization.
73
            Previously called Azure Active Directory.
74
        :param organization:
75
            Your organization ID. See OpenAI's
76
            [Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
77
            for more information.
78
        :param timeout: The timeout for `AzureOpenAI` client calls, in seconds.
79
            If not set, defaults to either the
80
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
81
        :param max_retries: Maximum number of retries to contact AzureOpenAI after an internal error.
82
            If not set, defaults to either the `OPENAI_MAX_RETRIES` environment variable, or to 5 retries.
83
        :param prefix:
84
            A string to add at the beginning of each text.
85
        :param suffix:
86
            A string to add at the end of each text.
87
        :param default_headers: Default headers to send to the AzureOpenAI client.
88
        """
89
        # Why is this here?
90
        # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
91
        # None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
92
        # of passing it as a parameter.
93
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
94
        if not azure_endpoint:
1✔
95
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
96

97
        if api_key is None and azure_ad_token is None:
1✔
98
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
99

100
        self.api_key = api_key
1✔
101
        self.azure_ad_token = azure_ad_token
1✔
102
        self.api_version = api_version
1✔
103
        self.azure_endpoint = azure_endpoint
1✔
104
        self.azure_deployment = azure_deployment
1✔
105
        self.dimensions = dimensions
1✔
106
        self.organization = organization
1✔
107
        self.timeout = timeout or float(os.environ.get("OPENAI_TIMEOUT", 30.0))
1✔
108
        self.max_retries = max_retries or int(os.environ.get("OPENAI_MAX_RETRIES", 5))
1✔
109
        self.prefix = prefix
1✔
110
        self.suffix = suffix
1✔
111
        self.default_headers = default_headers or {}
1✔
112

113
        self._client = AzureOpenAI(
1✔
114
            api_version=api_version,
115
            azure_endpoint=azure_endpoint,
116
            azure_deployment=azure_deployment,
117
            api_key=api_key.resolve_value() if api_key is not None else None,
118
            azure_ad_token=azure_ad_token.resolve_value() if azure_ad_token is not None else None,
119
            organization=organization,
120
            timeout=self.timeout,
121
            max_retries=self.max_retries,
122
            default_headers=self.default_headers,
123
        )
124

125
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
126
        """
127
        Data that is sent to Posthog for usage analytics.
128
        """
129
        return {"model": self.azure_deployment}
×
130

131
    def to_dict(self) -> Dict[str, Any]:
1✔
132
        """
133
        Serializes the component to a dictionary.
134

135
        :returns:
136
            Dictionary with serialized data.
137
        """
138
        return default_to_dict(
1✔
139
            self,
140
            azure_endpoint=self.azure_endpoint,
141
            azure_deployment=self.azure_deployment,
142
            dimensions=self.dimensions,
143
            organization=self.organization,
144
            api_version=self.api_version,
145
            prefix=self.prefix,
146
            suffix=self.suffix,
147
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
148
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
149
            timeout=self.timeout,
150
            max_retries=self.max_retries,
151
            default_headers=self.default_headers,
152
        )
153

154
    @classmethod
1✔
155
    def from_dict(cls, data: Dict[str, Any]) -> "AzureOpenAITextEmbedder":
1✔
156
        """
157
        Deserializes the component from a dictionary.
158

159
        :param data:
160
            Dictionary to deserialize from.
161
        :returns:
162
            Deserialized component.
163
        """
164
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
165
        return default_from_dict(cls, data)
1✔
166

167
    @component.output_types(embedding=List[float], meta=Dict[str, Any])
1✔
168
    def run(self, text: str):
1✔
169
        """
170
        Embeds a single string.
171

172
        :param text:
173
            Text to embed.
174

175
        :returns:
176
            A dictionary with the following keys:
177
            - `embedding`: The embedding of the input text.
178
            - `meta`: Information about the usage of the model.
179
        """
180
        if not isinstance(text, str):
×
181
            # Check if input is a list and all elements are instances of Document
182
            if isinstance(text, list) and all(isinstance(elem, Document) for elem in text):
×
183
                error_message = "Input must be a string. Use AzureOpenAIDocumentEmbedder for a list of Documents."
×
184
            else:
185
                error_message = "Input must be a string."
×
186
            raise TypeError(error_message)
×
187

188
        # Preprocess the text by adding prefixes/suffixes
189
        # finally, replace newlines as recommended by OpenAI docs
190
        processed_text = f"{self.prefix}{text}{self.suffix}".replace("\n", " ")
×
191

192
        if self.dimensions is not None:
×
193
            response = self._client.embeddings.create(
×
194
                model=self.azure_deployment, dimensions=self.dimensions, input=processed_text
195
            )
196
        else:
197
            response = self._client.embeddings.create(model=self.azure_deployment, input=processed_text)
×
198

199
        return {
×
200
            "embedding": response.data[0].embedding,
201
            "meta": {"model": response.model, "usage": dict(response.usage)},
202
        }
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc