• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18818043546

26 Oct 2025 12:38PM UTC coverage: 92.24% (+0.02%) from 92.219%
18818043546

Pull #9942

github

web-flow
Merge 9ca93ecfb into 554616981
Pull Request #9942: feat: Add warm_up() method to ChatGenerators for tool initialization

13491 of 14626 relevant lines covered (92.24%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.03
haystack/components/generators/chat/azure.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Optional, Union
1✔
7

8
from openai.lib._pydantic import to_strict_json_schema
1✔
9
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
1✔
10
from pydantic import BaseModel
1✔
11

12
from haystack import component, default_from_dict, default_to_dict
1✔
13
from haystack.components.generators.chat import OpenAIChatGenerator
1✔
14
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
15
from haystack.tools import (
1✔
16
    ToolsType,
17
    _check_duplicate_tool_names,
18
    deserialize_tools_or_toolset_inplace,
19
    flatten_tools_or_toolsets,
20
    serialize_tools_or_toolset,
21
    warm_up_tools,
22
)
23
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
24
from haystack.utils.http_client import init_http_client
1✔
25

26

27
@component
1✔
28
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
1✔
29
    """
30
    Generates text using OpenAI's models on Azure.
31

32
    It works with the gpt-4 - type models and supports streaming responses
33
    from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
34
    format in input and output.
35

36
    You can customize how the text is generated by passing parameters to the
37
    OpenAI API. Use the `**generation_kwargs` argument when you initialize
38
    the component or when you run it. Any parameter that works with
39
    `openai.ChatCompletion.create` will work here too.
40

41
    For details on OpenAI API parameters, see
42
    [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
43

44
    ### Usage example
45

46
    ```python
47
    from haystack.components.generators.chat import AzureOpenAIChatGenerator
48
    from haystack.dataclasses import ChatMessage
49
    from haystack.utils import Secret
50

51
    messages = [ChatMessage.from_user("What's Natural Language Processing?")]
52

53
    client = AzureOpenAIChatGenerator(
54
        azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
55
        api_key=Secret.from_token("<your-api-key>"),
56
        azure_deployment="<this a model name, e.g. gpt-4o-mini>")
57
    response = client.run(messages)
58
    print(response)
59
    ```
60

61
    ```
62
    {'replies':
63
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
64
        "Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
65
         enabling computers to understand, interpret, and generate human language in a way that is useful.")],
66
         _name=None,
67
         _meta={'model': 'gpt-4o-mini', 'index': 0, 'finish_reason': 'stop',
68
         'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]
69
    }
70
    ```
71
    """
72

73
    # pylint: disable=super-init-not-called
74
    # ruff: noqa: PLR0913
75
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
76
        self,
77
        azure_endpoint: Optional[str] = None,
78
        api_version: Optional[str] = "2023-05-15",
79
        azure_deployment: Optional[str] = "gpt-4o-mini",
80
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
81
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
82
        organization: Optional[str] = None,
83
        streaming_callback: Optional[StreamingCallbackT] = None,
84
        timeout: Optional[float] = None,
85
        max_retries: Optional[int] = None,
86
        generation_kwargs: Optional[dict[str, Any]] = None,
87
        default_headers: Optional[dict[str, str]] = None,
88
        tools: Optional[ToolsType] = None,
89
        tools_strict: bool = False,
90
        *,
91
        azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
92
        http_client_kwargs: Optional[dict[str, Any]] = None,
93
    ):
94
        """
95
        Initialize the Azure OpenAI Chat Generator component.
96

97
        :param azure_endpoint: The endpoint of the deployed model, for example `"https://example-resource.azure.openai.com/"`.
98
        :param api_version: The version of the API to use. Defaults to 2023-05-15.
99
        :param azure_deployment: The deployment of the model, usually the model name.
100
        :param api_key: The API key to use for authentication.
101
        :param azure_ad_token: [Azure Active Directory token](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id).
102
        :param organization: Your organization ID, defaults to `None`. For help, see
103
        [Setting up your organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
104
        :param streaming_callback: A callback function called when a new token is received from the stream.
105
            It accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk)
106
            as an argument.
107
        :param timeout: Timeout for OpenAI client calls. If not set, it defaults to either the
108
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
109
        :param max_retries: Maximum number of retries to contact OpenAI after an internal error.
110
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
111
        :param generation_kwargs: Other parameters to use for the model. These parameters are sent directly to
112
            the OpenAI endpoint. For details, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
113
            Some of the supported parameters:
114
            - `max_completion_tokens`: An upper bound for the number of tokens that can be generated for a completion,
115
                including visible output tokens and reasoning tokens.
116
            - `temperature`: The sampling temperature to use. Higher values mean the model takes more risks.
117
                Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
118
            - `top_p`: Nucleus sampling is an alternative to sampling with temperature, where the model considers
119
                tokens with a top_p probability mass. For example, 0.1 means only the tokens comprising
120
                the top 10% probability mass are considered.
121
            - `n`: The number of completions to generate for each prompt. For example, with 3 prompts and n=2,
122
                the LLM will generate two completions per prompt, resulting in 6 completions total.
123
            - `stop`: One or more sequences after which the LLM should stop generating tokens.
124
            - `presence_penalty`: The penalty applied if a token is already present.
125
                Higher values make the model less likely to repeat the token.
126
            - `frequency_penalty`: Penalty applied if a token has already been generated.
127
                Higher values make the model less likely to repeat the token.
128
            - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
129
                values are the bias to add to that token.
130
            - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
131
                If provided, the output will always be validated against this
132
                format (unless the model returns a tool call).
133
                For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
134
                Notes:
135
                - This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
136
                  Older models only support basic version of structured outputs through `{"type": "json_object"}`.
137
                  For detailed information on JSON mode, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs#json-mode).
138
                - For structured outputs with streaming,
139
                  the `response_format` must be a JSON schema and not a Pydantic model.
140
        :param default_headers: Default headers to use for the AzureOpenAI client.
141
        :param tools:
142
            A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
143
        :param tools_strict:
144
            Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
145
            the schema provided in the `parameters` field of the tool definition, but this may increase latency.
146
        :param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
147
            every request.
148
        :param http_client_kwargs:
149
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
150
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
151
        """
152
        # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
153
        # with the API.
154

155
        # Why is this here?
156
        # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
157
        # None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
158
        # of passing it as a parameter.
159
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
160
        if not azure_endpoint:
1✔
161
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
162

163
        if api_key is None and azure_ad_token is None:
1✔
164
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
165

166
        # The check above makes mypy incorrectly infer that api_key is never None,
167
        # which propagates the incorrect type.
168
        self.api_key = api_key  # type: ignore
1✔
169
        self.azure_ad_token = azure_ad_token
1✔
170
        self.generation_kwargs = generation_kwargs or {}
1✔
171
        self.streaming_callback = streaming_callback
1✔
172
        self.api_version = api_version
1✔
173
        self.azure_endpoint = azure_endpoint
1✔
174
        self.azure_deployment = azure_deployment
1✔
175
        self.organization = organization
1✔
176
        self.model = azure_deployment or "gpt-4o-mini"
1✔
177
        self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
178
        self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
179
        self.default_headers = default_headers or {}
1✔
180
        self.azure_ad_token_provider = azure_ad_token_provider
1✔
181
        self.http_client_kwargs = http_client_kwargs
1✔
182
        _check_duplicate_tool_names(flatten_tools_or_toolsets(tools))
1✔
183
        self.tools = tools
1✔
184
        self.tools_strict = tools_strict
1✔
185

186
        client_args: dict[str, Any] = {
1✔
187
            "api_version": api_version,
188
            "azure_endpoint": azure_endpoint,
189
            "azure_deployment": azure_deployment,
190
            "api_key": api_key.resolve_value() if api_key is not None else None,
191
            "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None,
192
            "organization": organization,
193
            "timeout": self.timeout,
194
            "max_retries": self.max_retries,
195
            "default_headers": self.default_headers,
196
            "azure_ad_token_provider": azure_ad_token_provider,
197
        }
198

199
        self.client = AzureOpenAI(
1✔
200
            http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
201
        )
202
        self.async_client = AsyncAzureOpenAI(
1✔
203
            http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
204
        )
205
        self._is_warmed_up = False
1✔
206

207
    def warm_up(self):
1✔
208
        """
209
        Warm up the Azure OpenAI chat generator.
210

211
        This will warm up the tools registered in the chat generator.
212
        This method is idempotent and will only warm up the tools once.
213
        """
214
        if not self._is_warmed_up:
1✔
215
            warm_up_tools(self.tools)
1✔
216
            self._is_warmed_up = True
1✔
217

218
    def to_dict(self) -> dict[str, Any]:
1✔
219
        """
220
        Serialize this component to a dictionary.
221

222
        :returns:
223
            The serialized component as a dictionary.
224
        """
225
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
226
        azure_ad_token_provider_name = None
1✔
227
        if self.azure_ad_token_provider:
1✔
228
            azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
1✔
229
        # If the response format is a Pydantic model, it's converted to openai's json schema format
230
        # If it's already a json schema, it's left as is
231
        generation_kwargs = self.generation_kwargs.copy()
1✔
232
        response_format = generation_kwargs.get("response_format")
1✔
233
        if response_format and issubclass(response_format, BaseModel):
1✔
234
            json_schema = {
1✔
235
                "type": "json_schema",
236
                "json_schema": {
237
                    "name": response_format.__name__,
238
                    "strict": True,
239
                    "schema": to_strict_json_schema(response_format),
240
                },
241
            }
242
            generation_kwargs["response_format"] = json_schema
1✔
243
        return default_to_dict(
1✔
244
            self,
245
            azure_endpoint=self.azure_endpoint,
246
            azure_deployment=self.azure_deployment,
247
            organization=self.organization,
248
            api_version=self.api_version,
249
            streaming_callback=callback_name,
250
            generation_kwargs=generation_kwargs,
251
            timeout=self.timeout,
252
            max_retries=self.max_retries,
253
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
254
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
255
            default_headers=self.default_headers,
256
            tools=serialize_tools_or_toolset(self.tools),
257
            tools_strict=self.tools_strict,
258
            azure_ad_token_provider=azure_ad_token_provider_name,
259
            http_client_kwargs=self.http_client_kwargs,
260
        )
261

262
    @classmethod
1✔
263
    def from_dict(cls, data: dict[str, Any]) -> "AzureOpenAIChatGenerator":
1✔
264
        """
265
        Deserialize this component from a dictionary.
266

267
        :param data: The dictionary representation of this component.
268
        :returns:
269
            The deserialized component instance.
270
        """
271
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
272
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
273
        init_params = data.get("init_parameters", {})
1✔
274
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
275
        if serialized_callback_handler:
1✔
276
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
277
        serialized_azure_ad_token_provider = init_params.get("azure_ad_token_provider")
1✔
278
        if serialized_azure_ad_token_provider:
1✔
279
            data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
×
280
                serialized_azure_ad_token_provider
281
            )
282
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc