• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14794014302

02 May 2025 11:06AM UTC coverage: 90.448% (-0.07%) from 90.513%
14794014302

Pull #9290

github

web-flow
Merge de5d77f03 into e3f9da13d
Pull Request #9290: feat: enable streaming ToolCall/Result from Agent

10899 of 12050 relevant lines covered (90.45%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.78
haystack/components/tools/tool_invoker.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
import json
1✔
7
from typing import Any, Dict, List, Optional, Union
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging
1✔
10
from haystack.core.component.sockets import Sockets
1✔
11
from haystack.dataclasses import ChatMessage, State, ToolCall
1✔
12
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1✔
13
from haystack.tools import (
1✔
14
    ComponentTool,
15
    Tool,
16
    Toolset,
17
    _check_duplicate_tool_names,
18
    deserialize_tools_or_toolset_inplace,
19
    serialize_tools_or_toolset,
20
)
21
from haystack.tools.errors import ToolInvocationError
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
class ToolInvokerError(Exception):
1✔
27
    """Base exception class for ToolInvoker errors."""
28

29
    def __init__(self, message: str):
1✔
30
        super().__init__(message)
1✔
31

32

33
class ToolNotFoundException(ToolInvokerError):
1✔
34
    """Exception raised when a tool is not found in the list of available tools."""
35

36
    def __init__(self, tool_name: str, available_tools: List[str]):
1✔
37
        message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
1✔
38
        super().__init__(message)
1✔
39

40

41
class StringConversionError(ToolInvokerError):
1✔
42
    """Exception raised when the conversion of a tool result to a string fails."""
43

44
    def __init__(self, tool_name: str, conversion_function: str, error: Exception):
1✔
45
        message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
1✔
46
        super().__init__(message)
1✔
47

48

49
class ToolOutputMergeError(ToolInvokerError):
1✔
50
    """Exception raised when merging tool outputs into state fails."""
51

52
    pass
1✔
53

54

55
@component
1✔
56
class ToolInvoker:
1✔
57
    """
58
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
59

60
    Also handles reading/writing from a shared `State`.
61
    At initialization, the ToolInvoker component is provided with a list of available tools.
62
    At runtime, the component processes a list of ChatMessage object containing tool calls
63
    and invokes the corresponding tools.
64
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
65

66
    Usage example:
67
    ```python
68
    from haystack.dataclasses import ChatMessage, ToolCall
69
    from haystack.tools import Tool
70
    from haystack.components.tools import ToolInvoker
71

72
    # Tool definition
73
    def dummy_weather_function(city: str):
74
        return f"The weather in {city} is 20 degrees."
75

76
    parameters = {"type": "object",
77
                "properties": {"city": {"type": "string"}},
78
                "required": ["city"]}
79

80
    tool = Tool(name="weather_tool",
81
                description="A tool to get the weather",
82
                function=dummy_weather_function,
83
                parameters=parameters)
84

85
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
86
    # Here, we create it manually for demonstration purposes
87
    tool_call = ToolCall(
88
        tool_name="weather_tool",
89
        arguments={"city": "Berlin"}
90
    )
91
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
92

93
    # ToolInvoker initialization and run
94
    invoker = ToolInvoker(tools=[tool])
95
    result = invoker.run(messages=[message])
96

97
    print(result)
98
    ```
99

100
    ```
101
    >>  {
102
    >>      'tool_messages': [
103
    >>          ChatMessage(
104
    >>              _role=<ChatRole.TOOL: 'tool'>,
105
    >>              _content=[
106
    >>                  ToolCallResult(
107
    >>                      result='"The weather in Berlin is 20 degrees."',
108
    >>                      origin=ToolCall(
109
    >>                          tool_name='weather_tool',
110
    >>                          arguments={'city': 'Berlin'},
111
    >>                          id=None
112
    >>                      )
113
    >>                  )
114
    >>              ],
115
    >>              _meta={}
116
    >>          )
117
    >>      ]
118
    >>  }
119
    ```
120

121
    Usage example with a Toolset:
122
    ```python
123
    from haystack.dataclasses import ChatMessage, ToolCall
124
    from haystack.tools import Tool, Toolset
125
    from haystack.components.tools import ToolInvoker
126

127
    # Tool definition
128
    def dummy_weather_function(city: str):
129
        return f"The weather in {city} is 20 degrees."
130

131
    parameters = {"type": "object",
132
                "properties": {"city": {"type": "string"}},
133
                "required": ["city"]}
134

135
    tool = Tool(name="weather_tool",
136
                description="A tool to get the weather",
137
                function=dummy_weather_function,
138
                parameters=parameters)
139

140
    # Create a Toolset
141
    toolset = Toolset([tool])
142

143
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
144
    # Here, we create it manually for demonstration purposes
145
    tool_call = ToolCall(
146
        tool_name="weather_tool",
147
        arguments={"city": "Berlin"}
148
    )
149
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
150

151
    # ToolInvoker initialization and run with Toolset
152
    invoker = ToolInvoker(tools=toolset)
153
    result = invoker.run(messages=[message])
154

155
    print(result)
156
    """
157

158
    def __init__(
1✔
159
        self,
160
        tools: Union[List[Tool], Toolset],
161
        raise_on_failure: bool = True,
162
        convert_result_to_json_string: bool = False,
163
        streaming_callback: Optional[StreamingCallbackT] = None,
164
    ):
165
        """
166
        Initialize the ToolInvoker component.
167

168
        :param tools:
169
            A list of tools that can be invoked or a Toolset instance that can resolve tools.
170
        :param raise_on_failure:
171
            If True, the component will raise an exception in case of errors
172
            (tool not found, tool invocation errors, tool result conversion errors).
173
            If False, the component will return a ChatMessage object with `error=True`
174
            and a description of the error in `result`.
175
        :param convert_result_to_json_string:
176
            If True, the tool invocation result will be converted to a string using `json.dumps`.
177
            If False, the tool invocation result will be converted to a string using `str`.
178
        :param streaming_callback:
179
            A callback function that will be called to emit tool results.
180
            Note that the result is only emitted once it becomes available — it is not
181
            streamed incrementally in real time.
182
        :raises ValueError:
183
            If no tools are provided or if duplicate tool names are found.
184
        """
185
        if not tools:
1✔
186
            raise ValueError("ToolInvoker requires at least one tool.")
1✔
187

188
        # could be a Toolset instance or a list of Tools
189
        self.tools = tools
1✔
190
        self.streaming_callback = streaming_callback
1✔
191

192
        # Convert Toolset to list for internal use
193
        if isinstance(tools, Toolset):
1✔
194
            converted_tools = list(tools)
1✔
195
        else:
196
            converted_tools = tools
1✔
197

198
        _check_duplicate_tool_names(converted_tools)
1✔
199
        tool_names = [tool.name for tool in converted_tools]
1✔
200
        duplicates = {name for name in tool_names if tool_names.count(name) > 1}
1✔
201
        if duplicates:
1✔
202
            raise ValueError(f"Duplicate tool names found: {duplicates}")
×
203

204
        self._tools_with_names = dict(zip(tool_names, converted_tools))
1✔
205
        self.raise_on_failure = raise_on_failure
1✔
206
        self.convert_result_to_json_string = convert_result_to_json_string
1✔
207

208
    def _handle_error(self, error: Exception) -> str:
1✔
209
        """
210
        Handles errors by logging and either raising or returning a fallback error message.
211

212
        :param error: The exception instance.
213
        :returns: The fallback error message when `raise_on_failure` is False.
214
        :raises: The provided error if `raise_on_failure` is True.
215
        """
216
        logger.error("{error_exception}", error_exception=error)
1✔
217
        if self.raise_on_failure:
1✔
218
            # We re-raise the original error maintaining the exception chain
219
            raise error
1✔
220
        return str(error)
1✔
221

222
    def _default_output_to_string_handler(self, result: Any) -> str:
1✔
223
        """
224
        Default handler for converting a tool result to a string.
225

226
        :param result: The tool result to convert to a string.
227
        :returns: The converted tool result as a string.
228
        """
229
        if self.convert_result_to_json_string:
1✔
230
            # We disable ensure_ascii so special chars like emojis are not converted
231
            tool_result_str = json.dumps(result, ensure_ascii=False)
1✔
232
        else:
233
            tool_result_str = str(result)
1✔
234
        return tool_result_str
1✔
235

236
    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
1✔
237
        """
238
        Prepares a ChatMessage with the result of a tool invocation.
239

240
        :param result:
241
            The tool result.
242
        :param tool_call:
243
            The ToolCall object containing the tool name and arguments.
244
        :param tool_to_invoke:
245
            The Tool object that was invoked.
246
        :returns:
247
            A ChatMessage object containing the tool result as a string.
248
        :raises
249
            StringConversionError: If the conversion of the tool result to a string fails
250
            and `raise_on_failure` is True.
251
        """
252
        source_key = None
1✔
253
        output_to_string_handler = None
1✔
254
        if tool_to_invoke.outputs_to_string is not None:
1✔
255
            if tool_to_invoke.outputs_to_string.get("source"):
×
256
                source_key = tool_to_invoke.outputs_to_string["source"]
×
257
            if tool_to_invoke.outputs_to_string.get("handler"):
×
258
                output_to_string_handler = tool_to_invoke.outputs_to_string["handler"]
×
259

260
        # If a source key is provided, we extract the result from the source key
261
        if source_key is not None:
1✔
262
            result_to_convert = result.get(source_key)
×
263
        else:
264
            result_to_convert = result
1✔
265

266
        # If no handler is provided, we use the default handler
267
        if output_to_string_handler is None:
1✔
268
            output_to_string_handler = self._default_output_to_string_handler
1✔
269

270
        error = False
1✔
271
        try:
1✔
272
            tool_result_str = output_to_string_handler(result_to_convert)
1✔
273
        except Exception as e:
1✔
274
            try:
1✔
275
                tool_result_str = self._handle_error(
1✔
276
                    StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
277
                )
278
                error = True
1✔
279
            except StringConversionError as conversion_error:
1✔
280
                # If _handle_error re-raises, this properly preserves the chain
281
                raise conversion_error from e
1✔
282
        return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
1✔
283

284
    @staticmethod
1✔
285
    def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
1✔
286
        """
287
        Combine LLM-provided arguments (llm_args) with state-based arguments.
288

289
        Tool arguments take precedence in the following order:
290
          - LLM overrides state if the same param is present in both
291
          - local tool.inputs mappings (if any)
292
          - function signature name matching
293
        """
294
        final_args = dict(llm_args)  # start with LLM-provided
1✔
295

296
        # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
297
        # to find out which parameters the tool accepts.
298
        if isinstance(tool, ComponentTool):
1✔
299
            # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
300
            assert hasattr(tool._component, "__haystack_input__") and isinstance(
1✔
301
                tool._component.__haystack_input__, Sockets
302
            )
303
            func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
1✔
304
        else:
305
            func_params = set(inspect.signature(tool.function).parameters.keys())
1✔
306

307
        # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
308
        # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
309
        if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
1✔
310
            param_mappings = tool.inputs_from_state
1✔
311
        else:
312
            param_mappings = {name: name for name in func_params}
1✔
313

314
        # Populate final_args from state if not provided by LLM
315
        for state_key, param_name in param_mappings.items():
1✔
316
            if param_name not in final_args and state.has(state_key):
1✔
317
                final_args[param_name] = state.get(state_key)
1✔
318

319
        return final_args
1✔
320

321
    @staticmethod
1✔
322
    def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
1✔
323
        """
324
        Merges the tool result into the State.
325

326
        This method processes the output of a tool execution and integrates it into the global state.
327
        It also determines what message, if any, should be returned for further processing in a conversation.
328

329
        Processing Steps:
330
        1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
331
        2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
332
           The return value in this case is simply the full `result` dictionary.
333
        3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
334
           processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
335
           This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
336
           handlers.
337

338
        :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
339
        :param result: The output from tool execution. Can be a dictionary, or any other type.
340
        :param state: The global State object to which results should be merged.
341
        :returns: Three possible values:
342
            - A string message for conversation
343
            - The merged result dictionary
344
            - Or the raw result if not a dictionary
345
        """
346
        # If result is not a dictionary we exit
347
        if not isinstance(result, dict):
1✔
348
            return
1✔
349

350
        # If there is no specific `outputs_to_state` mapping, we exit
351
        if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
1✔
352
            return
1✔
353

354
        # Update the state with the tool outputs
355
        for state_key, config in tool.outputs_to_state.items():
1✔
356
            # Get the source key from the output config, otherwise use the entire result
357
            source_key = config.get("source", None)
1✔
358
            output_value = result if source_key is None else result.get(source_key)
1✔
359

360
            # Get the handler function, if any
361
            handler = config.get("handler", None)
1✔
362

363
            # Merge other outputs into the state
364
            state.set(state_key, output_value, handler_override=handler)
1✔
365

366
    @component.output_types(tool_messages=List[ChatMessage], state=State)
1✔
367
    def run(
1✔
368
        self,
369
        messages: List[ChatMessage],
370
        state: Optional[State] = None,
371
        streaming_callback: Optional[StreamingCallbackT] = None,
372
    ) -> Dict[str, Any]:
373
        """
374
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
375

376
        :param messages:
377
            A list of ChatMessage objects.
378
        :param state: The runtime state that should be used by the tools.
379
        :param streaming_callback: A callback function that will be called to emit tool results.
380
            Note that the result is only emitted once it becomes available — it is not
381
            streamed incrementally in real time.
382
        :returns:
383
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
384
            Each ChatMessage objects wraps the result of a tool invocation.
385

386
        :raises ToolNotFoundException:
387
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
388
        :raises ToolInvocationError:
389
            If the tool invocation fails and `raise_on_failure` is True.
390
        :raises StringConversionError:
391
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
392
        :raises ToolOutputMergeError:
393
            If merging tool outputs into state fails and `raise_on_failure` is True.
394
        """
395
        if state is None:
1✔
396
            state = State(schema={})
1✔
397

398
        # Only keep messages with tool calls
399
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
400
        streaming_callback = select_streaming_callback(
1✔
401
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
402
        )
403

404
        tool_messages = []
1✔
405
        for message in messages_with_tool_calls:
1✔
406
            for tool_call in message.tool_calls:
1✔
407
                tool_name = tool_call.tool_name
1✔
408

409
                # Check if the tool is available, otherwise return an error message
410
                if tool_name not in self._tools_with_names:
1✔
411
                    error_message = self._handle_error(
1✔
412
                        ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
413
                    )
414
                    tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
1✔
415
                    continue
1✔
416

417
                tool_to_invoke = self._tools_with_names[tool_name]
1✔
418

419
                # 1) Combine user + state inputs
420
                llm_args = tool_call.arguments.copy()
1✔
421
                final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
1✔
422

423
                # 2) Invoke the tool
424
                try:
1✔
425
                    tool_result = tool_to_invoke.invoke(**final_args)
1✔
426

427
                except ToolInvocationError as e:
1✔
428
                    error_message = self._handle_error(e)
1✔
429
                    tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
1✔
430
                    continue
1✔
431

432
                # 3) Merge outputs into state
433
                try:
1✔
434
                    self._merge_tool_outputs(tool_to_invoke, tool_result, state)
1✔
435
                except Exception as e:
×
436
                    try:
×
437
                        error_message = self._handle_error(
×
438
                            ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
439
                        )
440
                        tool_messages.append(
×
441
                            ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
442
                        )
443
                        continue
×
444
                    except ToolOutputMergeError as propagated_e:
×
445
                        # Re-raise with proper error chain
446
                        raise propagated_e from e
×
447

448
                # 4) Prepare the tool result ChatMessage message
449
                tool_messages.append(
1✔
450
                    self._prepare_tool_result_message(
451
                        result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
452
                    )
453
                )
454

455
                if streaming_callback is not None:
1✔
456
                    streaming_callback(
1✔
457
                        StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call})
458
                    )
459

460
        return {"tool_messages": tool_messages, "state": state}
1✔
461

462
    def to_dict(self) -> Dict[str, Any]:
1✔
463
        """
464
        Serializes the component to a dictionary.
465

466
        :returns:
467
            Dictionary with serialized data.
468
        """
469
        return default_to_dict(
1✔
470
            self,
471
            tools=serialize_tools_or_toolset(self.tools),
472
            raise_on_failure=self.raise_on_failure,
473
            convert_result_to_json_string=self.convert_result_to_json_string,
474
        )
475

476
    @classmethod
1✔
477
    def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
1✔
478
        """
479
        Deserializes the component from a dictionary.
480

481
        :param data:
482
            The dictionary to deserialize from.
483
        :returns:
484
            The deserialized component.
485
        """
486
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
487
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc