• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18592817487

17 Oct 2025 12:33PM UTC coverage: 92.2% (+0.1%) from 92.062%
18592817487

Pull #9859

github

web-flow
Merge f20ff2b98 into a43c47b63
Pull Request #9859: feat: Add FallbackChatGenerator

13346 of 14475 relevant lines covered (92.2%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.2
haystack/components/generators/openai.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import os
1✔
6
from typing import Any, Optional, Union
1✔
7

8
from openai import OpenAI, Stream
1✔
9
from openai.types.chat import ChatCompletion, ChatCompletionChunk
1✔
10

11
from haystack import component, default_from_dict, default_to_dict, logging
1✔
12
from haystack.components.generators.chat.openai import (
1✔
13
    _check_finish_reason,
14
    _convert_chat_completion_chunk_to_streaming_chunk,
15
    _convert_chat_completion_to_chat_message,
16
)
17
from haystack.components.generators.utils import _convert_streaming_chunks_to_chat_message
1✔
18
from haystack.dataclasses import (
1✔
19
    ChatMessage,
20
    ComponentInfo,
21
    StreamingCallbackT,
22
    StreamingChunk,
23
    select_streaming_callback,
24
)
25
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
26
from haystack.utils.http_client import init_http_client
1✔
27

28
logger = logging.getLogger(__name__)
1✔
29

30

31
@component
1✔
32
class OpenAIGenerator:
1✔
33
    """
34
    Generates text using OpenAI's large language models (LLMs).
35

36
    It works with the gpt-4 and o-series models and supports streaming responses
37
    from OpenAI API. It uses strings as input and output.
38

39
    You can customize how the text is generated by passing parameters to the
40
    OpenAI API. Use the `**generation_kwargs` argument when you initialize
41
    the component or when you run it. Any parameter that works with
42
    `openai.ChatCompletion.create` will work here too.
43

44

45
    For details on OpenAI API parameters, see
46
    [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
47

48
    ### Usage example
49

50
    ```python
51
    from haystack.components.generators import OpenAIGenerator
52
    client = OpenAIGenerator()
53
    response = client.run("What's Natural Language Processing? Be brief.")
54
    print(response)
55

56
    >> {'replies': ['Natural Language Processing (NLP) is a branch of artificial intelligence that focuses on
57
    >> the interaction between computers and human language. It involves enabling computers to understand, interpret,
58
    >> and respond to natural human language in a way that is both meaningful and useful.'], 'meta': [{'model':
59
    >> 'gpt-4o-mini', 'index': 0, 'finish_reason': 'stop', 'usage': {'prompt_tokens': 16,
60
    >> 'completion_tokens': 49, 'total_tokens': 65}}]}
61
    ```
62
    """
63

64
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
65
        self,
66
        api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
67
        model: str = "gpt-4o-mini",
68
        streaming_callback: Optional[StreamingCallbackT] = None,
69
        api_base_url: Optional[str] = None,
70
        organization: Optional[str] = None,
71
        system_prompt: Optional[str] = None,
72
        generation_kwargs: Optional[dict[str, Any]] = None,
73
        timeout: Optional[float] = None,
74
        max_retries: Optional[int] = None,
75
        http_client_kwargs: Optional[dict[str, Any]] = None,
76
    ):
77
        """
78
        Creates an instance of OpenAIGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini
79

80
        By setting the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES' you can change the timeout and max_retries parameters
81
        in the OpenAI client.
82

83
        :param api_key: The OpenAI API key to connect to OpenAI.
84
        :param model: The name of the model to use.
85
        :param streaming_callback: A callback function that is called when a new token is received from the stream.
86
            The callback function accepts StreamingChunk as an argument.
87
        :param api_base_url: An optional base URL.
88
        :param organization: The Organization ID, defaults to `None`.
89
        :param system_prompt: The system prompt to use for text generation. If not provided, the system prompt is
90
        omitted, and the default system prompt of the model is used.
91
        :param generation_kwargs: Other parameters to use for the model. These parameters are all sent directly to
92
            the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
93
            more details.
94
            Some of the supported parameters:
95
            - `max_completion_tokens`: An upper bound for the number of tokens that can be generated for a completion,
96
                including visible output tokens and reasoning tokens.
97
            - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
98
                Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
99
            - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
100
                considers the results of the tokens with top_p probability mass. So, 0.1 means only the tokens
101
                comprising the top 10% probability mass are considered.
102
            - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
103
                it will generate two completions for each of the three prompts, ending up with 6 completions in total.
104
            - `stop`: One or more sequences after which the LLM should stop generating tokens.
105
            - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
106
                the model will be less likely to repeat the same token in the text.
107
            - `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
108
                Bigger values mean the model will be less likely to repeat the same token in the text.
109
            - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
110
                values are the bias to add to that token.
111
        :param timeout:
112
            Timeout for OpenAI Client calls, if not set it is inferred from the `OPENAI_TIMEOUT` environment variable
113
            or set to 30.
114
        :param max_retries:
115
            Maximum retries to establish contact with OpenAI if it returns an internal error, if not set it is inferred
116
            from the `OPENAI_MAX_RETRIES` environment variable or set to 5.
117
        :param http_client_kwargs:
118
            A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
119
            For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
120
        """
121
        self.api_key = api_key
1✔
122
        self.model = model
1✔
123
        self.generation_kwargs = generation_kwargs or {}
1✔
124
        self.system_prompt = system_prompt
1✔
125
        self.streaming_callback = streaming_callback
1✔
126

127
        self.api_base_url = api_base_url
1✔
128
        self.organization = organization
1✔
129
        self.http_client_kwargs = http_client_kwargs
1✔
130

131
        if timeout is None:
1✔
132
            timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
1✔
133
        if max_retries is None:
1✔
134
            max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
1✔
135

136
        self.client = OpenAI(
1✔
137
            api_key=api_key.resolve_value(),
138
            organization=organization,
139
            base_url=api_base_url,
140
            timeout=timeout,
141
            max_retries=max_retries,
142
            http_client=init_http_client(self.http_client_kwargs, async_client=False),
143
        )
144

145
    def _get_telemetry_data(self) -> dict[str, Any]:
1✔
146
        """
147
        Data that is sent to Posthog for usage analytics.
148
        """
149
        return {"model": self.model}
×
150

151
    def to_dict(self) -> dict[str, Any]:
1✔
152
        """
153
        Serialize this component to a dictionary.
154

155
        :returns:
156
            The serialized component as a dictionary.
157
        """
158
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
159
        return default_to_dict(
1✔
160
            self,
161
            model=self.model,
162
            streaming_callback=callback_name,
163
            api_base_url=self.api_base_url,
164
            organization=self.organization,
165
            generation_kwargs=self.generation_kwargs,
166
            system_prompt=self.system_prompt,
167
            api_key=self.api_key.to_dict(),
168
            http_client_kwargs=self.http_client_kwargs,
169
        )
170

171
    @classmethod
1✔
172
    def from_dict(cls, data: dict[str, Any]) -> "OpenAIGenerator":
1✔
173
        """
174
        Deserialize this component from a dictionary.
175

176
        :param data:
177
            The dictionary representation of this component.
178
        :returns:
179
            The deserialized component instance.
180
        """
181
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
1✔
182
        init_params = data.get("init_parameters", {})
1✔
183
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
184
        if serialized_callback_handler:
1✔
185
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
1✔
186
        return default_from_dict(cls, data)
1✔
187

188
    @component.output_types(replies=list[str], meta=list[dict[str, Any]])
1✔
189
    def run(
1✔
190
        self,
191
        prompt: str,
192
        system_prompt: Optional[str] = None,
193
        streaming_callback: Optional[StreamingCallbackT] = None,
194
        generation_kwargs: Optional[dict[str, Any]] = None,
195
    ):
196
        """
197
        Invoke the text generation inference based on the provided messages and generation parameters.
198

199
        :param prompt:
200
            The string prompt to use for text generation.
201
        :param system_prompt:
202
            The system prompt to use for text generation. If this run time system prompt is omitted, the system
203
            prompt, if defined at initialisation time, is used.
204
        :param streaming_callback:
205
            A callback function that is called when a new token is received from the stream.
206
        :param generation_kwargs:
207
            Additional keyword arguments for text generation. These parameters will potentially override the parameters
208
            passed in the `__init__` method. For more details on the parameters supported by the OpenAI API, refer to
209
            the OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat/create).
210
        :returns:
211
            A list of strings containing the generated responses and a list of dictionaries containing the metadata
212
        for each response.
213
        """
214
        message = ChatMessage.from_user(prompt)
1✔
215
        if system_prompt is not None:
1✔
216
            messages = [ChatMessage.from_system(system_prompt), message]
×
217
        elif self.system_prompt:
1✔
218
            messages = [ChatMessage.from_system(self.system_prompt), message]
×
219
        else:
220
            messages = [message]
1✔
221

222
        # update generation kwargs by merging with the generation kwargs passed to the run method
223
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
224

225
        # check if streaming_callback is passed
226
        streaming_callback = select_streaming_callback(
1✔
227
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
228
        )
229

230
        # adapt ChatMessage(s) to the format expected by the OpenAI API
231
        openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
1✔
232

233
        completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
1✔
234
            model=self.model,
235
            messages=openai_formatted_messages,  # type: ignore
236
            stream=streaming_callback is not None,
237
            **generation_kwargs,
238
        )
239

240
        completions: list[ChatMessage] = []
1✔
241
        if streaming_callback is not None:
1✔
242
            num_responses = generation_kwargs.pop("n", 1)
1✔
243
            if num_responses > 1:
1✔
244
                raise ValueError("Cannot stream multiple responses, please set n=1.")
×
245

246
            component_info = ComponentInfo.from_component(self)
1✔
247
            chunks: list[StreamingChunk] = []
1✔
248
            for chunk in completion:
1✔
249
                chunk_delta: StreamingChunk = _convert_chat_completion_chunk_to_streaming_chunk(
1✔
250
                    chunk=chunk,  # type: ignore
251
                    previous_chunks=chunks,
252
                    component_info=component_info,
253
                )
254
                chunks.append(chunk_delta)
1✔
255
                streaming_callback(chunk_delta)
1✔
256

257
            completions = [_convert_streaming_chunks_to_chat_message(chunks=chunks)]
1✔
258
        elif isinstance(completion, ChatCompletion):
1✔
259
            completions = [
1✔
260
                _convert_chat_completion_to_chat_message(completion=completion, choice=choice)
261
                for choice in completion.choices
262
            ]
263

264
        # before returning, do post-processing of the completions
265
        for response in completions:
1✔
266
            _check_finish_reason(response.meta)
1✔
267

268
        return {"replies": [message.text for message in completions], "meta": [message.meta for message in completions]}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc