• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15049844454

15 May 2025 04:07PM UTC coverage: 90.446% (+0.04%) from 90.41%
15049844454

Pull #9345

github

web-flow
Merge 9e4071f83 into 2a64cd4e9
Pull Request #9345: feat: add serialization to `State` / move `State` to utils

10981 of 12141 relevant lines covered (90.45%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.85
haystack/components/tools/tool_invoker.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import inspect
1✔
6
import json
1✔
7
from typing import Any, Dict, List, Optional, Union
1✔
8

9
from haystack import component, default_from_dict, default_to_dict, logging
1✔
10
from haystack.core.component.sockets import Sockets
1✔
11
from haystack.dataclasses import ChatMessage, ToolCall
1✔
12
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1✔
13
from haystack.tools import (
1✔
14
    ComponentTool,
15
    Tool,
16
    Toolset,
17
    _check_duplicate_tool_names,
18
    deserialize_tools_or_toolset_inplace,
19
    serialize_tools_or_toolset,
20
)
21
from haystack.tools.errors import ToolInvocationError
1✔
22
from haystack.utils.state import State
1✔
23

24
logger = logging.getLogger(__name__)
1✔
25

26

27
class ToolInvokerError(Exception):
1✔
28
    """Base exception class for ToolInvoker errors."""
29

30
    def __init__(self, message: str):
1✔
31
        super().__init__(message)
1✔
32

33

34
class ToolNotFoundException(ToolInvokerError):
1✔
35
    """Exception raised when a tool is not found in the list of available tools."""
36

37
    def __init__(self, tool_name: str, available_tools: List[str]):
1✔
38
        message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
1✔
39
        super().__init__(message)
1✔
40

41

42
class StringConversionError(ToolInvokerError):
1✔
43
    """Exception raised when the conversion of a tool result to a string fails."""
44

45
    def __init__(self, tool_name: str, conversion_function: str, error: Exception):
1✔
46
        message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
1✔
47
        super().__init__(message)
1✔
48

49

50
class ToolOutputMergeError(ToolInvokerError):
1✔
51
    """Exception raised when merging tool outputs into state fails."""
52

53
    pass
1✔
54

55

56
@component
1✔
57
class ToolInvoker:
1✔
58
    """
59
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
60

61
    Also handles reading/writing from a shared `State`.
62
    At initialization, the ToolInvoker component is provided with a list of available tools.
63
    At runtime, the component processes a list of ChatMessage object containing tool calls
64
    and invokes the corresponding tools.
65
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
66

67
    Usage example:
68
    ```python
69
    from haystack.dataclasses import ChatMessage, ToolCall
70
    from haystack.tools import Tool
71
    from haystack.components.tools import ToolInvoker
72

73
    # Tool definition
74
    def dummy_weather_function(city: str):
75
        return f"The weather in {city} is 20 degrees."
76

77
    parameters = {"type": "object",
78
                "properties": {"city": {"type": "string"}},
79
                "required": ["city"]}
80

81
    tool = Tool(name="weather_tool",
82
                description="A tool to get the weather",
83
                function=dummy_weather_function,
84
                parameters=parameters)
85

86
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
87
    # Here, we create it manually for demonstration purposes
88
    tool_call = ToolCall(
89
        tool_name="weather_tool",
90
        arguments={"city": "Berlin"}
91
    )
92
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
93

94
    # ToolInvoker initialization and run
95
    invoker = ToolInvoker(tools=[tool])
96
    result = invoker.run(messages=[message])
97

98
    print(result)
99
    ```
100

101
    ```
102
    >>  {
103
    >>      'tool_messages': [
104
    >>          ChatMessage(
105
    >>              _role=<ChatRole.TOOL: 'tool'>,
106
    >>              _content=[
107
    >>                  ToolCallResult(
108
    >>                      result='"The weather in Berlin is 20 degrees."',
109
    >>                      origin=ToolCall(
110
    >>                          tool_name='weather_tool',
111
    >>                          arguments={'city': 'Berlin'},
112
    >>                          id=None
113
    >>                      )
114
    >>                  )
115
    >>              ],
116
    >>              _meta={}
117
    >>          )
118
    >>      ]
119
    >>  }
120
    ```
121

122
    Usage example with a Toolset:
123
    ```python
124
    from haystack.dataclasses import ChatMessage, ToolCall
125
    from haystack.tools import Tool, Toolset
126
    from haystack.components.tools import ToolInvoker
127

128
    # Tool definition
129
    def dummy_weather_function(city: str):
130
        return f"The weather in {city} is 20 degrees."
131

132
    parameters = {"type": "object",
133
                "properties": {"city": {"type": "string"}},
134
                "required": ["city"]}
135

136
    tool = Tool(name="weather_tool",
137
                description="A tool to get the weather",
138
                function=dummy_weather_function,
139
                parameters=parameters)
140

141
    # Create a Toolset
142
    toolset = Toolset([tool])
143

144
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
145
    # Here, we create it manually for demonstration purposes
146
    tool_call = ToolCall(
147
        tool_name="weather_tool",
148
        arguments={"city": "Berlin"}
149
    )
150
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
151

152
    # ToolInvoker initialization and run with Toolset
153
    invoker = ToolInvoker(tools=toolset)
154
    result = invoker.run(messages=[message])
155

156
    print(result)
157
    """
158

159
    def __init__(
1✔
160
        self,
161
        tools: Union[List[Tool], Toolset],
162
        raise_on_failure: bool = True,
163
        convert_result_to_json_string: bool = False,
164
        streaming_callback: Optional[StreamingCallbackT] = None,
165
    ):
166
        """
167
        Initialize the ToolInvoker component.
168

169
        :param tools:
170
            A list of tools that can be invoked or a Toolset instance that can resolve tools.
171
        :param raise_on_failure:
172
            If True, the component will raise an exception in case of errors
173
            (tool not found, tool invocation errors, tool result conversion errors).
174
            If False, the component will return a ChatMessage object with `error=True`
175
            and a description of the error in `result`.
176
        :param convert_result_to_json_string:
177
            If True, the tool invocation result will be converted to a string using `json.dumps`.
178
            If False, the tool invocation result will be converted to a string using `str`.
179
        :param streaming_callback:
180
            A callback function that will be called to emit tool results.
181
            Note that the result is only emitted once it becomes available — it is not
182
            streamed incrementally in real time.
183
        :raises ValueError:
184
            If no tools are provided or if duplicate tool names are found.
185
        """
186
        if not tools:
1✔
187
            raise ValueError("ToolInvoker requires at least one tool.")
1✔
188

189
        # could be a Toolset instance or a list of Tools
190
        self.tools = tools
1✔
191
        self.streaming_callback = streaming_callback
1✔
192

193
        # Convert Toolset to list for internal use
194
        if isinstance(tools, Toolset):
1✔
195
            converted_tools = list(tools)
1✔
196
        else:
197
            converted_tools = tools
1✔
198

199
        _check_duplicate_tool_names(converted_tools)
1✔
200
        tool_names = [tool.name for tool in converted_tools]
1✔
201
        duplicates = {name for name in tool_names if tool_names.count(name) > 1}
1✔
202
        if duplicates:
1✔
203
            raise ValueError(f"Duplicate tool names found: {duplicates}")
×
204

205
        self._tools_with_names = dict(zip(tool_names, converted_tools))
1✔
206
        self.raise_on_failure = raise_on_failure
1✔
207
        self.convert_result_to_json_string = convert_result_to_json_string
1✔
208

209
    def _handle_error(self, error: Exception) -> str:
1✔
210
        """
211
        Handles errors by logging and either raising or returning a fallback error message.
212

213
        :param error: The exception instance.
214
        :returns: The fallback error message when `raise_on_failure` is False.
215
        :raises: The provided error if `raise_on_failure` is True.
216
        """
217
        logger.error("{error_exception}", error_exception=error)
1✔
218
        if self.raise_on_failure:
1✔
219
            # We re-raise the original error maintaining the exception chain
220
            raise error
1✔
221
        return str(error)
1✔
222

223
    def _default_output_to_string_handler(self, result: Any) -> str:
1✔
224
        """
225
        Default handler for converting a tool result to a string.
226

227
        :param result: The tool result to convert to a string.
228
        :returns: The converted tool result as a string.
229
        """
230
        if self.convert_result_to_json_string:
1✔
231
            # We disable ensure_ascii so special chars like emojis are not converted
232
            tool_result_str = json.dumps(result, ensure_ascii=False)
1✔
233
        else:
234
            tool_result_str = str(result)
1✔
235
        return tool_result_str
1✔
236

237
    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
1✔
238
        """
239
        Prepares a ChatMessage with the result of a tool invocation.
240

241
        :param result:
242
            The tool result.
243
        :param tool_call:
244
            The ToolCall object containing the tool name and arguments.
245
        :param tool_to_invoke:
246
            The Tool object that was invoked.
247
        :returns:
248
            A ChatMessage object containing the tool result as a string.
249
        :raises
250
            StringConversionError: If the conversion of the tool result to a string fails
251
            and `raise_on_failure` is True.
252
        """
253
        source_key = None
1✔
254
        output_to_string_handler = None
1✔
255
        if tool_to_invoke.outputs_to_string is not None:
1✔
256
            if tool_to_invoke.outputs_to_string.get("source"):
×
257
                source_key = tool_to_invoke.outputs_to_string["source"]
×
258
            if tool_to_invoke.outputs_to_string.get("handler"):
×
259
                output_to_string_handler = tool_to_invoke.outputs_to_string["handler"]
×
260

261
        # If a source key is provided, we extract the result from the source key
262
        if source_key is not None:
1✔
263
            result_to_convert = result.get(source_key)
×
264
        else:
265
            result_to_convert = result
1✔
266

267
        # If no handler is provided, we use the default handler
268
        if output_to_string_handler is None:
1✔
269
            output_to_string_handler = self._default_output_to_string_handler
1✔
270

271
        error = False
1✔
272
        try:
1✔
273
            tool_result_str = output_to_string_handler(result_to_convert)
1✔
274
        except Exception as e:
1✔
275
            try:
1✔
276
                tool_result_str = self._handle_error(
1✔
277
                    StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
278
                )
279
                error = True
1✔
280
            except StringConversionError as conversion_error:
1✔
281
                # If _handle_error re-raises, this properly preserves the chain
282
                raise conversion_error from e
1✔
283
        return ChatMessage.from_tool(tool_result=tool_result_str, error=error, origin=tool_call)
1✔
284

285
    @staticmethod
1✔
286
    def _inject_state_args(tool: Tool, llm_args: Dict[str, Any], state: State) -> Dict[str, Any]:
1✔
287
        """
288
        Combine LLM-provided arguments (llm_args) with state-based arguments.
289

290
        Tool arguments take precedence in the following order:
291
          - LLM overrides state if the same param is present in both
292
          - local tool.inputs mappings (if any)
293
          - function signature name matching
294
        """
295
        final_args = dict(llm_args)  # start with LLM-provided
1✔
296

297
        # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
298
        # to find out which parameters the tool accepts.
299
        if isinstance(tool, ComponentTool):
1✔
300
            # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
301
            assert hasattr(tool._component, "__haystack_input__") and isinstance(
1✔
302
                tool._component.__haystack_input__, Sockets
303
            )
304
            func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
1✔
305
        else:
306
            func_params = set(inspect.signature(tool.function).parameters.keys())
1✔
307

308
        # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
309
        # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
310
        if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
1✔
311
            param_mappings = tool.inputs_from_state
1✔
312
        else:
313
            param_mappings = {name: name for name in func_params}
1✔
314

315
        # Populate final_args from state if not provided by LLM
316
        for state_key, param_name in param_mappings.items():
1✔
317
            if param_name not in final_args and state.has(state_key):
1✔
318
                final_args[param_name] = state.get(state_key)
1✔
319

320
        return final_args
1✔
321

322
    @staticmethod
1✔
323
    def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
1✔
324
        """
325
        Merges the tool result into the State.
326

327
        This method processes the output of a tool execution and integrates it into the global state.
328
        It also determines what message, if any, should be returned for further processing in a conversation.
329

330
        Processing Steps:
331
        1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
332
        2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
333
           The return value in this case is simply the full `result` dictionary.
334
        3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
335
           processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
336
           This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
337
           handlers.
338

339
        :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
340
        :param result: The output from tool execution. Can be a dictionary, or any other type.
341
        :param state: The global State object to which results should be merged.
342
        :returns: Three possible values:
343
            - A string message for conversation
344
            - The merged result dictionary
345
            - Or the raw result if not a dictionary
346
        """
347
        # If result is not a dictionary we exit
348
        if not isinstance(result, dict):
1✔
349
            return
1✔
350

351
        # If there is no specific `outputs_to_state` mapping, we exit
352
        if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
1✔
353
            return
1✔
354

355
        # Update the state with the tool outputs
356
        for state_key, config in tool.outputs_to_state.items():
1✔
357
            # Get the source key from the output config, otherwise use the entire result
358
            source_key = config.get("source", None)
1✔
359
            output_value = result if source_key is None else result.get(source_key)
1✔
360

361
            # Get the handler function, if any
362
            handler = config.get("handler", None)
1✔
363

364
            # Merge other outputs into the state
365
            state.set(state_key, output_value, handler_override=handler)
1✔
366

367
    @component.output_types(tool_messages=List[ChatMessage], state=State)
1✔
368
    def run(
1✔
369
        self,
370
        messages: List[ChatMessage],
371
        state: Optional[State] = None,
372
        streaming_callback: Optional[StreamingCallbackT] = None,
373
    ) -> Dict[str, Any]:
374
        """
375
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
376

377
        :param messages:
378
            A list of ChatMessage objects.
379
        :param state: The runtime state that should be used by the tools.
380
        :param streaming_callback: A callback function that will be called to emit tool results.
381
            Note that the result is only emitted once it becomes available — it is not
382
            streamed incrementally in real time.
383
        :returns:
384
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
385
            Each ChatMessage objects wraps the result of a tool invocation.
386

387
        :raises ToolNotFoundException:
388
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
389
        :raises ToolInvocationError:
390
            If the tool invocation fails and `raise_on_failure` is True.
391
        :raises StringConversionError:
392
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
393
        :raises ToolOutputMergeError:
394
            If merging tool outputs into state fails and `raise_on_failure` is True.
395
        """
396
        if state is None:
1✔
397
            state = State(schema={})
1✔
398

399
        # Only keep messages with tool calls
400
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
401
        streaming_callback = select_streaming_callback(
1✔
402
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
403
        )
404

405
        tool_messages = []
1✔
406
        for message in messages_with_tool_calls:
1✔
407
            for tool_call in message.tool_calls:
1✔
408
                tool_name = tool_call.tool_name
1✔
409

410
                # Check if the tool is available, otherwise return an error message
411
                if tool_name not in self._tools_with_names:
1✔
412
                    error_message = self._handle_error(
1✔
413
                        ToolNotFoundException(tool_name, list(self._tools_with_names.keys()))
414
                    )
415
                    tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
1✔
416
                    continue
1✔
417

418
                tool_to_invoke = self._tools_with_names[tool_name]
1✔
419

420
                # 1) Combine user + state inputs
421
                llm_args = tool_call.arguments.copy()
1✔
422
                final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
1✔
423

424
                # 2) Invoke the tool
425
                try:
1✔
426
                    tool_result = tool_to_invoke.invoke(**final_args)
1✔
427

428
                except ToolInvocationError as e:
1✔
429
                    error_message = self._handle_error(e)
1✔
430
                    tool_messages.append(ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True))
1✔
431
                    continue
1✔
432

433
                # 3) Merge outputs into state
434
                try:
1✔
435
                    self._merge_tool_outputs(tool_to_invoke, tool_result, state)
1✔
436
                except Exception as e:
×
437
                    try:
×
438
                        error_message = self._handle_error(
×
439
                            ToolOutputMergeError(f"Failed to merge tool outputs from tool {tool_name} into State: {e}")
440
                        )
441
                        tool_messages.append(
×
442
                            ChatMessage.from_tool(tool_result=error_message, origin=tool_call, error=True)
443
                        )
444
                        continue
×
445
                    except ToolOutputMergeError as propagated_e:
×
446
                        # Re-raise with proper error chain
447
                        raise propagated_e from e
×
448

449
                # 4) Prepare the tool result ChatMessage message
450
                tool_messages.append(
1✔
451
                    self._prepare_tool_result_message(
452
                        result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
453
                    )
454
                )
455

456
                if streaming_callback is not None:
1✔
457
                    streaming_callback(
1✔
458
                        StreamingChunk(content="", meta={"tool_result": tool_result, "tool_call": tool_call})
459
                    )
460

461
        return {"tool_messages": tool_messages, "state": state}
1✔
462

463
    def to_dict(self) -> Dict[str, Any]:
1✔
464
        """
465
        Serializes the component to a dictionary.
466

467
        :returns:
468
            Dictionary with serialized data.
469
        """
470
        return default_to_dict(
1✔
471
            self,
472
            tools=serialize_tools_or_toolset(self.tools),
473
            raise_on_failure=self.raise_on_failure,
474
            convert_result_to_json_string=self.convert_result_to_json_string,
475
        )
476

477
    @classmethod
1✔
478
    def from_dict(cls, data: Dict[str, Any]) -> "ToolInvoker":
1✔
479
        """
480
        Deserializes the component from a dictionary.
481

482
        :param data:
483
            The dictionary to deserialize from.
484
        :returns:
485
            The deserialized component.
486
        """
487
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
488
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc