• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18647136217

20 Oct 2025 08:52AM UTC coverage: 92.179% (-0.04%) from 92.22%
18647136217

Pull #9856

github

web-flow
Merge dc9eda57a into 1de94413c
Pull Request #9856: Add Tools warm_up

13425 of 14564 relevant lines covered (92.18%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.98
haystack/components/tools/tool_invoker.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import asyncio
1✔
6
import contextvars
1✔
7
import inspect
1✔
8
import json
1✔
9
from concurrent.futures import ThreadPoolExecutor
1✔
10
from functools import partial
1✔
11
from typing import Any, Callable, Optional
1✔
12

13
from haystack.components.agents import State
1✔
14
from haystack.core.component.component import component
1✔
15
from haystack.core.component.sockets import Sockets
1✔
16
from haystack.core.serialization import default_from_dict, default_to_dict, logging
1✔
17
from haystack.dataclasses import ChatMessage, ToolCall
1✔
18
from haystack.dataclasses.streaming_chunk import StreamingCallbackT, StreamingChunk, select_streaming_callback
1✔
19
from haystack.tools import (
1✔
20
    ComponentTool,
21
    Tool,
22
    ToolsType,
23
    _check_duplicate_tool_names,
24
    deserialize_tools_or_toolset_inplace,
25
    flatten_tools_or_toolsets,
26
    serialize_tools_or_toolset,
27
    warm_up_tools,
28
)
29
from haystack.tools.errors import ToolInvocationError
1✔
30
from haystack.tracing.utils import _serializable_value
1✔
31
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
32

33
logger = logging.getLogger(__name__)
1✔
34

35

36
class ToolInvokerError(Exception):
1✔
37
    """Base exception class for ToolInvoker errors."""
38

39
    def __init__(self, message: str):
1✔
40
        super().__init__(message)
1✔
41

42

43
class ToolNotFoundException(ToolInvokerError):
1✔
44
    """Exception raised when a tool is not found in the list of available tools."""
45

46
    def __init__(self, tool_name: str, available_tools: list[str]):
1✔
47
        message = f"Tool '{tool_name}' not found. Available tools: {', '.join(available_tools)}"
1✔
48
        super().__init__(message)
1✔
49

50

51
class StringConversionError(ToolInvokerError):
1✔
52
    """Exception raised when the conversion of a tool result to a string fails."""
53

54
    def __init__(self, tool_name: str, conversion_function: str, error: Exception):
1✔
55
        message = f"Failed to convert tool result from tool {tool_name} using '{conversion_function}'. Error: {error}"
1✔
56
        super().__init__(message)
1✔
57

58

59
class ToolOutputMergeError(ToolInvokerError):
1✔
60
    """Exception raised when merging tool outputs into state fails."""
61

62
    @classmethod
1✔
63
    def from_exception(cls, tool_name: str, error: Exception) -> "ToolOutputMergeError":
1✔
64
        """
65
        Create a ToolOutputMergeError from an exception.
66
        """
67
        message = f"Failed to merge tool outputs from tool {tool_name} into State: {error}"
1✔
68
        return cls(message)
1✔
69

70

71
@component
1✔
72
class ToolInvoker:
1✔
73
    """
74
    Invokes tools based on prepared tool calls and returns the results as a list of ChatMessage objects.
75

76
    Also handles reading/writing from a shared `State`.
77
    At initialization, the ToolInvoker component is provided with a list of available tools.
78
    At runtime, the component processes a list of ChatMessage object containing tool calls
79
    and invokes the corresponding tools.
80
    The results of the tool invocations are returned as a list of ChatMessage objects with tool role.
81

82
    Usage example:
83
    ```python
84
    from haystack.dataclasses import ChatMessage, ToolCall
85
    from haystack.tools import Tool
86
    from haystack.components.tools import ToolInvoker
87

88
    # Tool definition
89
    def dummy_weather_function(city: str):
90
        return f"The weather in {city} is 20 degrees."
91

92
    parameters = {"type": "object",
93
                "properties": {"city": {"type": "string"}},
94
                "required": ["city"]}
95

96
    tool = Tool(name="weather_tool",
97
                description="A tool to get the weather",
98
                function=dummy_weather_function,
99
                parameters=parameters)
100

101
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
102
    # Here, we create it manually for demonstration purposes
103
    tool_call = ToolCall(
104
        tool_name="weather_tool",
105
        arguments={"city": "Berlin"}
106
    )
107
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
108

109
    # ToolInvoker initialization and run
110
    invoker = ToolInvoker(tools=[tool])
111
    result = invoker.run(messages=[message])
112

113
    print(result)
114
    ```
115

116
    ```
117
    >>  {
118
    >>      'tool_messages': [
119
    >>          ChatMessage(
120
    >>              _role=<ChatRole.TOOL: 'tool'>,
121
    >>              _content=[
122
    >>                  ToolCallResult(
123
    >>                      result='"The weather in Berlin is 20 degrees."',
124
    >>                      origin=ToolCall(
125
    >>                          tool_name='weather_tool',
126
    >>                          arguments={'city': 'Berlin'},
127
    >>                          id=None
128
    >>                      )
129
    >>                  )
130
    >>              ],
131
    >>              _meta={}
132
    >>          )
133
    >>      ]
134
    >>  }
135
    ```
136

137
    Usage example with a Toolset:
138
    ```python
139
    from haystack.dataclasses import ChatMessage, ToolCall
140
    from haystack.tools import Tool, Toolset
141
    from haystack.components.tools import ToolInvoker
142

143
    # Tool definition
144
    def dummy_weather_function(city: str):
145
        return f"The weather in {city} is 20 degrees."
146

147
    parameters = {"type": "object",
148
                "properties": {"city": {"type": "string"}},
149
                "required": ["city"]}
150

151
    tool = Tool(name="weather_tool",
152
                description="A tool to get the weather",
153
                function=dummy_weather_function,
154
                parameters=parameters)
155

156
    # Create a Toolset
157
    toolset = Toolset([tool])
158

159
    # Usually, the ChatMessage with tool_calls is generated by a Language Model
160
    # Here, we create it manually for demonstration purposes
161
    tool_call = ToolCall(
162
        tool_name="weather_tool",
163
        arguments={"city": "Berlin"}
164
    )
165
    message = ChatMessage.from_assistant(tool_calls=[tool_call])
166

167
    # ToolInvoker initialization and run with Toolset
168
    invoker = ToolInvoker(tools=toolset)
169
    result = invoker.run(messages=[message])
170

171
    print(result)
172
    """
173

174
    def __init__(
1✔
175
        self,
176
        tools: ToolsType,
177
        raise_on_failure: bool = True,
178
        convert_result_to_json_string: bool = False,
179
        streaming_callback: Optional[StreamingCallbackT] = None,
180
        *,
181
        enable_streaming_callback_passthrough: bool = False,
182
        max_workers: int = 4,
183
    ):
184
        """
185
        Initialize the ToolInvoker component.
186

187
        :param tools:
188
            A list of Tool and/or Toolset objects, or a Toolset instance that can resolve tools.
189
        :param raise_on_failure:
190
            If True, the component will raise an exception in case of errors
191
            (tool not found, tool invocation errors, tool result conversion errors).
192
            If False, the component will return a ChatMessage object with `error=True`
193
            and a description of the error in `result`.
194
        :param convert_result_to_json_string:
195
            If True, the tool invocation result will be converted to a string using `json.dumps`.
196
            If False, the tool invocation result will be converted to a string using `str`.
197
        :param streaming_callback:
198
            A callback function that will be called to emit tool results.
199
            Note that the result is only emitted once it becomes available — it is not
200
            streamed incrementally in real time.
201
        :param enable_streaming_callback_passthrough:
202
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
203
            This allows tools to stream their results back to the client.
204
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
205
            If False, the `streaming_callback` will not be passed to the tool invocation.
206
        :param max_workers:
207
            The maximum number of workers to use in the thread pool executor.
208
            This also decides the maximum number of concurrent tool invocations.
209
        :raises ValueError:
210
            If no tools are provided or if duplicate tool names are found.
211
        """
212
        self.tools = tools
1✔
213
        self.streaming_callback = streaming_callback
1✔
214
        self.enable_streaming_callback_passthrough = enable_streaming_callback_passthrough
1✔
215
        self.max_workers = max_workers
1✔
216
        self.raise_on_failure = raise_on_failure
1✔
217
        self.convert_result_to_json_string = convert_result_to_json_string
1✔
218

219
        self._tools_with_names = self._validate_and_prepare_tools(tools)
1✔
220
        self._is_warmed_up = False
1✔
221

222
    @staticmethod
1✔
223
    def _make_context_bound_invoke(tool_to_invoke: Tool, final_args: dict[str, Any]) -> Callable[[], Any]:
1✔
224
        """
225
        Create a zero-arg callable that invokes the tool under the caller's contextvars Context.
226

227
        We copy and use contextvars to preserve the caller’s ambient execution context (for example the active
228
        tracing Span) across thread boundaries. Python’s contextvars do not automatically propagate to worker
229
        threads (or to threadpool tasks spawned via run_in_executor), so without intervention nested tool calls
230
        would lose their parent trace/span and appear as separate roots. By capturing the current Context in the
231
        caller thread and invoking the tool under ctx.run(...) inside the executor, we ensure proper span parentage,
232
        consistent tagging, and reliable log/trace correlation in both sync and async paths. The callable returns
233
        ToolInvocationError instead of raising so parallel execution can collect failures.
234
        """
235
        ctx = contextvars.copy_context()
1✔
236

237
        def _runner() -> Any:
1✔
238
            try:
1✔
239
                return ctx.run(partial(tool_to_invoke.invoke, **final_args))
1✔
240
            except ToolInvocationError as e:
1✔
241
                return e
1✔
242

243
        return _runner
1✔
244

245
    @staticmethod
1✔
246
    def _validate_and_prepare_tools(tools: ToolsType) -> dict[str, Tool]:
1✔
247
        """
248
        Validates and prepares tools for use by the ToolInvoker.
249

250
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
251

252
        :returns: A dictionary mapping tool names to Tool instances.
253
        :raises ValueError: If no tools are provided or if duplicate tool names are found.
254
        """
255
        if not tools:
1✔
256
            raise ValueError("ToolInvoker requires at least one tool.")
1✔
257

258
        converted_tools = flatten_tools_or_toolsets(tools)
1✔
259

260
        _check_duplicate_tool_names(converted_tools)
1✔
261
        tool_names = [tool.name for tool in converted_tools]
1✔
262
        duplicates = {name for name in tool_names if tool_names.count(name) > 1}
1✔
263
        if duplicates:
1✔
264
            raise ValueError(f"Duplicate tool names found: {duplicates}")
×
265

266
        return dict(zip(tool_names, converted_tools))
1✔
267

268
    def _default_output_to_string_handler(self, result: Any) -> str:
1✔
269
        """
270
        Default handler for converting a tool result to a string.
271

272
        :param result: The tool result to convert to a string.
273
        :returns: The converted tool result as a string.
274
        """
275
        # We iterate through all items in result and call to_dict() if present
276
        # Relevant for a few reasons:
277
        # - If using convert_result_to_json_string we'd rather convert Haystack objects to JSON serializable dicts
278
        # - If using default str() we prefer converting Haystack objects to dicts rather than relying on the
279
        #   __repr__ method
280
        serializable = _serializable_value(result)
1✔
281

282
        if self.convert_result_to_json_string:
1✔
283
            try:
1✔
284
                # We disable ensure_ascii so special chars like emojis are not converted
285
                str_result = json.dumps(serializable, ensure_ascii=False)
1✔
286
            except Exception as error:
×
287
                # If the result is not JSON serializable, we fall back to str
288
                logger.warning(
×
289
                    "Tool result is not JSON serializable. Falling back to str conversion. "
290
                    "Result: {result}\nError: {error}",
291
                    result=result,
292
                    err=error,
293
                )
294
                str_result = str(result)
×
295
            return str_result
1✔
296

297
        return str(serializable)
1✔
298

299
    def _prepare_tool_result_message(self, result: Any, tool_call: ToolCall, tool_to_invoke: Tool) -> ChatMessage:
1✔
300
        """
301
        Prepares a ChatMessage with the result of a tool invocation.
302

303
        :param result:
304
            The tool result.
305
        :param tool_call:
306
            The ToolCall object containing the tool name and arguments.
307
        :param tool_to_invoke:
308
            The Tool object that was invoked.
309
        :returns:
310
            A ChatMessage object containing the tool result as a string.
311
        :raises
312
            StringConversionError: If the conversion of the tool result to a string fails
313
            and `raise_on_failure` is True.
314
        """
315
        outputs_config = tool_to_invoke.outputs_to_string or {}
1✔
316
        source_key = outputs_config.get("source")
1✔
317

318
        # If no handler is provided, we use the default handler
319
        output_to_string_handler = outputs_config.get("handler", self._default_output_to_string_handler)
1✔
320

321
        # If a source key is provided, we extract the result from the source key
322
        result_to_convert = result.get(source_key) if source_key is not None else result
1✔
323

324
        try:
1✔
325
            tool_result_str = output_to_string_handler(result_to_convert)
1✔
326
            chat_message = ChatMessage.from_tool(tool_result=tool_result_str, origin=tool_call)
1✔
327
        except Exception as e:
1✔
328
            error = StringConversionError(tool_call.tool_name, output_to_string_handler.__name__, e)
1✔
329
            if self.raise_on_failure:
1✔
330
                raise error from e
1✔
331
            logger.error("{error_exception}", error_exception=error)
1✔
332
            chat_message = ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
1✔
333
        return chat_message
1✔
334

335
    @staticmethod
1✔
336
    def _get_func_params(tool: Tool) -> set:
1✔
337
        """
338
        Returns the function parameters of the tool's invoke method.
339

340
        This method inspects the tool's function signature to determine which parameters the tool accepts.
341
        """
342
        # ComponentTool wraps the function with a function that accepts kwargs, so we need to look at input sockets
343
        # to find out which parameters the tool accepts.
344
        if isinstance(tool, ComponentTool):
1✔
345
            # mypy doesn't know that ComponentMeta always adds __haystack_input__ to Component
346
            assert hasattr(tool._component, "__haystack_input__") and isinstance(
1✔
347
                tool._component.__haystack_input__, Sockets
348
            )
349
            func_params = set(tool._component.__haystack_input__._sockets_dict.keys())
1✔
350
        else:
351
            func_params = set(inspect.signature(tool.function).parameters.keys())
1✔
352

353
        return func_params
1✔
354

355
    @staticmethod
1✔
356
    def _inject_state_args(tool: Tool, llm_args: dict[str, Any], state: State) -> dict[str, Any]:
1✔
357
        """
358
        Combine LLM-provided arguments (llm_args) with state-based arguments.
359

360
        Tool arguments take precedence in the following order:
361
          - LLM overrides state if the same param is present in both
362
          - local tool.inputs_from_state mappings (if any)
363
          - function signature name matching
364
        """
365
        final_args = dict(llm_args)  # start with LLM-provided
1✔
366
        func_params = ToolInvoker._get_func_params(tool)
1✔
367

368
        # Determine the source of parameter mappings (explicit tool inputs or direct function parameters)
369
        # Typically, a "Tool" might have .inputs_from_state = {"state_key": "tool_param_name"}
370
        if hasattr(tool, "inputs_from_state") and isinstance(tool.inputs_from_state, dict):
1✔
371
            param_mappings = tool.inputs_from_state
1✔
372
        else:
373
            param_mappings = {name: name for name in func_params}
1✔
374

375
        # Populate final_args from state if not provided by LLM
376
        for state_key, param_name in param_mappings.items():
1✔
377
            if param_name not in final_args and state.has(state_key):
1✔
378
                final_args[param_name] = state.get(state_key)
1✔
379

380
        return final_args
1✔
381

382
    @staticmethod
1✔
383
    def _merge_tool_outputs(tool: Tool, result: Any, state: State) -> None:
1✔
384
        """
385
        Merges the tool result into the State.
386

387
        This method processes the output of a tool execution and integrates it into the global state.
388
        It also determines what message, if any, should be returned for further processing in a conversation.
389

390
        Processing Steps:
391
        1. If `result` is not a dictionary, nothing is stored into state and the full `result` is returned.
392
        2. If the `tool` does not define an `outputs_to_state` mapping nothing is stored into state.
393
           The return value in this case is simply the full `result` dictionary.
394
        3. If the tool defines an `outputs_to_state` mapping (a dictionary describing how the tool's output should be
395
           processed), the method delegates to `_handle_tool_outputs` to process the output accordingly.
396
           This allows certain fields in `result` to be mapped explicitly to state fields or formatted using custom
397
           handlers.
398

399
        :param tool: Tool instance containing optional `outputs_to_state` mapping to guide result processing.
400
        :param result: The output from tool execution. Can be a dictionary, or any other type.
401
        :param state: The global State object to which results should be merged.
402
        :returns: Three possible values:
403
            - A string message for conversation
404
            - The merged result dictionary
405
            - Or the raw result if not a dictionary
406
        """
407
        # If result is not a dictionary we exit
408
        if not isinstance(result, dict):
1✔
409
            return
1✔
410

411
        # If there is no specific `outputs_to_state` mapping, we exit
412
        if not hasattr(tool, "outputs_to_state") or not isinstance(tool.outputs_to_state, dict):
1✔
413
            return
1✔
414

415
        # Update the state with the tool outputs
416
        for state_key, config in tool.outputs_to_state.items():
1✔
417
            # Get the source key from the output config, otherwise use the entire result
418
            source_key = config.get("source", None)
1✔
419
            output_value = result.get(source_key) if source_key else result
1✔
420

421
            # Merge other outputs into the state
422
            state.set(state_key, output_value, handler_override=config.get("handler"))
1✔
423

424
    @staticmethod
1✔
425
    def _create_tool_result_streaming_chunk(tool_messages: list[ChatMessage], tool_call: ToolCall) -> StreamingChunk:
1✔
426
        """Create a streaming chunk for a tool result."""
427
        return StreamingChunk(
1✔
428
            content="",
429
            index=len(tool_messages) - 1,
430
            tool_call_result=tool_messages[-1].tool_call_results[0],
431
            start=True,
432
            meta={"tool_result": tool_messages[-1].tool_call_results[0].result, "tool_call": tool_call},
433
        )
434

435
    def _prepare_tool_call_params(
1✔
436
        self,
437
        *,
438
        messages_with_tool_calls: list[ChatMessage],
439
        state: State,
440
        streaming_callback: Optional[StreamingCallbackT],
441
        enable_streaming_passthrough: bool,
442
        tools_with_names: dict[str, Tool],
443
    ) -> tuple[list[ToolCall], list[dict[str, Any]], list[ChatMessage]]:
444
        """
445
        Prepare tool call parameters for execution and collect any error messages.
446

447
        :param messages_with_tool_calls: Messages containing tool calls to process
448
        :param state: The current state for argument injection
449
        :param streaming_callback: Optional streaming callback to inject
450
        :param enable_streaming_passthrough: Whether to pass streaming callback to tools
451
        :returns: Tuple of (tool_calls, tool_call_params, error_messages)
452
        """
453
        tool_call_params = []
1✔
454
        error_messages = []
1✔
455
        tool_calls = []
1✔
456

457
        for message in messages_with_tool_calls:
1✔
458
            for tool_call in message.tool_calls:
1✔
459
                tool_name = tool_call.tool_name
1✔
460

461
                # Check if the tool is available, otherwise return an error message
462
                if tool_name not in tools_with_names:
1✔
463
                    error = ToolNotFoundException(tool_name, list(tools_with_names.keys()))
1✔
464
                    if self.raise_on_failure:
1✔
465
                        raise error
1✔
466
                    logger.error("{error_exception}", error_exception=error)
1✔
467
                    error_messages.append(ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True))
1✔
468
                    continue
1✔
469

470
                tool_to_invoke = tools_with_names[tool_name]
1✔
471

472
                # Combine user + state inputs
473
                llm_args = tool_call.arguments.copy()
1✔
474
                final_args = self._inject_state_args(tool_to_invoke, llm_args, state)
1✔
475

476
                # Check whether to inject streaming_callback
477
                if (
1✔
478
                    enable_streaming_passthrough
479
                    and streaming_callback is not None
480
                    and "streaming_callback" not in final_args
481
                    and "streaming_callback" in self._get_func_params(tool_to_invoke)
482
                ):
483
                    final_args["streaming_callback"] = streaming_callback
1✔
484

485
                tool_call_params.append({"tool_to_invoke": tool_to_invoke, "final_args": final_args})
1✔
486
                tool_calls.append(tool_call)
1✔
487

488
        return tool_calls, tool_call_params, error_messages
1✔
489

490
    def warm_up(self):
1✔
491
        """
492
        Warm up the tool invoker.
493

494
        This will warm up the tools registered in the tool invoker.
495
        This method is idempotent and will only warm up the tools once.
496
        """
497
        if not self._is_warmed_up:
1✔
498
            warm_up_tools(self.tools)
1✔
499
            self._is_warmed_up = True
1✔
500

501
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
502
    def run(
1✔
503
        self,
504
        messages: list[ChatMessage],
505
        state: Optional[State] = None,
506
        streaming_callback: Optional[StreamingCallbackT] = None,
507
        *,
508
        enable_streaming_callback_passthrough: Optional[bool] = None,
509
        tools: Optional[ToolsType] = None,
510
    ) -> dict[str, Any]:
511
        """
512
        Processes ChatMessage objects containing tool calls and invokes the corresponding tools, if available.
513

514
        :param messages:
515
            A list of ChatMessage objects.
516
        :param state: The runtime state that should be used by the tools.
517
        :param streaming_callback: A callback function that will be called to emit tool results.
518
            Note that the result is only emitted once it becomes available — it is not
519
            streamed incrementally in real time.
520
        :param enable_streaming_callback_passthrough:
521
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
522
            This allows tools to stream their results back to the client.
523
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
524
            If False, the `streaming_callback` will not be passed to the tool invocation.
525
            If None, the value from the constructor will be used.
526
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
527
            If set, it will override the `tools` parameter provided during initialization.
528
        :returns:
529
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
530
            Each ChatMessage objects wraps the result of a tool invocation.
531

532
        :raises ToolNotFoundException:
533
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
534
        :raises ToolInvocationError:
535
            If the tool invocation fails and `raise_on_failure` is True.
536
        :raises StringConversionError:
537
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
538
        :raises ToolOutputMergeError:
539
            If merging tool outputs into state fails and `raise_on_failure` is True.
540
        """
541
        tools_with_names = self._tools_with_names
1✔
542
        if tools is not None:
1✔
543
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
544
            logger.debug(
1✔
545
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
546
            )
547

548
        if state is None:
1✔
549
            state = State(schema={})
1✔
550

551
        resolved_enable_streaming_passthrough = (
1✔
552
            enable_streaming_callback_passthrough
553
            if enable_streaming_callback_passthrough is not None
554
            else self.enable_streaming_callback_passthrough
555
        )
556

557
        # Only keep messages with tool calls
558
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
559
        streaming_callback = select_streaming_callback(
1✔
560
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=False
561
        )
562

563
        if not messages_with_tool_calls:
1✔
564
            return {"tool_messages": [], "state": state}
1✔
565

566
        # 1) Collect all tool calls and their parameters for parallel execution
567
        tool_messages = []
1✔
568
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
569
            messages_with_tool_calls=messages_with_tool_calls,
570
            state=state,
571
            streaming_callback=streaming_callback,
572
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
573
            tools_with_names=tools_with_names,
574
        )
575
        tool_messages.extend(error_messages)
1✔
576

577
        if not tool_call_params:
1✔
578
            return {"tool_messages": tool_messages, "state": state}
1✔
579

580
        # 2) Execute valid tool calls in parallel
581
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
582
            futures = []
1✔
583
            for params in tool_call_params:
1✔
584
                callable_ = self._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
585
                futures.append(executor.submit(callable_))
1✔
586

587
            # 3) Gather and process results: handle errors and merge outputs into state
588
            for future, tool_call in zip(futures, tool_calls):
1✔
589
                result = future.result()
1✔
590

591
                if isinstance(result, ToolInvocationError):
1✔
592
                    # a) This is an error, create error Tool message
593
                    if self.raise_on_failure:
1✔
594
                        raise result
1✔
595
                    logger.error("{error_exception}", error_exception=result)
1✔
596
                    tool_messages.append(ChatMessage.from_tool(tool_result=str(result), origin=tool_call, error=True))
1✔
597
                else:
598
                    # b) In case of success, merge outputs into state
599
                    try:
1✔
600
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
601
                        self._merge_tool_outputs(tool=tool_to_invoke, result=result, state=state)
1✔
602
                        tool_messages.append(
1✔
603
                            self._prepare_tool_result_message(
604
                                result=result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
605
                            )
606
                        )
607
                    except Exception as e:
1✔
608
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
609
                        if self.raise_on_failure:
1✔
610
                            raise error from e
1✔
611
                        logger.error("{error_exception}", error_exception=error)
1✔
612
                        tool_messages.append(
1✔
613
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
614
                        )
615

616
                # c) Handle streaming callback
617
                if streaming_callback is not None:
1✔
618
                    streaming_callback(
1✔
619
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
620
                    )
621

622
        # We stream one more chunk that contains a finish_reason if tool_messages were generated
623
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
624
            streaming_callback(
1✔
625
                StreamingChunk(
626
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
627
                )
628
            )
629

630
        return {"tool_messages": tool_messages, "state": state}
1✔
631

632
    @component.output_types(tool_messages=list[ChatMessage], state=State)
1✔
633
    async def run_async(
1✔
634
        self,
635
        messages: list[ChatMessage],
636
        state: Optional[State] = None,
637
        streaming_callback: Optional[StreamingCallbackT] = None,
638
        *,
639
        enable_streaming_callback_passthrough: Optional[bool] = None,
640
        tools: Optional[ToolsType] = None,
641
    ) -> dict[str, Any]:
642
        """
643
        Asynchronously processes ChatMessage objects containing tool calls.
644

645
        Multiple tool calls are performed concurrently.
646
        :param messages:
647
            A list of ChatMessage objects.
648
        :param state: The runtime state that should be used by the tools.
649
        :param streaming_callback: An asynchronous callback function that will be called to emit tool results.
650
            Note that the result is only emitted once it becomes available — it is not
651
            streamed incrementally in real time.
652
        :param enable_streaming_callback_passthrough:
653
            If True, the `streaming_callback` will be passed to the tool invocation if the tool supports it.
654
            This allows tools to stream their results back to the client.
655
            Note that this requires the tool to have a `streaming_callback` parameter in its `invoke` method signature.
656
            If False, the `streaming_callback` will not be passed to the tool invocation.
657
            If None, the value from the constructor will be used.
658
        :param tools: A list of Tool and/or Toolset objects, or a single Toolset for which the model can prepare calls.
659
            If set, it will override the `tools` parameter provided during initialization.
660
        :returns:
661
            A dictionary with the key `tool_messages` containing a list of ChatMessage objects with tool role.
662
            Each ChatMessage objects wraps the result of a tool invocation.
663

664
        :raises ToolNotFoundException:
665
            If the tool is not found in the list of available tools and `raise_on_failure` is True.
666
        :raises ToolInvocationError:
667
            If the tool invocation fails and `raise_on_failure` is True.
668
        :raises StringConversionError:
669
            If the conversion of the tool result to a string fails and `raise_on_failure` is True.
670
        :raises ToolOutputMergeError:
671
            If merging tool outputs into state fails and `raise_on_failure` is True.
672
        """
673

674
        tools_with_names = self._tools_with_names
1✔
675
        if tools is not None:
1✔
676
            tools_with_names = self._validate_and_prepare_tools(tools)
1✔
677
            logger.debug(
1✔
678
                f"For this invocation, overriding constructor tools with: {', '.join(tools_with_names.keys())}"
679
            )
680

681
        if state is None:
1✔
682
            state = State(schema={})
1✔
683

684
        resolved_enable_streaming_passthrough = (
1✔
685
            enable_streaming_callback_passthrough
686
            if enable_streaming_callback_passthrough is not None
687
            else self.enable_streaming_callback_passthrough
688
        )
689

690
        # Only keep messages with tool calls
691
        messages_with_tool_calls = [message for message in messages if message.tool_calls]
1✔
692
        streaming_callback = select_streaming_callback(
1✔
693
            init_callback=self.streaming_callback, runtime_callback=streaming_callback, requires_async=True
694
        )
695

696
        if not messages_with_tool_calls:
1✔
697
            return {"tool_messages": [], "state": state}
×
698

699
        # 1) Collect all tool calls and their parameters for parallel execution
700
        tool_messages = []
1✔
701
        tool_calls, tool_call_params, error_messages = self._prepare_tool_call_params(
1✔
702
            messages_with_tool_calls=messages_with_tool_calls,
703
            state=state,
704
            streaming_callback=streaming_callback,
705
            enable_streaming_passthrough=resolved_enable_streaming_passthrough,
706
            tools_with_names=tools_with_names,
707
        )
708
        tool_messages.extend(error_messages)
1✔
709

710
        if not tool_call_params:
1✔
711
            return {"tool_messages": tool_messages, "state": state}
×
712

713
        # 2) Execute valid tool calls in parallel
714
        tool_call_tasks = []
1✔
715
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
1✔
716
            for params in tool_call_params:
1✔
717
                loop = asyncio.get_running_loop()
1✔
718
                callable_ = ToolInvoker._make_context_bound_invoke(params["tool_to_invoke"], params["final_args"])
1✔
719
                tool_call_tasks.append(loop.run_in_executor(executor, callable_))
1✔
720

721
            # 3) Gather and process results: handle errors and merge outputs into state
722
            tool_results = await asyncio.gather(*tool_call_tasks)
1✔
723
            for tool_result, tool_call in zip(tool_results, tool_calls):
1✔
724
                # a) This is an error, create error Tool message
725
                if isinstance(tool_result, ToolInvocationError):
1✔
726
                    if self.raise_on_failure:
×
727
                        raise tool_result
×
728
                    logger.error("{error_exception}", error_exception=tool_result)
×
729
                    tool_messages.append(
×
730
                        ChatMessage.from_tool(tool_result=str(tool_result), origin=tool_call, error=True)
731
                    )
732
                else:
733
                    # b) In case of success, merge outputs into state
734
                    try:
1✔
735
                        tool_to_invoke = tools_with_names[tool_call.tool_name]
1✔
736
                        self._merge_tool_outputs(tool=tool_to_invoke, result=tool_result, state=state)
1✔
737
                        tool_messages.append(
1✔
738
                            self._prepare_tool_result_message(
739
                                result=tool_result, tool_call=tool_call, tool_to_invoke=tool_to_invoke
740
                            )
741
                        )
742
                    except Exception as e:
1✔
743
                        error = ToolOutputMergeError.from_exception(tool_name=tool_call.tool_name, error=e)
1✔
744
                        if self.raise_on_failure:
1✔
745
                            raise error from e
1✔
746
                        logger.error("{error_exception}", error_exception=error)
1✔
747
                        tool_messages.append(
1✔
748
                            ChatMessage.from_tool(tool_result=str(error), origin=tool_call, error=True)
749
                        )
750

751
                # c) Handle streaming callback
752
                if streaming_callback is not None:
1✔
753
                    await streaming_callback(
1✔
754
                        self._create_tool_result_streaming_chunk(tool_messages=tool_messages, tool_call=tool_call)
755
                    )
756

757
        # 4) We stream one more chunk that contains a finish_reason if tool_messages were generated
758
        if len(tool_messages) > 0 and streaming_callback is not None:
1✔
759
            await streaming_callback(
1✔
760
                StreamingChunk(
761
                    content="", finish_reason="tool_call_results", meta={"finish_reason": "tool_call_results"}
762
                )
763
            )
764

765
        return {"tool_messages": tool_messages, "state": state}
1✔
766

767
    def to_dict(self) -> dict[str, Any]:
1✔
768
        """
769
        Serializes the component to a dictionary.
770

771
        :returns:
772
            Dictionary with serialized data.
773
        """
774
        if self.streaming_callback is not None:
1✔
775
            streaming_callback = serialize_callable(self.streaming_callback)
1✔
776
        else:
777
            streaming_callback = None
1✔
778

779
        return default_to_dict(
1✔
780
            self,
781
            tools=serialize_tools_or_toolset(self.tools),
782
            raise_on_failure=self.raise_on_failure,
783
            convert_result_to_json_string=self.convert_result_to_json_string,
784
            streaming_callback=streaming_callback,
785
            enable_streaming_callback_passthrough=self.enable_streaming_callback_passthrough,
786
            max_workers=self.max_workers,
787
        )
788

789
    @classmethod
1✔
790
    def from_dict(cls, data: dict[str, Any]) -> "ToolInvoker":
1✔
791
        """
792
        Deserializes the component from a dictionary.
793

794
        :param data:
795
            The dictionary to deserialize from.
796
        :returns:
797
            The deserialized component.
798
        """
799
        deserialize_tools_or_toolset_inplace(data["init_parameters"], key="tools")
1✔
800
        if data["init_parameters"].get("streaming_callback") is not None:
1✔
801
            data["init_parameters"]["streaming_callback"] = deserialize_callable(
1✔
802
                data["init_parameters"]["streaming_callback"]
803
            )
804
        return default_from_dict(cls, data)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc