• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13259218501

11 Feb 2025 09:00AM UTC coverage: 91.459% (-1.3%) from 92.709%
13259218501

Pull #8829

github

web-flow
Merge 427e76339 into ad90e106a
Pull Request #8829: fix: Look through all streaming chunks for tools calls

9413 of 10292 relevant lines covered (91.46%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

96.32
haystack/components/generators/chat/openai.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import json
1✔
6
import os
1✔
7
from datetime import datetime
1✔
8
from typing import Any, Callable, Dict, List, Optional, Union
1✔
9

10
from openai import OpenAI, Stream
1✔
11
from openai.types.chat import ChatCompletion, ChatCompletionChunk, ChatCompletionMessage
1✔
12
from openai.types.chat.chat_completion import Choice
1✔
13
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
1✔
14

15
from haystack import component, default_from_dict, default_to_dict, logging
1✔
16
from haystack.dataclasses import ChatMessage, StreamingChunk, ToolCall
1✔
17
from haystack.tools.tool import Tool, _check_duplicate_tool_names, deserialize_tools_inplace
1✔
18
from haystack.utils import Secret, deserialize_callable, deserialize_secrets_inplace, serialize_callable
1✔
19

20
logger = logging.getLogger(__name__)
1✔
21

22

23
StreamingCallbackT = Callable[[StreamingChunk], None]
1✔
24

25

26
@component
1✔
27
class OpenAIChatGenerator:
1✔
28
    """
29
    Completes chats using OpenAI's large language models (LLMs).
30

31
    It works with the gpt-4 and gpt-3.5-turbo models and supports streaming responses
32
    from OpenAI API. It uses [ChatMessage](https://docs.haystack.deepset.ai/docs/chatmessage)
33
    format in input and output.
34

35
    You can customize how the text is generated by passing parameters to the
36
    OpenAI API. Use the `**generation_kwargs` argument when you initialize
37
    the component or when you run it. Any parameter that works with
38
    `openai.ChatCompletion.create` will work here too.
39

40
    For details on OpenAI API parameters, see
41
    [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat).
42

43
    ### Usage example
44

45
    ```python
46
    from haystack.components.generators.chat import OpenAIChatGenerator
47
    from haystack.dataclasses import ChatMessage
48

49
    messages = [ChatMessage.from_user("What's Natural Language Processing?")]
50

51
    client = OpenAIChatGenerator()
52
    response = client.run(messages)
53
    print(response)
54
    ```
55
    Output:
56
    ```
57
    {'replies':
58
        [ChatMessage(content='Natural Language Processing (NLP) is a branch of artificial intelligence
59
            that focuses on enabling computers to understand, interpret, and generate human language in
60
            a way that is meaningful and useful.',
61
         role=<ChatRole.ASSISTANT: 'assistant'>, name=None,
62
         meta={'model': 'gpt-4o-mini', 'index': 0, 'finish_reason': 'stop',
63
         'usage': {'prompt_tokens': 15, 'completion_tokens': 36, 'total_tokens': 51}})
64
        ]
65
    }
66
    ```
67
    """
68

69
    def __init__(  # pylint: disable=too-many-positional-arguments
1✔
70
        self,
71
        api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
72
        model: str = "gpt-4o-mini",
73
        streaming_callback: Optional[StreamingCallbackT] = None,
74
        api_base_url: Optional[str] = None,
75
        organization: Optional[str] = None,
76
        generation_kwargs: Optional[Dict[str, Any]] = None,
77
        timeout: Optional[float] = None,
78
        max_retries: Optional[int] = None,
79
        tools: Optional[List[Tool]] = None,
80
        tools_strict: bool = False,
81
    ):
82
        """
83
        Creates an instance of OpenAIChatGenerator. Unless specified otherwise in `model`, uses OpenAI's gpt-4o-mini
84

85
        Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
86
        environment variables to override the `timeout` and `max_retries` parameters respectively
87
        in the OpenAI client.
88

89
        :param api_key: The OpenAI API key.
90
            You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
91
            during initialization.
92
        :param model: The name of the model to use.
93
        :param streaming_callback: A callback function that is called when a new token is received from the stream.
94
            The callback function accepts [StreamingChunk](https://docs.haystack.deepset.ai/docs/data-classes#streamingchunk)
95
            as an argument.
96
        :param api_base_url: An optional base URL.
97
        :param organization: Your organization ID, defaults to `None`. See
98
        [production best practices](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization).
99
        :param generation_kwargs: Other parameters to use for the model. These parameters are sent directly to
100
            the OpenAI endpoint. See OpenAI [documentation](https://platform.openai.com/docs/api-reference/chat) for
101
            more details.
102
            Some of the supported parameters:
103
            - `max_tokens`: The maximum number of tokens the output text can have.
104
            - `temperature`: What sampling temperature to use. Higher values mean the model will take more risks.
105
                Try 0.9 for more creative applications and 0 (argmax sampling) for ones with a well-defined answer.
106
            - `top_p`: An alternative to sampling with temperature, called nucleus sampling, where the model
107
                considers the results of the tokens with top_p probability mass. For example, 0.1 means only the tokens
108
                comprising the top 10% probability mass are considered.
109
            - `n`: How many completions to generate for each prompt. For example, if the LLM gets 3 prompts and n is 2,
110
                it will generate two completions for each of the three prompts, ending up with 6 completions in total.
111
            - `stop`: One or more sequences after which the LLM should stop generating tokens.
112
            - `presence_penalty`: What penalty to apply if a token is already present at all. Bigger values mean
113
                the model will be less likely to repeat the same token in the text.
114
            - `frequency_penalty`: What penalty to apply if a token has already been generated in the text.
115
                Bigger values mean the model will be less likely to repeat the same token in the text.
116
            - `logit_bias`: Add a logit bias to specific tokens. The keys of the dictionary are tokens, and the
117
                values are the bias to add to that token.
118
        :param timeout:
119
            Timeout for OpenAI client calls. If not set, it defaults to either the
120
            `OPENAI_TIMEOUT` environment variable, or 30 seconds.
121
        :param max_retries:
122
            Maximum number of retries to contact OpenAI after an internal error.
123
            If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or set to 5.
124
        :param tools:
125
            A list of tools for which the model can prepare calls.
126
        :param tools_strict:
127
            Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
128
            the schema provided in the `parameters` field of the tool definition, but this may increase latency.
129
        """
130
        self.api_key = api_key
1✔
131
        self.model = model
1✔
132
        self.generation_kwargs = generation_kwargs or {}
1✔
133
        self.streaming_callback = streaming_callback
1✔
134
        self.api_base_url = api_base_url
1✔
135
        self.organization = organization
1✔
136
        self.timeout = timeout
1✔
137
        self.max_retries = max_retries
1✔
138
        self.tools = tools
1✔
139
        self.tools_strict = tools_strict
1✔
140

141
        _check_duplicate_tool_names(tools)
1✔
142

143
        if timeout is None:
1✔
144
            timeout = float(os.environ.get("OPENAI_TIMEOUT", 30.0))
1✔
145
        if max_retries is None:
1✔
146
            max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", 5))
1✔
147

148
        self.client = OpenAI(
1✔
149
            api_key=api_key.resolve_value(),
150
            organization=organization,
151
            base_url=api_base_url,
152
            timeout=timeout,
153
            max_retries=max_retries,
154
        )
155

156
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
157
        """
158
        Data that is sent to Posthog for usage analytics.
159
        """
160
        return {"model": self.model}
×
161

162
    def to_dict(self) -> Dict[str, Any]:
1✔
163
        """
164
        Serialize this component to a dictionary.
165

166
        :returns:
167
            The serialized component as a dictionary.
168
        """
169
        callback_name = serialize_callable(self.streaming_callback) if self.streaming_callback else None
1✔
170
        return default_to_dict(
1✔
171
            self,
172
            model=self.model,
173
            streaming_callback=callback_name,
174
            api_base_url=self.api_base_url,
175
            organization=self.organization,
176
            generation_kwargs=self.generation_kwargs,
177
            api_key=self.api_key.to_dict(),
178
            timeout=self.timeout,
179
            max_retries=self.max_retries,
180
            tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
181
            tools_strict=self.tools_strict,
182
        )
183

184
    @classmethod
1✔
185
    def from_dict(cls, data: Dict[str, Any]) -> "OpenAIChatGenerator":
1✔
186
        """
187
        Deserialize this component from a dictionary.
188

189
        :param data: The dictionary representation of this component.
190
        :returns:
191
            The deserialized component instance.
192
        """
193
        deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
1✔
194
        deserialize_tools_inplace(data["init_parameters"], key="tools")
1✔
195
        init_params = data.get("init_parameters", {})
1✔
196
        serialized_callback_handler = init_params.get("streaming_callback")
1✔
197
        if serialized_callback_handler:
1✔
198
            data["init_parameters"]["streaming_callback"] = deserialize_callable(serialized_callback_handler)
1✔
199
        return default_from_dict(cls, data)
1✔
200

201
    @component.output_types(replies=List[ChatMessage])
1✔
202
    def run(
1✔
203
        self,
204
        messages: List[ChatMessage],
205
        streaming_callback: Optional[StreamingCallbackT] = None,
206
        generation_kwargs: Optional[Dict[str, Any]] = None,
207
        *,
208
        tools: Optional[List[Tool]] = None,
209
        tools_strict: Optional[bool] = None,
210
    ):
211
        """
212
        Invokes chat completion based on the provided messages and generation parameters.
213

214
        :param messages:
215
            A list of ChatMessage instances representing the input messages.
216
        :param streaming_callback:
217
            A callback function that is called when a new token is received from the stream.
218
        :param generation_kwargs:
219
            Additional keyword arguments for text generation. These parameters will
220
            override the parameters passed during component initialization.
221
            For details on OpenAI API parameters, see [OpenAI documentation](https://platform.openai.com/docs/api-reference/chat/create).
222
        :param tools:
223
            A list of tools for which the model can prepare calls. If set, it will override the `tools` parameter set
224
            during component initialization.
225
        :param tools_strict:
226
            Whether to enable strict schema adherence for tool calls. If set to `True`, the model will follow exactly
227
            the schema provided in the `parameters` field of the tool definition, but this may increase latency.
228
            If set, it will override the `tools_strict` parameter set during component initialization.
229

230
        :returns:
231
            A dictionary with the following key:
232
            - `replies`: A list containing the generated responses as ChatMessage instances.
233
        """
234
        if len(messages) == 0:
1✔
235
            return {"replies": []}
×
236

237
        streaming_callback = streaming_callback or self.streaming_callback
1✔
238

239
        api_args = self._prepare_api_call(
1✔
240
            messages=messages,
241
            streaming_callback=streaming_callback,
242
            generation_kwargs=generation_kwargs,
243
            tools=tools,
244
            tools_strict=tools_strict,
245
        )
246
        chat_completion: Union[Stream[ChatCompletionChunk], ChatCompletion] = self.client.chat.completions.create(
1✔
247
            **api_args
248
        )
249

250
        is_streaming = isinstance(chat_completion, Stream)
1✔
251
        assert is_streaming or streaming_callback is None
1✔
252

253
        if is_streaming:
1✔
254
            completions = self._handle_stream_response(
1✔
255
                chat_completion,  # type: ignore
256
                streaming_callback,  # type: ignore
257
            )
258
        else:
259
            assert isinstance(chat_completion, ChatCompletion), "Unexpected response type for non-streaming request."
1✔
260
            completions = [
1✔
261
                self._convert_chat_completion_to_chat_message(chat_completion, choice)
262
                for choice in chat_completion.choices
263
            ]
264

265
        # before returning, do post-processing of the completions
266
        for message in completions:
1✔
267
            self._check_finish_reason(message.meta)
1✔
268

269
        return {"replies": completions}
1✔
270

271
    def _prepare_api_call(  # noqa: PLR0913
1✔
272
        self,
273
        *,
274
        messages: List[ChatMessage],
275
        streaming_callback: Optional[StreamingCallbackT] = None,
276
        generation_kwargs: Optional[Dict[str, Any]] = None,
277
        tools: Optional[List[Tool]] = None,
278
        tools_strict: Optional[bool] = None,
279
    ) -> Dict[str, Any]:
280
        # update generation kwargs by merging with the generation kwargs passed to the run method
281
        generation_kwargs = {**self.generation_kwargs, **(generation_kwargs or {})}
1✔
282

283
        # adapt ChatMessage(s) to the format expected by the OpenAI API
284
        openai_formatted_messages = [message.to_openai_dict_format() for message in messages]
1✔
285

286
        tools = tools or self.tools
1✔
287
        tools_strict = tools_strict if tools_strict is not None else self.tools_strict
1✔
288
        _check_duplicate_tool_names(tools)
1✔
289

290
        openai_tools = {}
1✔
291
        if tools:
1✔
292
            tool_definitions = [
1✔
293
                {"type": "function", "function": {**t.tool_spec, **({"strict": tools_strict} if tools_strict else {})}}
294
                for t in tools
295
            ]
296
            openai_tools = {"tools": tool_definitions}
1✔
297

298
        is_streaming = streaming_callback is not None
1✔
299
        num_responses = generation_kwargs.pop("n", 1)
1✔
300
        if is_streaming and num_responses > 1:
1✔
301
            raise ValueError("Cannot stream multiple responses, please set n=1.")
×
302

303
        return {
1✔
304
            "model": self.model,
305
            "messages": openai_formatted_messages,  # type: ignore[arg-type] # openai expects list of specific message types
306
            "stream": streaming_callback is not None,
307
            "n": num_responses,
308
            **openai_tools,
309
            **generation_kwargs,
310
        }
311

312
    def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
1✔
313
        chunks: List[StreamingChunk] = []
1✔
314
        chunk = None
1✔
315

316
        for chunk in chat_completion:  # pylint: disable=not-an-iterable
1✔
317
            assert len(chunk.choices) == 1, "Streaming responses should have only one choice."
1✔
318
            chunk_delta: StreamingChunk = self._convert_chat_completion_chunk_to_streaming_chunk(chunk)
1✔
319
            chunks.append(chunk_delta)
1✔
320

321
            callback(chunk_delta)
1✔
322

323
        return [self._convert_streaming_chunks_to_chat_message(chunk, chunks)]
1✔
324

325
    def _check_finish_reason(self, meta: Dict[str, Any]) -> None:
1✔
326
        if meta["finish_reason"] == "length":
1✔
327
            logger.warning(
1✔
328
                "The completion for index {index} has been truncated before reaching a natural stopping point. "
329
                "Increase the max_tokens parameter to allow for longer completions.",
330
                index=meta["index"],
331
                finish_reason=meta["finish_reason"],
332
            )
333
        if meta["finish_reason"] == "content_filter":
1✔
334
            logger.warning(
1✔
335
                "The completion for index {index} has been truncated due to the content filter.",
336
                index=meta["index"],
337
                finish_reason=meta["finish_reason"],
338
            )
339

340
    def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
1✔
341
        """
342
        Connects the streaming chunks into a single ChatMessage.
343

344
        :param chunk: The last chunk returned by the OpenAI API.
345
        :param chunks: The list of all `StreamingChunk` objects.
346
        """
347
        text = "".join([chunk.content for chunk in chunks])
1✔
348
        tool_calls = []
1✔
349

350
        # Process tool calls if present in any chunk
351
        tool_call_data = {}  # Track tool calls by ID
1✔
352
        for chunk_payload in chunks:
1✔
353
            tool_calls_meta = chunk_payload.meta.get("tool_calls")
1✔
354
            if tool_calls_meta:
1✔
355
                for delta in tool_calls_meta:
1✔
356
                    if not delta.id in tool_call_data:
1✔
357
                        tool_call_data[delta.id] = {"id": delta.id, "name": "", "arguments": ""}
1✔
358

359
                    if delta.function:
1✔
360
                        if delta.function.name:
1✔
361
                            tool_call_data[delta.id]["name"] = delta.function.name
1✔
362
                        if delta.function.arguments:
1✔
363
                            tool_call_data[delta.id]["arguments"] = delta.function.arguments
1✔
364

365
        # Convert accumulated tool call data into ToolCall objects
366
        for call_data in tool_call_data.values():
1✔
367
            try:
1✔
368
                arguments = json.loads(call_data["arguments"])
1✔
369
                tool_calls.append(ToolCall(id=call_data["id"], tool_name=call_data["name"], arguments=arguments))
1✔
370
            except json.JSONDecodeError:
×
371
                logger.warning(
×
372
                    "Skipping malformed tool call due to invalid JSON. Set `tools_strict=True` for valid JSON. "
373
                    "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
374
                    _id=call_data["id"],
375
                    _name=call_data["name"],
376
                    _arguments=call_data["arguments"],
377
                )
378

379
        meta = {
1✔
380
            "model": chunk.model,
381
            "index": 0,
382
            "finish_reason": chunk.choices[0].finish_reason,
383
            "completion_start_time": chunks[0].meta.get("received_at"),  # first chunk received
384
            "usage": {},  # we don't have usage data for streaming responses
385
        }
386

387
        return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
1✔
388

389
    def _convert_chat_completion_to_chat_message(self, completion: ChatCompletion, choice: Choice) -> ChatMessage:
1✔
390
        """
391
        Converts the non-streaming response from the OpenAI API to a ChatMessage.
392

393
        :param completion: The completion returned by the OpenAI API.
394
        :param choice: The choice returned by the OpenAI API.
395
        :return: The ChatMessage.
396
        """
397
        message: ChatCompletionMessage = choice.message
1✔
398
        text = message.content
1✔
399
        tool_calls = []
1✔
400
        if openai_tool_calls := message.tool_calls:
1✔
401
            for openai_tc in openai_tool_calls:
1✔
402
                arguments_str = openai_tc.function.arguments
1✔
403
                try:
1✔
404
                    arguments = json.loads(arguments_str)
1✔
405
                    tool_calls.append(ToolCall(id=openai_tc.id, tool_name=openai_tc.function.name, arguments=arguments))
1✔
406
                except json.JSONDecodeError:
1✔
407
                    logger.warning(
1✔
408
                        "OpenAI returned a malformed JSON string for tool call arguments. This tool call "
409
                        "will be skipped. To always generate a valid JSON, set `tools_strict` to `True`. "
410
                        "Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
411
                        _id=openai_tc.id,
412
                        _name=openai_tc.function.name,
413
                        _arguments=arguments_str,
414
                    )
415

416
        chat_message = ChatMessage.from_assistant(text=text, tool_calls=tool_calls)
1✔
417
        chat_message._meta.update(
1✔
418
            {
419
                "model": completion.model,
420
                "index": choice.index,
421
                "finish_reason": choice.finish_reason,
422
                "usage": dict(completion.usage or {}),
423
            }
424
        )
425
        return chat_message
1✔
426

427
    def _convert_chat_completion_chunk_to_streaming_chunk(self, chunk: ChatCompletionChunk) -> StreamingChunk:
1✔
428
        """
429
        Converts the streaming response chunk from the OpenAI API to a StreamingChunk.
430

431
        :param chunk: The chunk returned by the OpenAI API.
432

433
        :returns:
434
            The StreamingChunk.
435
        """
436
        # we stream the content of the chunk if it's not a tool or function call
437
        choice: ChunkChoice = chunk.choices[0]
1✔
438
        content = choice.delta.content or ""
1✔
439
        chunk_message = StreamingChunk(content)
1✔
440
        # but save the tool calls and function call in the meta if they are present
441
        # and then connect the chunks in the _convert_streaming_chunks_to_chat_message method
442
        chunk_message.meta.update(
1✔
443
            {
444
                "model": chunk.model,
445
                "index": choice.index,
446
                "tool_calls": choice.delta.tool_calls,
447
                "finish_reason": choice.finish_reason,
448
                "received_at": datetime.now().isoformat(),
449
            }
450
        )
451
        return chunk_message
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc