• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18595249452

17 Oct 2025 02:08PM UTC coverage: 92.22% (+0.02%) from 92.2%
18595249452

Pull #9886

github

web-flow
Merge ad30d1879 into cc4f024af
Pull Request #9886: feat: Update tools param to Optional[Union[list[Union[Tool, Toolset]], Toolset]]

13382 of 14511 relevant lines covered (92.22%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.9
haystack/components/tools/tool_invoker.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import contextvars
1✔
7
import inspect
1✔
8
import json
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from functools import partial
1✔
11
from typing import Any, Callable, Optional
1✔
12

13
from haystack.components.agents import State
1✔
14
from haystack.core.component.component import component
1✔
15
from haystack.core.component.sockets import Sockets
1✔
16
from haystack.core.serialization import default_from_dict, default_to_dict, logging
1✔
17
from haystack.dataclasses import ChatMessage, ToolCall
1✔
18
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1✔
19
from haystack.tools import (
1✔
20
    ComponentTool,
21
    Tool,
22
    ToolsType,
23
    _check_duplicate_tool_names,
24
    deserialize_tools_or_toolset_inplace,
25
    flatten_tools_or_toolsets,
26
    serialize_tools_or_toolset,
27
)
28
from haystack.tools.errors import ToolInvocationError
1✔
29
from haystack.tracing.utils import _serializable_value
1✔
30
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
31

32
logger = logging.getLogger(__name__)
1✔
33

34

35
class ToolInvokerError(Exception):
1✔
36
    """Base exception class for ToolInvoker errors."""
37

38
    def __init__(self, message: str):
1✔
39
        super().__init__(message)
1✔
40

41

42
class ToolNotFoundException(ToolInvokerError):
1✔
43
    """Exception raised when a tool is not found in the list of available tools."""
44

45
    def __init__(self, tool_name: str, available_tools: list[str]):
1✔
46
        message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
1✔
47
        super().__init__(message)
1✔
48

49

50
class StringConversionError(ToolInvokerError):
1✔
51
    """Exception raised when the conversion of a tool result to a string fails."""
52

53
    def __init__(self, tool_name: str, conversion_function: str, error: Exception):
1✔
54
        message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
1✔
55
        super().__init__(message)
1✔
56

57

58
class ToolOutputMergeError(ToolInvokerError):
1✔
59
    """Exception raised when merging tool outputs into state fails."""
60

61
    @classmethod
1✔
62
    def from_exception(cls, tool_name: str, error: Exception) -> "ToolOutputMergeError":
1✔
63
        """
64
        Create a ToolOutputMergeError from an exception.
65
        """
66
        message = f"Failed to merge tool outputs from tool {tool_name} into State: {error}"
1✔
67
        return cls(message)
1✔
68

69

70
@component
1✔
71
class ToolInvoker:
1✔
72
    """
73
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
74

75
    Also handles reading/writing from a shared `State`.
76
    At initialization, the ToolInvoker component is provided with a list of available tools.
77
    At runtime, the component processes a list of ChatMessage object containing tool calls
78
    and invokes the corresponding tools.
79
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
80

81
    Usage example:
82
    ```python
83
    from haystack.dataclasses import ChatMessage, ToolCall
84
    from haystack.tools import Tool
85
    from haystack.components.tools import ToolInvoker
86

87
    # Tool definition
88
    def dummy_weather_function(city: str):
89
        return f"The weather in {city} is 20 degrees."
90

91
    parameters = {"type": "object",
92
                "properties": {"city": {"type": "string"}},
93
                "required": ["city"]}
94

95
    tool = Tool(name="weather_tool",
96
                description="A tool to get the weather",
97
                function=dummy_weather_function,
98
                parameters=parameters)
99

100
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
101
    # Here, we create it manually for demonstration purposes
102
    tool_call = ToolCall(
103
        tool_name="weather_tool",
104
        arguments={"city": "Berlin"}
105
    )
106
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
107

108
    # ToolInvoker initialization and run
109
    invoker = ToolInvoker(tools=[tool])
110
    result = invoker.run(messages=[message])
111

112
    print(result)
113
    ```
114

115
    ```
116
    >>  {
117
    >>      'tool_messages': [
118
    >>          ChatMessage(
119
    >>              _role=<ChatRole.TOOL: 'tool'>,
120
    >>              _content=[
121
    >>                  ToolCallResult(
122
    >>                      result='"The weather in Berlin is 20 degrees."',
123
    >>                      origin=ToolCall(
124
    >>                          tool_name='weather_tool',
125
    >>                          arguments={'city': 'Berlin'},
126
    >>                          id=None
127
    >>                      )
128
    >>                  )
129
    >>              ],
130
    >>              _meta={}
131
    >>          )
132
    >>      ]
133
    >>  }
134
    ```
135

136
    Usage example with a Toolset:
137
    ```python
138
    from haystack.dataclasses import ChatMessage, ToolCall
139
    from haystack.tools import Tool, Toolset
140
    from haystack.components.tools import ToolInvoker
141

142
    # Tool definition
143
    def dummy_weather_function(city: str):
144
        return f"The weather in {city} is 20 degrees."
145

146
    parameters = {"type": "object",
147
                "properties": {"city": {"type": "string"}},
148
                "required": ["city"]}
149

150
    tool = Tool(name="weather_tool",
151
                description="A tool to get the weather",
152
                function=dummy_weather_function,
153
                parameters=parameters)
154

155
    # Create a Toolset
156
    toolset = Toolset([tool])
157

158
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
159
    # Here, we create it manually for demonstration purposes
160
    tool_call = ToolCall(
161
        tool_name="weather_tool",
162
        arguments={"city": "Berlin"}
163
    )
164
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
165

166
    # ToolInvoker initialization and run with Toolset
167
    invoker = ToolInvoker(tools=toolset)
168
    result = invoker.run(messages=[message])
169

170
    print(result)
171
    """
172

173
    def __init__(
1✔
174
        self,
175
        tools: ToolsType,
176
        raise_on_failure: bool = True,
177
        convert_result_to_json_string: bool = False,
178
        streaming_callback: Optional[StreamingCallbackT] = None,
179
        *,
180
        enable_streaming_callback_passthrough: bool = False,
181
        max_workers: int = 4,
182
    ):
183
        """
184
        Initialize the ToolInvoker component.
185

186
        :param tools:
187
            A list of Tool and/or Toolset objects, or a Toolset instance that can resolve tools.
188
        :param raise_on_failure:
189
            If True, the component will raise an exception in case of errors
190
            (tool not found, tool invocation errors, tool result conversion errors).
191
            If False, the component will return a ChatMessage object with `error=True`
192
            and a description of the error in `result`.
193
        :param convert_result_to_json_string:
194
            If True, the tool invocation result will be converted to a string using `json.dumps`.
195
            If False, the tool invocation result will be converted to a string using `str`.
196
        :param streaming_callback:
197
            A callback function that will be called to emit tool results.
198
            Note that the result is only emitted once it becomes available — it is not
199
            streamed incrementally in real time.
200
        :param enable_streaming_callback_passthrough:
201
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
202
            This allows tools to stream their results back to the client.
203
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
204
            If False, the `streaming_callback` will not be passed to the tool invocation.
205
        :param max_workers:
206
            The maximum number of workers to use in the thread pool executor.
207
            This also decides the maximum number of concurrent tool invocations.
208
        :raises ValueError:
209
            If no tools are provided or if duplicate tool names are found.
210
        """
211
        self.tools = tools
1✔
212
        self.streaming_callback = streaming_callback
1✔
213
        self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
1✔
214
        self.max_workers = max_workers
1✔
215
        self.raise_on_failure = raise_on_failure
1✔
216
        self.convert_result_to_json_string = convert_result_to_json_string
1✔
217

218
        self._tools_with_names = self._validate_and_prepare_tools(tools)
1✔
219

220
    @staticmethod
1✔
221
    def _make_context_bound_invoke(tool_to_invoke: Tool, final_args: dict[str, Any]) -> Callable[[], Any]:
1✔
222
        """
223
        Create a zero-arg callable that invokes the tool under the caller's contextvars Context.
224

225
        We copy and use contextvars to preserve the caller’s ambient execution context (for example the active
226
        tracing Span) across thread boundaries. Python’s contextvars do not automatically propagate to worker
227
        threads (or to threadpool tasks spawned via run_in_executor), so without intervention nested tool calls
228
        would lose their parent trace/span and appear as separate roots. By capturing the current Context in the
229
        caller thread and invoking the tool under ctx.run(...) inside the executor, we ensure proper span parentage,
230
        consistent tagging, and reliable log/trace correlation in both sync and async paths. The callable returns
231
        ToolInvocationError instead of raising so parallel execution can collect failures.
232
        """
233
        ctx = contextvars.copy_context()
1✔
234

235
        def _runner() -> Any:
1✔
236
            try:
1✔
237
                return ctx.run(partial(tool_to_invoke.invoke, **final_args))
1✔
238
            except ToolInvocationError as e:
1✔
239
                return e
1✔
240

241
        return _runner
1✔
242

243
    @staticmethod
1✔
244
    def _validate_and_prepare_tools(tools: ToolsType) -> dict[str, Tool]:
1✔
245
        """
246
        Validates and prepares tools for use by the ToolInvoker.
247

248
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
249

250
        :returns: A dictionary mapping tool names to Tool instances.
251
        :raises ValueError: If no tools are provided or if duplicate tool names are found.
252
        """
253
        if not tools:
1✔
254
            raise ValueError("ToolInvoker requires at least one tool.")
1✔
255

256
        converted_tools = flatten_tools_or_toolsets(tools)
1✔
257

258
        _check_duplicate_tool_names(converted_tools)
1✔
259
        tool_names = [tool.name for tool in converted_tools]
1✔
260
        duplicates = {name for name in tool_names if tool_names.count(name) > 1}
1✔
261
        if duplicates:
1✔
262
            raise ValueError(f"Duplicate tool names found: {duplicates}")
×
263

264
        return dict(zip(tool_names, converted_tools))
1✔
265

266
    def _default_output_to_string_handler(self, result: Any) -> str:
1✔
267
        """
268
        Default handler for converting a tool result to a string.
269

270
        :param result: The tool result to convert to a string.
271
        :returns: The converted tool result as a string.
272
        """
273
        # We iterate through all items in result and call to_dict() if present
274
        # Relevant for a few reasons:
275
        # - If using convert_result_to_json_string we'd rather convert Haystack objects to JSON serializable dicts
276
        # - If using default str() we prefer converting Haystack objects to dicts rather than relying on the
277
        #   __repr__ method
278
        serializable = _serializable_value(result)
1✔
279

280
        if self.convert_result_to_json_string:
1✔
281
            try:
1✔
282
                # We disable ensure_ascii so special chars like emojis are not converted
283
                str_result = json.dumps(serializable, ensure_ascii=False)
1✔
284
            except Exception as error:
×
285
                # If the result is not JSON serializable, we fall back to str
286
                logger.warning(
×
287
                    "Tool result is not JSON serializable. Falling back to str conversion. "
288
                    "Result: {result}\nError: {error}",
289
                    result=result,
290
                    err=error,
291
                )
292
                str_result = str(result)
×
293
            return str_result
1✔
294

295
        return str(serializable)
1✔
296

297
    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
1✔
298
        """
299
        Prepares a ChatMessage with the result of a tool invocation.
300

301
        :param result:
302
            The tool result.
303
        :param tool_call:
304
            The ToolCall object containing the tool name and arguments.
305
        :param tool_to_invoke:
306
            The Tool object that was invoked.
307
        :returns:
308
            A ChatMessage object containing the tool result as a string.
309
        :raises
310
            StringConversionError: If the conversion of the tool result to a string fails
311
            and `raise_on_failure` is True.
312
        """
313
        outputs_config = tool_to_invoke.outputs_to_string or {}
1✔
314
        source_key = outputs_config.get("source")
1✔
315

316
        # If no handler is provided, we use the default handler
317
        output_to_string_handler = outputs_config.get("handler", self._default_output_to_string_handler)
1✔
318

319
        # If a source key is provided, we extract the result from the source key
320
        result_to_convert = result.get(source_key) if source_key is not None else result
1✔
321

322
        try:
1✔
323
            tool_result_str = output_to_string_handler(result_to_convert)
1✔
324
            chat_message = ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
1✔
325
        except Exception as e:
1✔
326
            error = StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
1✔
327
            if self.raise_on_failure:
1✔
328
                raise error from e
1✔
329
            logger.error("{error_exception}", error_exception=error)
1✔
330
            chat_message = ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
1✔
331
        return chat_message
1✔
332

333
    @staticmethod
1✔
334
    def _get_func_params(tool: Tool) -> set:
1✔
335
        """
336
        Returns the function parameters of the tool's invoke method.
337

338
        This method inspects the tool's function signature to determine which parameters the tool accepts.
339
        """
340
        # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
341
        # to find out which parameters the tool accepts.
342
        if isinstance(tool, ComponentTool):
1✔
343
            # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
344
            assert hasattr(tool._component, "__haystack_input__") and isinstance(
1✔
345
                tool._component.__haystack_input__, Sockets
346
            )
347
            func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
1✔
348
        else:
349
            func_params = set(inspect.signature(tool.function).parameters.keys())
1✔
350

351
        return func_params
1✔
352

353
    @staticmethod
1✔
354
    def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> dict[str, Any]:
1✔
355
        """
356
        Combine LLM-provided arguments (llm_args) with state-based arguments.
357

358
        Tool arguments take precedence in the following order:
359
          - LLM overrides state if the same param is present in both
360
          - local tool.inputs_from_state mappings (if any)
361
          - function signature name matching
362
        """
363
        final_args = dict(llm_args)  # start with LLM-provided
1✔
364
        func_params = ToolInvoker._get_func_params(tool)
1✔
365

366
        # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
367
        # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
368
        if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
1✔
369
            param_mappings = tool.inputs_from_state
1✔
370
        else:
371
            param_mappings = {name: name for name in func_params}
1✔
372

373
        # Populate final_args from state if not provided by LLM
374
        for state_key, param_name in param_mappings.items():
1✔
375
            if param_name not in final_args and state.has(state_key):
1✔
376
                final_args[param_name] = state.get(state_key)
1✔
377

378
        return final_args
1✔
379

380
    @staticmethod
1✔
381
    def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
1✔
382
        """
383
        Merges the tool result into the State.
384

385
        This method processes the output of a tool execution and integrates it into the global state.
386
        It also determines what message, if any, should be returned for further processing in a conversation.
387

388
        Processing Steps:
389
        1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
390
        2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
391
           The return value in this case is simply the full `result` dictionary.
392
        3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
393
           processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
394
           This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
395
           handlers.
396

397
        :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
398
        :param result: The output from tool execution. Can be a dictionary, or any other type.
399
        :param state: The global State object to which results should be merged.
400
        :returns: Three possible values:
401
            - A string message for conversation
402
            - The merged result dictionary
403
            - Or the raw result if not a dictionary
404
        """
405
        # If result is not a dictionary we exit
406
        if not isinstance(result, dict):
1✔
407
            return
1✔
408

409
        # If there is no specific `outputs_to_state` mapping, we exit
410
        if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
1✔
411
            return
1✔
412

413
        # Update the state with the tool outputs
414
        for state_key, config in tool.outputs_to_state.items():
1✔
415
            # Get the source key from the output config, otherwise use the entire result
416
            source_key = config.get("source", None)
1✔
417
            output_value = result.get(source_key) if source_key else result
1✔
418

419
            # Merge other outputs into the state
420
            state.set(state_key, output_value, handler_override=config.get("handler"))
1✔
421

422
    @staticmethod
1✔
423
    def _create_tool_result_streaming_chunk(tool_messages: list[ChatMessage], tool_call: ToolCall) -> StreamingChunk:
1✔
424
        """Create a streaming chunk for a tool result."""
425
        return StreamingChunk(
1✔
426
            content="",
427
            index=len(tool_messages) - 1,
428
            tool_call_result=tool_messages[-1].tool_call_results[0],
429
            start=True,
430
            meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
431
        )
432

433
    def _prepare_tool_call_params(
1✔
434
        self,
435
        *,
436
        messages_with_tool_calls: list[ChatMessage],
437
        state: State,
438
        streaming_callback: Optional[StreamingCallbackT],
439
        enable_streaming_passthrough: bool,
440
        tools_with_names: dict[str, Tool],
441
    ) -> tuple[list[ToolCall], list[dict[str, Any]], list[ChatMessage]]:
442
        """
443
        Prepare tool call parameters for execution and collect any error messages.
444

445
        :param messages_with_tool_calls: Messages containing tool calls to process
446
        :param state: The current state for argument injection
447
        :param streaming_callback: Optional streaming callback to inject
448
        :param enable_streaming_passthrough: Whether to pass streaming callback to tools
449
        :returns: Tuple of (tool_calls, tool_call_params, error_messages)
450
        """
451
        tool_call_params = []
1✔
452
        error_messages = []
1✔
453
        tool_calls = []
1✔
454

455
        for message in messages_with_tool_calls:
1✔
456
            for tool_call in message.tool_calls:
1✔
457
                tool_name = tool_call.tool_name
1✔
458

459
                # Check if the tool is available, otherwise return an error message
460
                if tool_name not in tools_with_names:
1✔
461
                    error = ToolNotFoundException(tool_name, list(tools_with_names.keys()))
1✔
462
                    if self.raise_on_failure:
1✔
463
                        raise error
1✔
464
                    logger.error("{error_exception}", error_exception=error)
1✔
465
                    error_messages.append(ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True))
1✔
466
                    continue
1✔
467

468
                tool_to_invoke = tools_with_names[tool_name]
1✔
469

470
                # Combine user + state inputs
471
                llm_args = tool_call.arguments.copy()
1✔
472
                final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
1✔
473

474
                # Check whether to inject streaming_callback
475
                if (
1✔
476
                    enable_streaming_passthrough
477
                    and streaming_callback is not None
478
                    and "streaming_callback" not in final_args
479
                    and "streaming_callback" in self._get_func_params(tool_to_invoke)
480
                ):
481
                    final_args["streaming_callback"] = streaming_callback
1✔
482

483
                tool_call_params.append({"tool_to_invoke": tool_to_invoke, "final_args": final_args})
1✔
484
                tool_calls.append(tool_call)
1✔
485

486
        return tool_calls, tool_call_params, error_messages
1✔
487

488
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
489
    def run(
1✔
490
        self,
491
        messages: list[ChatMessage],
492
        state: Optional[State] = None,
493
        streaming_callback: Optional[StreamingCallbackT] = None,
494
        *,
495
        enable_streaming_callback_passthrough: Optional[bool] = None,
496
        tools: Optional[ToolsType] = None,
497
    ) -> dict[str, Any]:
498
        """
499
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
500

501
        :param messages:
502
            A list of ChatMessage objects.
503
        :param state: The runtime state that should be used by the tools.
504
        :param streaming_callback: A callback function that will be called to emit tool results.
505
            Note that the result is only emitted once it becomes available — it is not
506
            streamed incrementally in real time.
507
        :param enable_streaming_callback_passthrough:
508
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
509
            This allows tools to stream their results back to the client.
510
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
511
            If False, the `streaming_callback` will not be passed to the tool invocation.
512
            If None, the value from the constructor will be used.
513
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
514
            If set, it will override the `tools` parameter provided during initialization.
515
        :returns:
516
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
517
            Each ChatMessage objects wraps the result of a tool invocation.
518

519
        :raises ToolNotFoundException:
520
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
521
        :raises ToolInvocationError:
522
            If the tool invocation fails and `raise_on_failure` is True.
523
        :raises StringConversionError:
524
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
525
        :raises ToolOutputMergeError:
526
            If merging tool outputs into state fails and `raise_on_failure` is True.
527
        """
528
        tools_with_names = self._tools_with_names
1✔
529
        if tools is not None:
1✔
530
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
531
            logger.debug(
1✔
532
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
533
            )
534

535
        if state is None:
1✔
536
            state = State(schema={})
1✔
537

538
        resolved_enable_streaming_passthrough = (
1✔
539
            enable_streaming_callback_passthrough
540
            if enable_streaming_callback_passthrough is not None
541
            else self.enable_streaming_callback_passthrough
542
        )
543

544
        # Only keep messages with tool calls
545
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
546
        streaming_callback = select_streaming_callback(
1✔
547
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
548
        )
549

550
        if not messages_with_tool_calls:
1✔
551
            return {"tool_messages": [], "state": state}
1✔
552

553
        # 1) Collect all tool calls and their parameters for parallel execution
554
        tool_messages = []
1✔
555
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
556
            messages_with_tool_calls=messages_with_tool_calls,
557
            state=state,
558
            streaming_callback=streaming_callback,
559
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
560
            tools_with_names=tools_with_names,
561
        )
562
        tool_messages.extend(error_messages)
1✔
563

564
        if not tool_call_params:
1✔
565
            return {"tool_messages": tool_messages, "state": state}
1✔
566

567
        # 2) Execute valid tool calls in parallel
568
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
569
            futures = []
1✔
570
            for params in tool_call_params:
1✔
571
                callable_ = self._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
572
                futures.append(executor.submit(callable_))
1✔
573

574
            # 3) Gather and process results: handle errors and merge outputs into state
575
            for future, tool_call in zip(futures, tool_calls):
1✔
576
                result = future.result()
1✔
577

578
                if isinstance(result, ToolInvocationError):
1✔
579
                    # a) This is an error, create error Tool message
580
                    if self.raise_on_failure:
1✔
581
                        raise result
1✔
582
                    logger.error("{error_exception}", error_exception=result)
1✔
583
                    tool_messages.append(ChatMessage.from_tool(tool_result=str(result), origin=tool_call, error=True))
1✔
584
                else:
585
                    # b) In case of success, merge outputs into state
586
                    try:
1✔
587
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
588
                        self._merge_tool_outputs(tool=tool_to_invoke, result=result, state=state)
1✔
589
                        tool_messages.append(
1✔
590
                            self._prepare_tool_result_message(
591
                                result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
592
                            )
593
                        )
594
                    except Exception as e:
1✔
595
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
596
                        if self.raise_on_failure:
1✔
597
                            raise error from e
1✔
598
                        logger.error("{error_exception}", error_exception=error)
1✔
599
                        tool_messages.append(
1✔
600
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
601
                        )
602

603
                # c) Handle streaming callback
604
                if streaming_callback is not None:
1✔
605
                    streaming_callback(
1✔
606
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
607
                    )
608

609
        # We stream one more chunk that contains a finish_reason if tool_messages were generated
610
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
611
            streaming_callback(
1✔
612
                StreamingChunk(
613
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
614
                )
615
            )
616

617
        return {"tool_messages": tool_messages, "state": state}
1✔
618

619
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
620
    async def run_async(
1✔
621
        self,
622
        messages: list[ChatMessage],
623
        state: Optional[State] = None,
624
        streaming_callback: Optional[StreamingCallbackT] = None,
625
        *,
626
        enable_streaming_callback_passthrough: Optional[bool] = None,
627
        tools: Optional[ToolsType] = None,
628
    ) -> dict[str, Any]:
629
        """
630
        Asynchronously processes ChatMessage objects containing tool calls.
631

632
        Multiple tool calls are performed concurrently.
633
        :param messages:
634
            A list of ChatMessage objects.
635
        :param state: The runtime state that should be used by the tools.
636
        :param streaming_callback: An asynchronous callback function that will be called to emit tool results.
637
            Note that the result is only emitted once it becomes available — it is not
638
            streamed incrementally in real time.
639
        :param enable_streaming_callback_passthrough:
640
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
641
            This allows tools to stream their results back to the client.
642
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
643
            If False, the `streaming_callback` will not be passed to the tool invocation.
644
            If None, the value from the constructor will be used.
645
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
646
            If set, it will override the `tools` parameter provided during initialization.
647
        :returns:
648
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
649
            Each ChatMessage objects wraps the result of a tool invocation.
650

651
        :raises ToolNotFoundException:
652
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
653
        :raises ToolInvocationError:
654
            If the tool invocation fails and `raise_on_failure` is True.
655
        :raises StringConversionError:
656
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
657
        :raises ToolOutputMergeError:
658
            If merging tool outputs into state fails and `raise_on_failure` is True.
659
        """
660

661
        tools_with_names = self._tools_with_names
1✔
662
        if tools is not None:
1✔
663
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
664
            logger.debug(
1✔
665
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
666
            )
667

668
        if state is None:
1✔
669
            state = State(schema={})
1✔
670

671
        resolved_enable_streaming_passthrough = (
1✔
672
            enable_streaming_callback_passthrough
673
            if enable_streaming_callback_passthrough is not None
674
            else self.enable_streaming_callback_passthrough
675
        )
676

677
        # Only keep messages with tool calls
678
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
679
        streaming_callback = select_streaming_callback(
1✔
680
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
681
        )
682

683
        if not messages_with_tool_calls:
1✔
684
            return {"tool_messages": [], "state": state}
×
685

686
        # 1) Collect all tool calls and their parameters for parallel execution
687
        tool_messages = []
1✔
688
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
689
            messages_with_tool_calls=messages_with_tool_calls,
690
            state=state,
691
            streaming_callback=streaming_callback,
692
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
693
            tools_with_names=tools_with_names,
694
        )
695
        tool_messages.extend(error_messages)
1✔
696

697
        if not tool_call_params:
1✔
698
            return {"tool_messages": tool_messages, "state": state}
×
699

700
        # 2) Execute valid tool calls in parallel
701
        tool_call_tasks = []
1✔
702
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
703
            for params in tool_call_params:
1✔
704
                loop = asyncio.get_running_loop()
1✔
705
                callable_ = ToolInvoker._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
706
                tool_call_tasks.append(loop.run_in_executor(executor, callable_))
1✔
707

708
            # 3) Gather and process results: handle errors and merge outputs into state
709
            tool_results = await asyncio.gather(*tool_call_tasks)
1✔
710
            for tool_result, tool_call in zip(tool_results, tool_calls):
1✔
711
                # a) This is an error, create error Tool message
712
                if isinstance(tool_result, ToolInvocationError):
1✔
713
                    if self.raise_on_failure:
×
714
                        raise tool_result
×
715
                    logger.error("{error_exception}", error_exception=tool_result)
×
716
                    tool_messages.append(
×
717
                        ChatMessage.from_tool(tool_result=str(tool_result), origin=tool_call, error=True)
718
                    )
719
                else:
720
                    # b) In case of success, merge outputs into state
721
                    try:
1✔
722
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
723
                        self._merge_tool_outputs(tool=tool_to_invoke, result=tool_result, state=state)
1✔
724
                        tool_messages.append(
1✔
725
                            self._prepare_tool_result_message(
726
                                result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
727
                            )
728
                        )
729
                    except Exception as e:
1✔
730
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
731
                        if self.raise_on_failure:
1✔
732
                            raise error from e
1✔
733
                        logger.error("{error_exception}", error_exception=error)
1✔
734
                        tool_messages.append(
1✔
735
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
736
                        )
737

738
                # c) Handle streaming callback
739
                if streaming_callback is not None:
1✔
740
                    await streaming_callback(
1✔
741
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
742
                    )
743

744
        # 4) We stream one more chunk that contains a finish_reason if tool_messages were generated
745
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
746
            await streaming_callback(
1✔
747
                StreamingChunk(
748
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
749
                )
750
            )
751

752
        return {"tool_messages": tool_messages, "state": state}
1✔
753

754
    def to_dict(self) -> dict[str, Any]:
1✔
755
        """
756
        Serializes the component to a dictionary.
757

758
        :returns:
759
            Dictionary with serialized data.
760
        """
761
        if self.streaming_callback is not None:
1✔
762
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
763
        else:
764
            streaming_callback = None
1✔
765

766
        return default_to_dict(
1✔
767
            self,
768
            tools=serialize_tools_or_toolset(self.tools),
769
            raise_on_failure=self.raise_on_failure,
770
            convert_result_to_json_string=self.convert_result_to_json_string,
771
            streaming_callback=streaming_callback,
772
            enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
773
            max_workers=self.max_workers,
774
        )
775

776
    @classmethod
1✔
777
    def from_dict(cls, data: dict[str, Any]) -> "ToolInvoker":
1✔
778
        """
779
        Deserializes the component from a dictionary.
780

781
        :param data:
782
            The dictionary to deserialize from.
783
        :returns:
784
            The deserialized component.
785
        """
786
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
787
        if data["init_parameters"].get("streaming_callback") is not None:
1✔
788
            data["init_parameters"]["streaming_callback"] = deserialize_callable(
1✔
789
                data["init_parameters"]["streaming_callback"]
790
            )
791
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc