• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14972474692

12 May 2025 12:40PM UTC coverage: 90.412% (-0.005%) from 90.417%
14972474692

Pull #9342

github

web-flow
Merge 00e0bde48 into f233e06f0
Pull Request #9342: Fix component tool parameters

10938 of 12098 relevant lines covered (90.41%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.48
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
6

7
from pydantic import TypeAdapter
1✔
8

9
from haystack import logging
1✔
10
from haystack.core.component import Component
1✔
11
from haystack.core.serialization import (
1✔
12
    component_from_dict,
13
    component_to_dict,
14
    generate_qualified_class_name,
15
    import_class_by_name,
16
)
17
from haystack.tools import Tool
1✔
18
from haystack.tools.errors import SchemaGenerationError
1✔
19
from haystack.tools.property_schema_utils import _create_property_schema, _get_param_descriptions
1✔
20
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
21

22
logger = logging.getLogger(__name__)
1✔
23

24

25
class ComponentTool(Tool):
1✔
26
    """
27
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
28

29
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
30
    which are derived from the component's `run` method signature and type hints.
31

32

33
    Key features:
34
    - Automatic LLM tool calling schema generation from component input sockets
35
    - Type conversion and validation for component inputs
36
    - Support for types:
37
        - Dataclasses
38
        - Lists of dataclasses
39
        - Basic types (str, int, float, bool, dict)
40
        - Lists of basic types
41
    - Automatic name generation from component class name
42
    - Description extraction from component docstrings
43

44
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
45
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
46
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
47

48
    ```python
49
    from haystack import component, Pipeline
50
    from haystack.tools import ComponentTool
51
    from haystack.components.websearch import SerperDevWebSearch
52
    from haystack.utils import Secret
53
    from haystack.components.tools.tool_invoker import ToolInvoker
54
    from haystack.components.generators.chat import OpenAIChatGenerator
55
    from haystack.dataclasses import ChatMessage
56

57
    # Create a SerperDev search component
58
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
59

60
    # Create a tool from the component
61
    tool = ComponentTool(
62
        component=search,
63
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
64
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
65
    )
66

67
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
68
    pipeline = Pipeline()
69
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
70
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
71

72
    # Connect components
73
    pipeline.connect("llm.replies", "tool_invoker.messages")
74

75
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
76

77
    # Run pipeline
78
    result = pipeline.run({"llm": {"messages": [message]}})
79

80
    print(result)
81
    ```
82

83
    """
84

85
    def __init__(
1✔
86
        self,
87
        component: Component,
88
        name: Optional[str] = None,
89
        description: Optional[str] = None,
90
        parameters: Optional[Dict[str, Any]] = None,
91
        *,
92
        outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
93
        inputs_from_state: Optional[Dict[str, str]] = None,
94
        outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
95
    ):
96
        """
97
        Create a Tool instance from a Haystack component.
98

99
        :param component: The Haystack component to wrap as a tool.
100
        :param name: Optional name for the tool (defaults to snake_case of component class name).
101
        :param description: Optional description (defaults to component's docstring).
102
        :param parameters:
103
            A JSON schema defining the parameters expected by the Tool.
104
            Will fall back to the parameters defined in the component's run method signature if not provided.
105
        :param outputs_to_string:
106
            Optional dictionary defining how a tool outputs should be converted into a string.
107
            If the source is provided only the specified output key is sent to the handler.
108
            If the source is omitted the whole tool result is sent to the handler.
109
            Example: {
110
                "source": "docs", "handler": format_documents
111
            }
112
        :param inputs_from_state:
113
            Optional dictionary mapping state keys to tool parameter names.
114
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
115
        :param outputs_to_state:
116
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
117
            If the source is provided only the specified output key is sent to the handler.
118
            Example: {
119
                "documents": {"source": "docs", "handler": custom_handler}
120
            }
121
            If the source is omitted the whole tool result is sent to the handler.
122
            Example: {
123
                "documents": {"handler": custom_handler}
124
            }
125
        :raises ValueError: If the component is invalid or schema generation fails.
126
        """
127
        if not isinstance(component, Component):
1✔
128
            message = (
1✔
129
                f"Object {component!r} is not a Haystack component. "
130
                "Use ComponentTool only with Haystack component instances."
131
            )
132
            raise ValueError(message)
1✔
133

134
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
135
            msg = (
1✔
136
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
137
                "Create ComponentTool from a non-pipeline component instead."
138
            )
139
            raise ValueError(msg)
1✔
140

141
        self._unresolved_parameters = parameters
1✔
142
        # Create the tools schema from the component run method parameters
143
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
144

145
        def component_invoker(**kwargs):
1✔
146
            """
147
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
148

149
            :param kwargs: The keyword arguments to invoke the component with.
150
            :returns: The result of the component invocation.
151
            """
152
            converted_kwargs = {}
1✔
153
            input_sockets = component.__haystack_input__._sockets_dict
1✔
154
            for param_name, param_value in kwargs.items():
1✔
155
                param_type = input_sockets[param_name].type
1✔
156

157
                # Check if the type (or list element type) has from_dict
158
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
159
                if hasattr(target_type, "from_dict"):
1✔
160
                    if isinstance(param_value, list):
1✔
161
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
162
                    elif isinstance(param_value, dict):
×
163
                        param_value = target_type.from_dict(param_value)
×
164
                else:
165
                    # Let TypeAdapter handle both single values and lists
166
                    type_adapter = TypeAdapter(param_type)
1✔
167
                    param_value = type_adapter.validate_python(param_value)
1✔
168

169
                converted_kwargs[param_name] = param_value
1✔
170
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
171
            return component.run(**converted_kwargs)
1✔
172

173
        # Generate a name for the tool if not provided
174
        if not name:
1✔
175
            class_name = component.__class__.__name__
1✔
176
            # Convert camelCase/PascalCase to snake_case
177
            name = "".join(
1✔
178
                [
179
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
180
                    for i, c in enumerate(class_name)
181
                ]
182
            ).lstrip("_")
183

184
        description = description or component.__doc__ or name
1✔
185

186
        # Create the Tool instance with the component invoker as the function to be called and the schema
187
        super().__init__(
1✔
188
            name=name,
189
            description=description,
190
            parameters=tool_schema,
191
            function=component_invoker,
192
            inputs_from_state=inputs_from_state,
193
            outputs_to_state=outputs_to_state,
194
            outputs_to_string=outputs_to_string,
195
        )
196
        self._component = component
1✔
197

198
    def to_dict(self) -> Dict[str, Any]:
1✔
199
        """
200
        Serializes the ComponentTool to a dictionary.
201
        """
202
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
203

204
        serialized = {
1✔
205
            "component": serialized_component,
206
            "name": self.name,
207
            "description": self.description,
208
            "parameters": self._unresolved_parameters,
209
            "outputs_to_string": self.outputs_to_string,
210
            "inputs_from_state": self.inputs_from_state,
211
            "outputs_to_state": self.outputs_to_state,
212
        }
213

214
        if self.outputs_to_state is not None:
1✔
215
            serialized_outputs = {}
1✔
216
            for key, config in self.outputs_to_state.items():
1✔
217
                serialized_config = config.copy()
1✔
218
                if "handler" in config:
1✔
219
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
220
                serialized_outputs[key] = serialized_config
1✔
221
            serialized["outputs_to_state"] = serialized_outputs
1✔
222

223
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
224
            serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
×
225

226
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
227

228
    @classmethod
1✔
229
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
230
        """
231
        Deserializes the ComponentTool from a dictionary.
232
        """
233
        inner_data = data["data"]
1✔
234
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
235
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
236

237
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
238
            deserialized_outputs = {}
1✔
239
            for key, config in inner_data["outputs_to_state"].items():
1✔
240
                deserialized_config = config.copy()
1✔
241
                if "handler" in config:
1✔
242
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
243
                deserialized_outputs[key] = deserialized_config
1✔
244
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
245

246
        if (
1✔
247
            inner_data.get("outputs_to_string") is not None
248
            and inner_data["outputs_to_string"].get("handler") is not None
249
        ):
250
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
×
251
                inner_data["outputs_to_string"]["handler"]
252
            )
253

254
        return cls(
1✔
255
            component=component,
256
            name=inner_data["name"],
257
            description=inner_data["description"],
258
            parameters=inner_data.get("parameters", None),
259
            outputs_to_string=inner_data.get("outputs_to_string", None),
260
            inputs_from_state=inner_data.get("inputs_from_state", None),
261
            outputs_to_state=inner_data.get("outputs_to_state", None),
262
        )
263

264
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
265
        """
266
        Creates an OpenAI tools schema from a component's run method parameters.
267

268
        :param component: The component to create the schema from.
269
        :raises SchemaGenerationError: If schema generation fails
270
        :returns: OpenAI tools schema for the component's run method parameters.
271
        """
272
        properties = {}
1✔
273
        required = []
1✔
274

275
        param_descriptions = _get_param_descriptions(component.run)
1✔
276

277
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
278
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
279
                continue
1✔
280
            input_type = socket.type
1✔
281
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
282

283
            try:
1✔
284
                property_schema = _create_property_schema(input_type, description)
1✔
285
            except Exception as e:
×
286
                raise SchemaGenerationError(
×
287
                    f"Error processing input '{input_name}': {e}. "
288
                    f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
289
                    f"and lists of these types as input types for component's run method."
290
                ) from e
291

292
            properties[input_name] = property_schema
1✔
293

294
            # Use socket.is_mandatory to check if the input is required
295
            if socket.is_mandatory:
1✔
296
                required.append(input_name)
1✔
297

298
        parameters_schema = {"type": "object", "properties": properties}
1✔
299

300
        if required:
1✔
301
            parameters_schema["required"] = required
1✔
302

303
        return parameters_schema
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc