• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15018133817

14 May 2025 10:15AM UTC coverage: 90.467% (+0.05%) from 90.417%
15018133817

Pull #9342

github

web-flow
Merge 6c290fea7 into 42b378950
Pull Request #9342: Fix component tool parameters

10932 of 12084 relevant lines covered (90.47%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.18
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
6

7
from pydantic import Field, TypeAdapter
1✔
8

9
from haystack import logging
1✔
10
from haystack.core.component import Component
1✔
11
from haystack.core.serialization import (
1✔
12
    component_from_dict,
13
    component_to_dict,
14
    generate_qualified_class_name,
15
    import_class_by_name,
16
)
17
from haystack.tools import Tool
1✔
18
from haystack.tools.parameters_schema_utils import _create_parameters_schema, _get_param_descriptions, _resolve_type
1✔
19
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
20

21
logger = logging.getLogger(__name__)
1✔
22

23

24
class ComponentTool(Tool):
1✔
25
    """
26
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
27

28
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
29
    which are derived from the component's `run` method signature and type hints.
30

31

32
    Key features:
33
    - Automatic LLM tool calling schema generation from component input sockets
34
    - Type conversion and validation for component inputs
35
    - Support for types:
36
        - Dataclasses
37
        - Lists of dataclasses
38
        - Basic types (str, int, float, bool, dict)
39
        - Lists of basic types
40
    - Automatic name generation from component class name
41
    - Description extraction from component docstrings
42

43
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
44
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
45
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
46

47
    ```python
48
    from haystack import component, Pipeline
49
    from haystack.tools import ComponentTool
50
    from haystack.components.websearch import SerperDevWebSearch
51
    from haystack.utils import Secret
52
    from haystack.components.tools.tool_invoker import ToolInvoker
53
    from haystack.components.generators.chat import OpenAIChatGenerator
54
    from haystack.dataclasses import ChatMessage
55

56
    # Create a SerperDev search component
57
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
58

59
    # Create a tool from the component
60
    tool = ComponentTool(
61
        component=search,
62
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
63
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
64
    )
65

66
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
67
    pipeline = Pipeline()
68
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
69
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
70

71
    # Connect components
72
    pipeline.connect("llm.replies", "tool_invoker.messages")
73

74
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
75

76
    # Run pipeline
77
    result = pipeline.run({"llm": {"messages": [message]}})
78

79
    print(result)
80
    ```
81

82
    """
83

84
    def __init__(
1✔
85
        self,
86
        component: Component,
87
        name: Optional[str] = None,
88
        description: Optional[str] = None,
89
        parameters: Optional[Dict[str, Any]] = None,
90
        *,
91
        outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
92
        inputs_from_state: Optional[Dict[str, str]] = None,
93
        outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
94
    ):
95
        """
96
        Create a Tool instance from a Haystack component.
97

98
        :param component: The Haystack component to wrap as a tool.
99
        :param name: Optional name for the tool (defaults to snake_case of component class name).
100
        :param description: Optional description (defaults to component's docstring).
101
        :param parameters:
102
            A JSON schema defining the parameters expected by the Tool.
103
            Will fall back to the parameters defined in the component's run method signature if not provided.
104
        :param outputs_to_string:
105
            Optional dictionary defining how a tool outputs should be converted into a string.
106
            If the source is provided only the specified output key is sent to the handler.
107
            If the source is omitted the whole tool result is sent to the handler.
108
            Example: {
109
                "source": "docs", "handler": format_documents
110
            }
111
        :param inputs_from_state:
112
            Optional dictionary mapping state keys to tool parameter names.
113
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
114
        :param outputs_to_state:
115
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
116
            If the source is provided only the specified output key is sent to the handler.
117
            Example: {
118
                "documents": {"source": "docs", "handler": custom_handler}
119
            }
120
            If the source is omitted the whole tool result is sent to the handler.
121
            Example: {
122
                "documents": {"handler": custom_handler}
123
            }
124
        :raises ValueError: If the component is invalid or schema generation fails.
125
        """
126
        if not isinstance(component, Component):
1✔
127
            message = (
1✔
128
                f"Object {component!r} is not a Haystack component. "
129
                "Use ComponentTool only with Haystack component instances."
130
            )
131
            raise ValueError(message)
1✔
132

133
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
134
            msg = (
1✔
135
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
136
                "Create ComponentTool from a non-pipeline component instead."
137
            )
138
            raise ValueError(msg)
1✔
139

140
        self._unresolved_parameters = parameters
1✔
141
        # Create the tools schema from the component run method parameters
142
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
143

144
        def component_invoker(**kwargs):
1✔
145
            """
146
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
147

148
            :param kwargs: The keyword arguments to invoke the component with.
149
            :returns: The result of the component invocation.
150
            """
151
            converted_kwargs = {}
1✔
152
            input_sockets = component.__haystack_input__._sockets_dict
1✔
153
            for param_name, param_value in kwargs.items():
1✔
154
                param_type = input_sockets[param_name].type
1✔
155

156
                # Check if the type (or list element type) has from_dict
157
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
158
                if hasattr(target_type, "from_dict"):
1✔
159
                    if isinstance(param_value, list):
1✔
160
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
161
                    elif isinstance(param_value, dict):
×
162
                        param_value = target_type.from_dict(param_value)
×
163
                else:
164
                    # Let TypeAdapter handle both single values and lists
165
                    type_adapter = TypeAdapter(param_type)
1✔
166
                    param_value = type_adapter.validate_python(param_value)
1✔
167

168
                converted_kwargs[param_name] = param_value
1✔
169
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
170
            return component.run(**converted_kwargs)
1✔
171

172
        # Generate a name for the tool if not provided
173
        if not name:
1✔
174
            class_name = component.__class__.__name__
1✔
175
            # Convert camelCase/PascalCase to snake_case
176
            name = "".join(
1✔
177
                [
178
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
179
                    for i, c in enumerate(class_name)
180
                ]
181
            ).lstrip("_")
182

183
        description = description or component.__doc__ or name
1✔
184

185
        # Create the Tool instance with the component invoker as the function to be called and the schema
186
        super().__init__(
1✔
187
            name=name,
188
            description=description,
189
            parameters=tool_schema,
190
            function=component_invoker,
191
            inputs_from_state=inputs_from_state,
192
            outputs_to_state=outputs_to_state,
193
            outputs_to_string=outputs_to_string,
194
        )
195
        self._component = component
1✔
196

197
    def to_dict(self) -> Dict[str, Any]:
1✔
198
        """
199
        Serializes the ComponentTool to a dictionary.
200
        """
201
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
202

203
        serialized = {
1✔
204
            "component": serialized_component,
205
            "name": self.name,
206
            "description": self.description,
207
            "parameters": self._unresolved_parameters,
208
            "outputs_to_string": self.outputs_to_string,
209
            "inputs_from_state": self.inputs_from_state,
210
            "outputs_to_state": self.outputs_to_state,
211
        }
212

213
        if self.outputs_to_state is not None:
1✔
214
            serialized_outputs = {}
1✔
215
            for key, config in self.outputs_to_state.items():
1✔
216
                serialized_config = config.copy()
1✔
217
                if "handler" in config:
1✔
218
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
219
                serialized_outputs[key] = serialized_config
1✔
220
            serialized["outputs_to_state"] = serialized_outputs
1✔
221

222
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
223
            serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
×
224

225
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
226

227
    @classmethod
1✔
228
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
229
        """
230
        Deserializes the ComponentTool from a dictionary.
231
        """
232
        inner_data = data["data"]
1✔
233
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
234
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
235

236
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
237
            deserialized_outputs = {}
1✔
238
            for key, config in inner_data["outputs_to_state"].items():
1✔
239
                deserialized_config = config.copy()
1✔
240
                if "handler" in config:
1✔
241
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
242
                deserialized_outputs[key] = deserialized_config
1✔
243
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
244

245
        if (
1✔
246
            inner_data.get("outputs_to_string") is not None
247
            and inner_data["outputs_to_string"].get("handler") is not None
248
        ):
249
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
×
250
                inner_data["outputs_to_string"]["handler"]
251
            )
252

253
        return cls(
1✔
254
            component=component,
255
            name=inner_data["name"],
256
            description=inner_data["description"],
257
            parameters=inner_data.get("parameters", None),
258
            outputs_to_string=inner_data.get("outputs_to_string", None),
259
            inputs_from_state=inner_data.get("inputs_from_state", None),
260
            outputs_to_state=inner_data.get("outputs_to_state", None),
261
        )
262

263
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
264
        """
265
        Creates an OpenAI tools schema from a component's run method parameters.
266

267
        :param component: The component to create the schema from.
268
        :raises SchemaGenerationError: If schema generation fails
269
        :returns: OpenAI tools schema for the component's run method parameters.
270
        """
271
        component_run_description, param_descriptions = _get_param_descriptions(component.run)
1✔
272

273
        # collect fields (types and defaults) and descriptions from function parameters
274
        fields: Dict[str, Any] = {}
1✔
275

276
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
277
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
278
                continue
1✔
279
            input_type = socket.type
1✔
280
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
281

282
            # if the parameter has not a default value, Pydantic requires an Ellipsis (...)
283
            # to explicitly indicate that the parameter is required
284
            default = ... if socket.is_mandatory else socket.default_value
1✔
285
            resolved_type = _resolve_type(input_type)
1✔
286
            fields[input_name] = (resolved_type, Field(default=default, description=description))
1✔
287

288
        return _create_parameters_schema(component.run.__name__, component_run_description, fields)
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc