• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 17760482097

16 Sep 2025 08:58AM UTC coverage: 92.059% (+0.01%) from 92.047%
17760482097

Pull #9754

github

web-flow
Merge 0c0073114 into e3d4e9e94
Pull Request #9754: feat: support structured outputs in `OpenAIChatGenerator`

12996 of 14117 relevant lines covered (92.06%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.55
haystack/components/generators/chat/azure.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Optional, Union
1✔
7

8
from openai.lib._pydantic import to_strict_json_schema
1✔
9
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI, AzureADTokenProvider, AzureOpenAI
1✔
10
from pydantic import BaseModel
1✔
11

12
from haystack import component, default_from_dict, default_to_dict
1✔
13
from haystack.components.generators.chat import OpenAIChatGenerator
1✔
14
from haystack.dataclasses.streaming_chunk import StreamingCallbackT
1✔
15
from haystack.tools import (
1✔
16
    Tool,
17
    Toolset,
18
    _check_duplicate_tool_names,
19
    deserialize_tools_or_toolset_inplace,
20
    serialize_tools_or_toolset,
21
)
22
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
23
from haystack.utils.http_client import init_http_client
1✔
24

25

26
@component
1✔
27
class AzureOpenAIChatGenerator(OpenAIChatGenerator):
1✔
28
    """
29
    Generates text using OpenAI's models on Azure.
30

31
    It works with the gpt-4 - type models and supports streaming responses
32
    from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
33
    format in input and output.
34

35
    You can customize how the text is generated by passing parameters to the
36
    OpenAI API. Use the `**generation_kwargs` argument when you initialize
37
    the component or when you run it. Any parameter that works with
38
    `openai.ChatCompletion.create` will work here too.
39

40
    For details on OpenAI API parameters, see
41
    [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
42

43
    ### Usage example
44

45
    ```python
46
    from haystack.components.generators.chat import AzureOpenAIChatGenerator
47
    from haystack.dataclasses import ChatMessage
48
    from haystack.utils import Secret
49

50
    messages = [ChatMessage.from_user("What's Natural Language Processing?")]
51

52
    client = AzureOpenAIChatGenerator(
53
        azure_endpoint="<Your Azure endpoint e.g. `https://your-company.azure.openai.com/>",
54
        api_key=Secret.from_token("<your-api-key>"),
55
        azure_deployment="<this a model name, e.g. gpt-4o-mini>")
56
    response = client.run(messages)
57
    print(response)
58
    ```
59

60
    ```
61
    {'replies':
62
        [ChatMessage(_role=<ChatRole.ASSISTANT: 'assistant'>, _content=[TextContent(text=
63
        "Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
64
         enabling computers to understand, interpret, and generate human language in a way that is useful.")],
65
         _name=None,
66
         _meta={'model': 'gpt-4o-mini', 'index': 0, 'finish_reason': 'stop',
67
         'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})]
68
    }
69
    ```
70
    """
71

72
    # pylint: disable=super-init-not-called
73
    # ruff: noqa: PLR0913
74
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
75
        self,
76
        azure_endpoint: Optional[str] = None,
77
        api_version: Optional[str] = "2023-05-15",
78
        azure_deployment: Optional[str] = "gpt-4o-mini",
79
        api_key: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_API_KEY", strict=False),
80
        azure_ad_token: Optional[Secret] = Secret.from_env_var("AZURE_OPENAI_AD_TOKEN", strict=False),
81
        organization: Optional[str] = None,
82
        streaming_callback: Optional[StreamingCallbackT] = None,
83
        timeout: Optional[float] = None,
84
        max_retries: Optional[int] = None,
85
        generation_kwargs: Optional[dict[str, Any]] = None,
86
        default_headers: Optional[dict[str, str]] = None,
87
        tools: Optional[Union[list[Tool], Toolset]] = None,
88
        tools_strict: bool = False,
89
        *,
90
        azure_ad_token_provider: Optional[Union[AzureADTokenProvider, AsyncAzureADTokenProvider]] = None,
91
        http_client_kwargs: Optional[dict[str, Any]] = None,
92
    ):
93
        """
94
        Initialize the Azure OpenAI Chat Generator component.
95

96
        :param azure_endpoint: The endpoint of the deployed model, for example `"https://example-resource.azure.openai.com/"`.
97
        :param api_version: The version of the API to use. Defaults to 2023-05-15.
98
        :param azure_deployment: The deployment of the model, usually the model name.
99
        :param api_key: The API key to use for authentication.
100
        :param azure_ad_token: [Azure Active Directory token](https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id).
101
        :param organization: Your organization ID, defaults to `None`. For help, see
102
        [Setting up your organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
103
        :param streaming_callback: A callback function called when a new token is received from the stream.
104
            It accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk)
105
            as an argument.
106
        :param timeout: Timeout for OpenAI client calls. If not set, it defaults to either the
107
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
108
        :param max_retries: Maximum number of retries to contact OpenAI after an internal error.
109
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
110
        :param generation_kwargs: Other parameters to use for the model. These parameters are sent directly to
111
            the OpenAI endpoint. For details, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
112
            Some of the supported parameters:
113
            - `max_tokens`: The maximum number of tokens the output text can have.
114
            - `temperature`: The sampling temperature to use. Higher values mean the model takes more risks.
115
                Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
116
            - `top_p`: Nucleus sampling is an alternative to sampling with temperature, where the model considers
117
                tokens with a top_p probability mass. For example, 0.1 means only the tokens comprising
118
                the top 10% probability mass are considered.
119
            - `n`: The number of completions to generate for each prompt. For example, with 3 prompts and n=2,
120
                the LLM will generate two completions per prompt, resulting in 6 completions total.
121
            - `stop`: One or more sequences after which the LLM should stop generating tokens.
122
            - `presence_penalty`: The penalty applied if a token is already present.
123
                Higher values make the model less likely to repeat the token.
124
            - `frequency_penalty`: Penalty applied if a token has already been generated.
125
                Higher values make the model less likely to repeat the token.
126
            - `logit_bias`: Adds a logit bias to specific tokens. The keys of the dictionary are tokens, and the
127
                values are the bias to add to that token.
128
            - `response_format`: A JSON schema or a Pydantic model that enforces the structure of the model's response.
129
                If provided, the output will always be validated against this
130
                format (unless the model returns a tool call).
131
                For details, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs).
132
                Notes:
133
                - This parameter accepts Pydantic models and JSON schemas for latest models starting from GPT-4o.
134
                  Older models only support basic version of structured outputs through `{"type": "json_object"}`.
135
                  For detailed information on JSON mode, see the [OpenAI Structured Outputs documentation](https://platform.openai.com/docs/guides/structured-outputs#json-mode).
136
                - For structured outputs with streaming,
137
                  the `response_format` must be a JSON schema and not a Pydantic model.
138
        :param default_headers: Default headers to use for the AzureOpenAI client.
139
        :param tools:
140
            A list of tools or a Toolset for which the model can prepare calls. This parameter can accept either a
141
            list of `Tool` objects or a `Toolset` instance.
142
        :param tools_strict:
143
            Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
144
            the schema provided in the `parameters` field of the tool definition, but this may increase latency.
145
        :param azure_ad_token_provider: A function that returns an Azure Active Directory token, will be invoked on
146
            every request.
147
        :param http_client_kwargs:
148
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
149
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
150
        """
151
        # We intentionally do not call super().__init__ here because we only need to instantiate the client to interact
152
        # with the API.
153

154
        # Why is this here?
155
        # AzureOpenAI init is forcing us to use an init method that takes either base_url or azure_endpoint as not
156
        # None init parameters. This way we accommodate the use case where env var AZURE_OPENAI_ENDPOINT is set instead
157
        # of passing it as a parameter.
158
        azure_endpoint = azure_endpoint or os.environ.get("AZURE_OPENAI_ENDPOINT")
1✔
159
        if not azure_endpoint:
1✔
160
            raise ValueError("Please provide an Azure endpoint or set the environment variable AZURE_OPENAI_ENDPOINT.")
×
161

162
        if api_key is None and azure_ad_token is None:
1✔
163
            raise ValueError("Please provide an API key or an Azure Active Directory token.")
×
164

165
        # The check above makes mypy incorrectly infer that api_key is never None,
166
        # which propagates the incorrect type.
167
        self.api_key = api_key  # type: ignore
1✔
168
        self.azure_ad_token = azure_ad_token
1✔
169
        self.generation_kwargs = generation_kwargs or {}
1✔
170
        self.streaming_callback = streaming_callback
1✔
171
        self.api_version = api_version
1✔
172
        self.azure_endpoint = azure_endpoint
1✔
173
        self.azure_deployment = azure_deployment
1✔
174
        self.organization = organization
1✔
175
        self.model = azure_deployment or "gpt-4o-mini"
1✔
176
        self.timeout = timeout if timeout is not None else float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
177
        self.max_retries = max_retries if max_retries is not None else int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
178
        self.default_headers = default_headers or {}
1✔
179
        self.azure_ad_token_provider = azure_ad_token_provider
1✔
180
        self.http_client_kwargs = http_client_kwargs
1✔
181
        _check_duplicate_tool_names(list(tools or []))
1✔
182
        self.tools = tools
1✔
183
        self.tools_strict = tools_strict
1✔
184

185
        client_args: dict[str, Any] = {
1✔
186
            "api_version": api_version,
187
            "azure_endpoint": azure_endpoint,
188
            "azure_deployment": azure_deployment,
189
            "api_key": api_key.resolve_value() if api_key is not None else None,
190
            "azure_ad_token": azure_ad_token.resolve_value() if azure_ad_token is not None else None,
191
            "organization": organization,
192
            "timeout": self.timeout,
193
            "max_retries": self.max_retries,
194
            "default_headers": self.default_headers,
195
            "azure_ad_token_provider": azure_ad_token_provider,
196
        }
197

198
        self.client = AzureOpenAI(
1✔
199
            http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_args
200
        )
201
        self.async_client = AsyncAzureOpenAI(
1✔
202
            http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_args
203
        )
204

205
    def to_dict(self) -> dict[str, Any]:
1✔
206
        """
207
        Serialize this component to a dictionary.
208

209
        :returns:
210
            The serialized component as a dictionary.
211
        """
212
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
213
        azure_ad_token_provider_name = None
1✔
214
        if self.azure_ad_token_provider:
1✔
215
            azure_ad_token_provider_name = serialize_callable(self.azure_ad_token_provider)
1✔
216
        # If the response format is a Pydantic model, it's converted to openai's json schema format
217
        # If it's already a json schema, it's left as is
218
        generation_kwargs = self.generation_kwargs.copy()
1✔
219
        response_format = generation_kwargs.get("response_format")
1✔
220
        if response_format and issubclass(response_format, BaseModel):
1✔
221
            json_schema = {
1✔
222
                "type": "json_schema",
223
                "json_schema": {
224
                    "name": response_format.__name__,
225
                    "strict": True,
226
                    "schema": to_strict_json_schema(response_format),
227
                },
228
            }
229
            generation_kwargs["response_format"] = json_schema
1✔
230
        return default_to_dict(
1✔
231
            self,
232
            azure_endpoint=self.azure_endpoint,
233
            azure_deployment=self.azure_deployment,
234
            organization=self.organization,
235
            api_version=self.api_version,
236
            streaming_callback=callback_name,
237
            generation_kwargs=generation_kwargs,
238
            timeout=self.timeout,
239
            max_retries=self.max_retries,
240
            api_key=self.api_key.to_dict() if self.api_key is not None else None,
241
            azure_ad_token=self.azure_ad_token.to_dict() if self.azure_ad_token is not None else None,
242
            default_headers=self.default_headers,
243
            tools=serialize_tools_or_toolset(self.tools),
244
            tools_strict=self.tools_strict,
245
            azure_ad_token_provider=azure_ad_token_provider_name,
246
            http_client_kwargs=self.http_client_kwargs,
247
        )
248

249
    @classmethod
1✔
250
    def from_dict(cls, data: dict[str, Any]) -> "AzureOpenAIChatGenerator":
1✔
251
        """
252
        Deserialize this component from a dictionary.
253

254
        :param data: The dictionary representation of this component.
255
        :returns:
256
            The deserialized component instance.
257
        """
258
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key", "azure_ad_token"])
1✔
259
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
260
        init_params = data.get("init_parameters", {})
1✔
261
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
262
        if serialized_callback_handler:
1✔
263
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
×
264
        serialized_azure_ad_token_provider = init_params.get("azure_ad_token_provider")
1✔
265
        if serialized_azure_ad_token_provider:
1✔
266
            data["init_parameters"]["azure_ad_token_provider"] = deserialize_callable(
×
267
                serialized_azure_ad_token_provider
268
            )
269
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc