• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18530018322

15 Oct 2025 01:10PM UTC coverage: 92.075% (-0.03%) from 92.103%
18530018322

Pull #9880

github

web-flow
Merge 6dad544fe into cfa5d2761
Pull Request #9880: draft: Expand tools param to include list[Toolset]

13279 of 14422 relevant lines covered (92.07%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.9
haystack/components/tools/tool_invoker.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import contextvars
1✔
7
import inspect
1✔
8
import json
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from functools import partial
1✔
11
from typing import Any, Callable, Optional, Union
1✔
12

13
from haystack.components.agents import State
1✔
14
from haystack.core.component.component import component
1✔
15
from haystack.core.component.sockets import Sockets
1✔
16
from haystack.core.serialization import default_from_dict, default_to_dict, logging
1✔
17
from haystack.dataclasses import ChatMessage, ToolCall
1✔
18
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1✔
19
from haystack.tools import (
1✔
20
    ComponentTool,
21
    Tool,
22
    Toolset,
23
    _check_duplicate_tool_names,
24
    deserialize_tools_or_toolset_inplace,
25
    flatten_tools_or_toolsets,
26
    serialize_tools_or_toolset,
27
)
28
from haystack.tools.errors import ToolInvocationError
1✔
29
from haystack.tracing.utils import _serializable_value
1✔
30
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
31

32
logger = logging.getLogger(__name__)
1✔
33

34

35
class ToolInvokerError(Exception):
1✔
36
    """Base exception class for ToolInvoker errors."""
37

38
    def __init__(self, message: str):
1✔
39
        super().__init__(message)
1✔
40

41

42
class ToolNotFoundException(ToolInvokerError):
1✔
43
    """Exception raised when a tool is not found in the list of available tools."""
44

45
    def __init__(self, tool_name: str, available_tools: list[str]):
1✔
46
        message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
1✔
47
        super().__init__(message)
1✔
48

49

50
class StringConversionError(ToolInvokerError):
1✔
51
    """Exception raised when the conversion of a tool result to a string fails."""
52

53
    def __init__(self, tool_name: str, conversion_function: str, error: Exception):
1✔
54
        message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
1✔
55
        super().__init__(message)
1✔
56

57

58
class ToolOutputMergeError(ToolInvokerError):
1✔
59
    """Exception raised when merging tool outputs into state fails."""
60

61
    @classmethod
1✔
62
    def from_exception(cls, tool_name: str, error: Exception) -> "ToolOutputMergeError":
1✔
63
        """
64
        Create a ToolOutputMergeError from an exception.
65
        """
66
        message = f"Failed to merge tool outputs from tool {tool_name} into State: {error}"
1✔
67
        return cls(message)
1✔
68

69

70
@component
1✔
71
class ToolInvoker:
1✔
72
    """
73
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
74

75
    Also handles reading/writing from a shared `State`.
76
    At initialization, the ToolInvoker component is provided with a list of available tools.
77
    At runtime, the component processes a list of ChatMessage object containing tool calls
78
    and invokes the corresponding tools.
79
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
80

81
    Usage example:
82
    ```python
83
    from haystack.dataclasses import ChatMessage, ToolCall
84
    from haystack.tools import Tool
85
    from haystack.components.tools import ToolInvoker
86

87
    # Tool definition
88
    def dummy_weather_function(city: str):
89
        return f"The weather in {city} is 20 degrees."
90

91
    parameters = {"type": "object",
92
                "properties": {"city": {"type": "string"}},
93
                "required": ["city"]}
94

95
    tool = Tool(name="weather_tool",
96
                description="A tool to get the weather",
97
                function=dummy_weather_function,
98
                parameters=parameters)
99

100
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
101
    # Here, we create it manually for demonstration purposes
102
    tool_call = ToolCall(
103
        tool_name="weather_tool",
104
        arguments={"city": "Berlin"}
105
    )
106
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
107

108
    # ToolInvoker initialization and run
109
    invoker = ToolInvoker(tools=[tool])
110
    result = invoker.run(messages=[message])
111

112
    print(result)
113
    ```
114

115
    ```
116
    >>  {
117
    >>      'tool_messages': [
118
    >>          ChatMessage(
119
    >>              _role=<ChatRole.TOOL: 'tool'>,
120
    >>              _content=[
121
    >>                  ToolCallResult(
122
    >>                      result='"The weather in Berlin is 20 degrees."',
123
    >>                      origin=ToolCall(
124
    >>                          tool_name='weather_tool',
125
    >>                          arguments={'city': 'Berlin'},
126
    >>                          id=None
127
    >>                      )
128
    >>                  )
129
    >>              ],
130
    >>              _meta={}
131
    >>          )
132
    >>      ]
133
    >>  }
134
    ```
135

136
    Usage example with a Toolset:
137
    ```python
138
    from haystack.dataclasses import ChatMessage, ToolCall
139
    from haystack.tools import Tool, Toolset
140
    from haystack.components.tools import ToolInvoker
141

142
    # Tool definition
143
    def dummy_weather_function(city: str):
144
        return f"The weather in {city} is 20 degrees."
145

146
    parameters = {"type": "object",
147
                "properties": {"city": {"type": "string"}},
148
                "required": ["city"]}
149

150
    tool = Tool(name="weather_tool",
151
                description="A tool to get the weather",
152
                function=dummy_weather_function,
153
                parameters=parameters)
154

155
    # Create a Toolset
156
    toolset = Toolset([tool])
157

158
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
159
    # Here, we create it manually for demonstration purposes
160
    tool_call = ToolCall(
161
        tool_name="weather_tool",
162
        arguments={"city": "Berlin"}
163
    )
164
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
165

166
    # ToolInvoker initialization and run with Toolset
167
    invoker = ToolInvoker(tools=toolset)
168
    result = invoker.run(messages=[message])
169

170
    print(result)
171
    """
172

173
    def __init__(
1✔
174
        self,
175
        tools: Union[list[Tool], Toolset, list[Toolset]],
176
        raise_on_failure: bool = True,
177
        convert_result_to_json_string: bool = False,
178
        streaming_callback: Optional[StreamingCallbackT] = None,
179
        *,
180
        enable_streaming_callback_passthrough: bool = False,
181
        max_workers: int = 4,
182
    ):
183
        """
184
        Initialize the ToolInvoker component.
185

186
        :param tools:
187
            A list of tools, a Toolset instance, or a list of Toolset instances that can resolve tools.
188
        :param raise_on_failure:
189
            If True, the component will raise an exception in case of errors
190
            (tool not found, tool invocation errors, tool result conversion errors).
191
            If False, the component will return a ChatMessage object with `error=True`
192
            and a description of the error in `result`.
193
        :param convert_result_to_json_string:
194
            If True, the tool invocation result will be converted to a string using `json.dumps`.
195
            If False, the tool invocation result will be converted to a string using `str`.
196
        :param streaming_callback:
197
            A callback function that will be called to emit tool results.
198
            Note that the result is only emitted once it becomes available — it is not
199
            streamed incrementally in real time.
200
        :param enable_streaming_callback_passthrough:
201
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
202
            This allows tools to stream their results back to the client.
203
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
204
            If False, the `streaming_callback` will not be passed to the tool invocation.
205
        :param max_workers:
206
            The maximum number of workers to use in the thread pool executor.
207
            This also decides the maximum number of concurrent tool invocations.
208
        :raises ValueError:
209
            If no tools are provided or if duplicate tool names are found.
210
        """
211
        self.tools = tools
1✔
212
        self.streaming_callback = streaming_callback
1✔
213
        self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
1✔
214
        self.max_workers = max_workers
1✔
215
        self.raise_on_failure = raise_on_failure
1✔
216
        self.convert_result_to_json_string = convert_result_to_json_string
1✔
217

218
        self._tools_with_names = self._validate_and_prepare_tools(tools)
1✔
219

220
    @staticmethod
1✔
221
    def _make_context_bound_invoke(tool_to_invoke: Tool, final_args: dict[str, Any]) -> Callable[[], Any]:
1✔
222
        """
223
        Create a zero-arg callable that invokes the tool under the caller's contextvars Context.
224

225
        We copy and use contextvars to preserve the caller’s ambient execution context (for example the active
226
        tracing Span) across thread boundaries. Python’s contextvars do not automatically propagate to worker
227
        threads (or to threadpool tasks spawned via run_in_executor), so without intervention nested tool calls
228
        would lose their parent trace/span and appear as separate roots. By capturing the current Context in the
229
        caller thread and invoking the tool under ctx.run(...) inside the executor, we ensure proper span parentage,
230
        consistent tagging, and reliable log/trace correlation in both sync and async paths. The callable returns
231
        ToolInvocationError instead of raising so parallel execution can collect failures.
232
        """
233
        ctx = contextvars.copy_context()
1✔
234

235
        def _runner() -> Any:
1✔
236
            try:
1✔
237
                return ctx.run(partial(tool_to_invoke.invoke, **final_args))
1✔
238
            except ToolInvocationError as e:
1✔
239
                return e
1✔
240

241
        return _runner
1✔
242

243
    @staticmethod
1✔
244
    def _validate_and_prepare_tools(tools: Union[list[Tool], Toolset, list[Toolset]]) -> dict[str, Tool]:
1✔
245
        """
246
        Validates and prepares tools for use by the ToolInvoker.
247

248
        :param tools: A list of tools or a Toolset instance.
249
        :returns: A dictionary mapping tool names to Tool instances.
250
        :raises ValueError: If no tools are provided or if duplicate tool names are found.
251
        """
252
        if not tools:
1✔
253
            raise ValueError("ToolInvoker requires at least one tool.")
1✔
254

255
        converted_tools = flatten_tools_or_toolsets(tools)
1✔
256

257
        _check_duplicate_tool_names(converted_tools)
1✔
258
        tool_names = [tool.name for tool in converted_tools]
1✔
259
        duplicates = {name for name in tool_names if tool_names.count(name) > 1}
1✔
260
        if duplicates:
1✔
261
            raise ValueError(f"Duplicate tool names found: {duplicates}")
×
262

263
        return dict(zip(tool_names, converted_tools))
1✔
264

265
    def _default_output_to_string_handler(self, result: Any) -> str:
1✔
266
        """
267
        Default handler for converting a tool result to a string.
268

269
        :param result: The tool result to convert to a string.
270
        :returns: The converted tool result as a string.
271
        """
272
        # We iterate through all items in result and call to_dict() if present
273
        # Relevant for a few reasons:
274
        # - If using convert_result_to_json_string we'd rather convert Haystack objects to JSON serializable dicts
275
        # - If using default str() we prefer converting Haystack objects to dicts rather than relying on the
276
        #   __repr__ method
277
        serializable = _serializable_value(result)
1✔
278

279
        if self.convert_result_to_json_string:
1✔
280
            try:
1✔
281
                # We disable ensure_ascii so special chars like emojis are not converted
282
                str_result = json.dumps(serializable, ensure_ascii=False)
1✔
283
            except Exception as error:
×
284
                # If the result is not JSON serializable, we fall back to str
285
                logger.warning(
×
286
                    "Tool result is not JSON serializable. Falling back to str conversion. "
287
                    "Result: {result}\nError: {error}",
288
                    result=result,
289
                    err=error,
290
                )
291
                str_result = str(result)
×
292
            return str_result
1✔
293

294
        return str(serializable)
1✔
295

296
    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
1✔
297
        """
298
        Prepares a ChatMessage with the result of a tool invocation.
299

300
        :param result:
301
            The tool result.
302
        :param tool_call:
303
            The ToolCall object containing the tool name and arguments.
304
        :param tool_to_invoke:
305
            The Tool object that was invoked.
306
        :returns:
307
            A ChatMessage object containing the tool result as a string.
308
        :raises
309
            StringConversionError: If the conversion of the tool result to a string fails
310
            and `raise_on_failure` is True.
311
        """
312
        outputs_config = tool_to_invoke.outputs_to_string or {}
1✔
313
        source_key = outputs_config.get("source")
1✔
314

315
        # If no handler is provided, we use the default handler
316
        output_to_string_handler = outputs_config.get("handler", self._default_output_to_string_handler)
1✔
317

318
        # If a source key is provided, we extract the result from the source key
319
        result_to_convert = result.get(source_key) if source_key is not None else result
1✔
320

321
        try:
1✔
322
            tool_result_str = output_to_string_handler(result_to_convert)
1✔
323
            chat_message = ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
1✔
324
        except Exception as e:
1✔
325
            error = StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
1✔
326
            if self.raise_on_failure:
1✔
327
                raise error from e
1✔
328
            logger.error("{error_exception}", error_exception=error)
1✔
329
            chat_message = ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
1✔
330
        return chat_message
1✔
331

332
    @staticmethod
1✔
333
    def _get_func_params(tool: Tool) -> set:
1✔
334
        """
335
        Returns the function parameters of the tool's invoke method.
336

337
        This method inspects the tool's function signature to determine which parameters the tool accepts.
338
        """
339
        # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
340
        # to find out which parameters the tool accepts.
341
        if isinstance(tool, ComponentTool):
1✔
342
            # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
343
            assert hasattr(tool._component, "__haystack_input__") and isinstance(
1✔
344
                tool._component.__haystack_input__, Sockets
345
            )
346
            func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
1✔
347
        else:
348
            func_params = set(inspect.signature(tool.function).parameters.keys())
1✔
349

350
        return func_params
1✔
351

352
    @staticmethod
1✔
353
    def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> dict[str, Any]:
1✔
354
        """
355
        Combine LLM-provided arguments (llm_args) with state-based arguments.
356

357
        Tool arguments take precedence in the following order:
358
          - LLM overrides state if the same param is present in both
359
          - local tool.inputs_from_state mappings (if any)
360
          - function signature name matching
361
        """
362
        final_args = dict(llm_args)  # start with LLM-provided
1✔
363
        func_params = ToolInvoker._get_func_params(tool)
1✔
364

365
        # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
366
        # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
367
        if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
1✔
368
            param_mappings = tool.inputs_from_state
1✔
369
        else:
370
            param_mappings = {name: name for name in func_params}
1✔
371

372
        # Populate final_args from state if not provided by LLM
373
        for state_key, param_name in param_mappings.items():
1✔
374
            if param_name not in final_args and state.has(state_key):
1✔
375
                final_args[param_name] = state.get(state_key)
1✔
376

377
        return final_args
1✔
378

379
    @staticmethod
1✔
380
    def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
1✔
381
        """
382
        Merges the tool result into the State.
383

384
        This method processes the output of a tool execution and integrates it into the global state.
385
        It also determines what message, if any, should be returned for further processing in a conversation.
386

387
        Processing Steps:
388
        1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
389
        2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
390
           The return value in this case is simply the full `result` dictionary.
391
        3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
392
           processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
393
           This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
394
           handlers.
395

396
        :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
397
        :param result: The output from tool execution. Can be a dictionary, or any other type.
398
        :param state: The global State object to which results should be merged.
399
        :returns: Three possible values:
400
            - A string message for conversation
401
            - The merged result dictionary
402
            - Or the raw result if not a dictionary
403
        """
404
        # If result is not a dictionary we exit
405
        if not isinstance(result, dict):
1✔
406
            return
1✔
407

408
        # If there is no specific `outputs_to_state` mapping, we exit
409
        if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
1✔
410
            return
1✔
411

412
        # Update the state with the tool outputs
413
        for state_key, config in tool.outputs_to_state.items():
1✔
414
            # Get the source key from the output config, otherwise use the entire result
415
            source_key = config.get("source", None)
1✔
416
            output_value = result.get(source_key) if source_key else result
1✔
417

418
            # Merge other outputs into the state
419
            state.set(state_key, output_value, handler_override=config.get("handler"))
1✔
420

421
    @staticmethod
1✔
422
    def _create_tool_result_streaming_chunk(tool_messages: list[ChatMessage], tool_call: ToolCall) -> StreamingChunk:
1✔
423
        """Create a streaming chunk for a tool result."""
424
        return StreamingChunk(
1✔
425
            content="",
426
            index=len(tool_messages) - 1,
427
            tool_call_result=tool_messages[-1].tool_call_results[0],
428
            start=True,
429
            meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
430
        )
431

432
    def _prepare_tool_call_params(
1✔
433
        self,
434
        *,
435
        messages_with_tool_calls: list[ChatMessage],
436
        state: State,
437
        streaming_callback: Optional[StreamingCallbackT],
438
        enable_streaming_passthrough: bool,
439
        tools_with_names: dict[str, Tool],
440
    ) -> tuple[list[ToolCall], list[dict[str, Any]], list[ChatMessage]]:
441
        """
442
        Prepare tool call parameters for execution and collect any error messages.
443

444
        :param messages_with_tool_calls: Messages containing tool calls to process
445
        :param state: The current state for argument injection
446
        :param streaming_callback: Optional streaming callback to inject
447
        :param enable_streaming_passthrough: Whether to pass streaming callback to tools
448
        :returns: Tuple of (tool_calls, tool_call_params, error_messages)
449
        """
450
        tool_call_params = []
1✔
451
        error_messages = []
1✔
452
        tool_calls = []
1✔
453

454
        for message in messages_with_tool_calls:
1✔
455
            for tool_call in message.tool_calls:
1✔
456
                tool_name = tool_call.tool_name
1✔
457

458
                # Check if the tool is available, otherwise return an error message
459
                if tool_name not in tools_with_names:
1✔
460
                    error = ToolNotFoundException(tool_name, list(tools_with_names.keys()))
1✔
461
                    if self.raise_on_failure:
1✔
462
                        raise error
1✔
463
                    logger.error("{error_exception}", error_exception=error)
1✔
464
                    error_messages.append(ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True))
1✔
465
                    continue
1✔
466

467
                tool_to_invoke = tools_with_names[tool_name]
1✔
468

469
                # Combine user + state inputs
470
                llm_args = tool_call.arguments.copy()
1✔
471
                final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
1✔
472

473
                # Check whether to inject streaming_callback
474
                if (
1✔
475
                    enable_streaming_passthrough
476
                    and streaming_callback is not None
477
                    and "streaming_callback" not in final_args
478
                    and "streaming_callback" in self._get_func_params(tool_to_invoke)
479
                ):
480
                    final_args["streaming_callback"] = streaming_callback
1✔
481

482
                tool_call_params.append({"tool_to_invoke": tool_to_invoke, "final_args": final_args})
1✔
483
                tool_calls.append(tool_call)
1✔
484

485
        return tool_calls, tool_call_params, error_messages
1✔
486

487
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
488
    def run(
1✔
489
        self,
490
        messages: list[ChatMessage],
491
        state: Optional[State] = None,
492
        streaming_callback: Optional[StreamingCallbackT] = None,
493
        *,
494
        enable_streaming_callback_passthrough: Optional[bool] = None,
495
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
496
    ) -> dict[str, Any]:
497
        """
498
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
499

500
        :param messages:
501
            A list of ChatMessage objects.
502
        :param state: The runtime state that should be used by the tools.
503
        :param streaming_callback: A callback function that will be called to emit tool results.
504
            Note that the result is only emitted once it becomes available — it is not
505
            streamed incrementally in real time.
506
        :param enable_streaming_callback_passthrough:
507
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
508
            This allows tools to stream their results back to the client.
509
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
510
            If False, the `streaming_callback` will not be passed to the tool invocation.
511
            If None, the value from the constructor will be used.
512
        :param tools:
513
            A list of tools to use for the tool invoker. If set, overrides the tools set in the constructor.
514
        :returns:
515
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
516
            Each ChatMessage objects wraps the result of a tool invocation.
517

518
        :raises ToolNotFoundException:
519
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
520
        :raises ToolInvocationError:
521
            If the tool invocation fails and `raise_on_failure` is True.
522
        :raises StringConversionError:
523
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
524
        :raises ToolOutputMergeError:
525
            If merging tool outputs into state fails and `raise_on_failure` is True.
526
        """
527
        tools_with_names = self._tools_with_names
1✔
528
        if tools is not None:
1✔
529
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
530
            logger.debug(
1✔
531
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
532
            )
533

534
        if state is None:
1✔
535
            state = State(schema={})
1✔
536

537
        resolved_enable_streaming_passthrough = (
1✔
538
            enable_streaming_callback_passthrough
539
            if enable_streaming_callback_passthrough is not None
540
            else self.enable_streaming_callback_passthrough
541
        )
542

543
        # Only keep messages with tool calls
544
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
545
        streaming_callback = select_streaming_callback(
1✔
546
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
547
        )
548

549
        if not messages_with_tool_calls:
1✔
550
            return {"tool_messages": [], "state": state}
1✔
551

552
        # 1) Collect all tool calls and their parameters for parallel execution
553
        tool_messages = []
1✔
554
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
555
            messages_with_tool_calls=messages_with_tool_calls,
556
            state=state,
557
            streaming_callback=streaming_callback,
558
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
559
            tools_with_names=tools_with_names,
560
        )
561
        tool_messages.extend(error_messages)
1✔
562

563
        if not tool_call_params:
1✔
564
            return {"tool_messages": tool_messages, "state": state}
1✔
565

566
        # 2) Execute valid tool calls in parallel
567
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
568
            futures = []
1✔
569
            for params in tool_call_params:
1✔
570
                callable_ = self._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
571
                futures.append(executor.submit(callable_))
1✔
572

573
            # 3) Gather and process results: handle errors and merge outputs into state
574
            for future, tool_call in zip(futures, tool_calls):
1✔
575
                result = future.result()
1✔
576

577
                if isinstance(result, ToolInvocationError):
1✔
578
                    # a) This is an error, create error Tool message
579
                    if self.raise_on_failure:
1✔
580
                        raise result
1✔
581
                    logger.error("{error_exception}", error_exception=result)
1✔
582
                    tool_messages.append(ChatMessage.from_tool(tool_result=str(result), origin=tool_call, error=True))
1✔
583
                else:
584
                    # b) In case of success, merge outputs into state
585
                    try:
1✔
586
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
587
                        self._merge_tool_outputs(tool=tool_to_invoke, result=result, state=state)
1✔
588
                        tool_messages.append(
1✔
589
                            self._prepare_tool_result_message(
590
                                result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
591
                            )
592
                        )
593
                    except Exception as e:
1✔
594
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
595
                        if self.raise_on_failure:
1✔
596
                            raise error from e
1✔
597
                        logger.error("{error_exception}", error_exception=error)
1✔
598
                        tool_messages.append(
1✔
599
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
600
                        )
601

602
                # c) Handle streaming callback
603
                if streaming_callback is not None:
1✔
604
                    streaming_callback(
1✔
605
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
606
                    )
607

608
        # We stream one more chunk that contains a finish_reason if tool_messages were generated
609
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
610
            streaming_callback(
1✔
611
                StreamingChunk(
612
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
613
                )
614
            )
615

616
        return {"tool_messages": tool_messages, "state": state}
1✔
617

618
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
619
    async def run_async(
1✔
620
        self,
621
        messages: list[ChatMessage],
622
        state: Optional[State] = None,
623
        streaming_callback: Optional[StreamingCallbackT] = None,
624
        *,
625
        enable_streaming_callback_passthrough: Optional[bool] = None,
626
        tools: Optional[Union[list[Tool], Toolset, list[Toolset]]] = None,
627
    ) -> dict[str, Any]:
628
        """
629
        Asynchronously processes ChatMessage objects containing tool calls.
630

631
        Multiple tool calls are performed concurrently.
632
        :param messages:
633
            A list of ChatMessage objects.
634
        :param state: The runtime state that should be used by the tools.
635
        :param streaming_callback: An asynchronous callback function that will be called to emit tool results.
636
            Note that the result is only emitted once it becomes available — it is not
637
            streamed incrementally in real time.
638
        :param enable_streaming_callback_passthrough:
639
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
640
            This allows tools to stream their results back to the client.
641
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
642
            If False, the `streaming_callback` will not be passed to the tool invocation.
643
            If None, the value from the constructor will be used.
644
        :param tools:
645
            A list of tools to use for the tool invoker. If set, overrides the tools set in the constructor.
646
        :returns:
647
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
648
            Each ChatMessage objects wraps the result of a tool invocation.
649

650
        :raises ToolNotFoundException:
651
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
652
        :raises ToolInvocationError:
653
            If the tool invocation fails and `raise_on_failure` is True.
654
        :raises StringConversionError:
655
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
656
        :raises ToolOutputMergeError:
657
            If merging tool outputs into state fails and `raise_on_failure` is True.
658
        """
659

660
        tools_with_names = self._tools_with_names
1✔
661
        if tools is not None:
1✔
662
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
663
            logger.debug(
1✔
664
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
665
            )
666

667
        if state is None:
1✔
668
            state = State(schema={})
1✔
669

670
        resolved_enable_streaming_passthrough = (
1✔
671
            enable_streaming_callback_passthrough
672
            if enable_streaming_callback_passthrough is not None
673
            else self.enable_streaming_callback_passthrough
674
        )
675

676
        # Only keep messages with tool calls
677
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
678
        streaming_callback = select_streaming_callback(
1✔
679
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
680
        )
681

682
        if not messages_with_tool_calls:
1✔
683
            return {"tool_messages": [], "state": state}
×
684

685
        # 1) Collect all tool calls and their parameters for parallel execution
686
        tool_messages = []
1✔
687
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
688
            messages_with_tool_calls=messages_with_tool_calls,
689
            state=state,
690
            streaming_callback=streaming_callback,
691
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
692
            tools_with_names=tools_with_names,
693
        )
694
        tool_messages.extend(error_messages)
1✔
695

696
        if not tool_call_params:
1✔
697
            return {"tool_messages": tool_messages, "state": state}
×
698

699
        # 2) Execute valid tool calls in parallel
700
        tool_call_tasks = []
1✔
701
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
702
            for params in tool_call_params:
1✔
703
                loop = asyncio.get_running_loop()
1✔
704
                callable_ = ToolInvoker._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
705
                tool_call_tasks.append(loop.run_in_executor(executor, callable_))
1✔
706

707
            # 3) Gather and process results: handle errors and merge outputs into state
708
            tool_results = await asyncio.gather(*tool_call_tasks)
1✔
709
            for tool_result, tool_call in zip(tool_results, tool_calls):
1✔
710
                # a) This is an error, create error Tool message
711
                if isinstance(tool_result, ToolInvocationError):
1✔
712
                    if self.raise_on_failure:
×
713
                        raise tool_result
×
714
                    logger.error("{error_exception}", error_exception=tool_result)
×
715
                    tool_messages.append(
×
716
                        ChatMessage.from_tool(tool_result=str(tool_result), origin=tool_call, error=True)
717
                    )
718
                else:
719
                    # b) In case of success, merge outputs into state
720
                    try:
1✔
721
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
722
                        self._merge_tool_outputs(tool=tool_to_invoke, result=tool_result, state=state)
1✔
723
                        tool_messages.append(
1✔
724
                            self._prepare_tool_result_message(
725
                                result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
726
                            )
727
                        )
728
                    except Exception as e:
1✔
729
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
730
                        if self.raise_on_failure:
1✔
731
                            raise error from e
1✔
732
                        logger.error("{error_exception}", error_exception=error)
1✔
733
                        tool_messages.append(
1✔
734
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
735
                        )
736

737
                # c) Handle streaming callback
738
                if streaming_callback is not None:
1✔
739
                    await streaming_callback(
1✔
740
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
741
                    )
742

743
        # 4) We stream one more chunk that contains a finish_reason if tool_messages were generated
744
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
745
            await streaming_callback(
1✔
746
                StreamingChunk(
747
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
748
                )
749
            )
750

751
        return {"tool_messages": tool_messages, "state": state}
1✔
752

753
    def to_dict(self) -> dict[str, Any]:
1✔
754
        """
755
        Serializes the component to a dictionary.
756

757
        :returns:
758
            Dictionary with serialized data.
759
        """
760
        if self.streaming_callback is not None:
1✔
761
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
762
        else:
763
            streaming_callback = None
1✔
764

765
        return default_to_dict(
1✔
766
            self,
767
            tools=serialize_tools_or_toolset(self.tools),
768
            raise_on_failure=self.raise_on_failure,
769
            convert_result_to_json_string=self.convert_result_to_json_string,
770
            streaming_callback=streaming_callback,
771
            enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
772
            max_workers=self.max_workers,
773
        )
774

775
    @classmethod
1✔
776
    def from_dict(cls, data: dict[str, Any]) -> "ToolInvoker":
1✔
777
        """
778
        Deserializes the component from a dictionary.
779

780
        :param data:
781
            The dictionary to deserialize from.
782
        :returns:
783
            The deserialized component.
784
        """
785
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
786
        if data["init_parameters"].get("streaming_callback") is not None:
1✔
787
            data["init_parameters"]["streaming_callback"] = deserialize_callable(
1✔
788
                data["init_parameters"]["streaming_callback"]
789
            )
790
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc