• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15022600357

14 May 2025 01:54PM UTC coverage: 90.465% (+0.05%) from 90.417%
15022600357

Pull #9342

github

web-flow
Merge 2fd50f281 into 42b378950
Pull Request #9342: Fix component tool parameters

10930 of 12082 relevant lines covered (90.47%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.41
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
6

7
from pydantic import Field, TypeAdapter, create_model
1✔
8

9
from haystack import logging
1✔
10
from haystack.core.component import Component
1✔
11
from haystack.core.serialization import (
1✔
12
    component_from_dict,
13
    component_to_dict,
14
    generate_qualified_class_name,
15
    import_class_by_name,
16
)
17
from haystack.tools import Tool
1✔
18
from haystack.tools.errors import SchemaGenerationError
1✔
19
from haystack.tools.from_function import _remove_title_from_schema
1✔
20
from haystack.tools.parameters_schema_utils import _get_param_descriptions, _resolve_type
1✔
21
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
class ComponentTool(Tool):
1✔
27
    """
28
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
29

30
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
31
    which are derived from the component's `run` method signature and type hints.
32

33

34
    Key features:
35
    - Automatic LLM tool calling schema generation from component input sockets
36
    - Type conversion and validation for component inputs
37
    - Support for types:
38
        - Dataclasses
39
        - Lists of dataclasses
40
        - Basic types (str, int, float, bool, dict)
41
        - Lists of basic types
42
    - Automatic name generation from component class name
43
    - Description extraction from component docstrings
44

45
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
46
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
47
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
48

49
    ```python
50
    from haystack import component, Pipeline
51
    from haystack.tools import ComponentTool
52
    from haystack.components.websearch import SerperDevWebSearch
53
    from haystack.utils import Secret
54
    from haystack.components.tools.tool_invoker import ToolInvoker
55
    from haystack.components.generators.chat import OpenAIChatGenerator
56
    from haystack.dataclasses import ChatMessage
57

58
    # Create a SerperDev search component
59
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
60

61
    # Create a tool from the component
62
    tool = ComponentTool(
63
        component=search,
64
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
65
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
66
    )
67

68
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
69
    pipeline = Pipeline()
70
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
71
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
72

73
    # Connect components
74
    pipeline.connect("llm.replies", "tool_invoker.messages")
75

76
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
77

78
    # Run pipeline
79
    result = pipeline.run({"llm": {"messages": [message]}})
80

81
    print(result)
82
    ```
83

84
    """
85

86
    def __init__(
1✔
87
        self,
88
        component: Component,
89
        name: Optional[str] = None,
90
        description: Optional[str] = None,
91
        parameters: Optional[Dict[str, Any]] = None,
92
        *,
93
        outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
94
        inputs_from_state: Optional[Dict[str, str]] = None,
95
        outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
96
    ):
97
        """
98
        Create a Tool instance from a Haystack component.
99

100
        :param component: The Haystack component to wrap as a tool.
101
        :param name: Optional name for the tool (defaults to snake_case of component class name).
102
        :param description: Optional description (defaults to component's docstring).
103
        :param parameters:
104
            A JSON schema defining the parameters expected by the Tool.
105
            Will fall back to the parameters defined in the component's run method signature if not provided.
106
        :param outputs_to_string:
107
            Optional dictionary defining how a tool outputs should be converted into a string.
108
            If the source is provided only the specified output key is sent to the handler.
109
            If the source is omitted the whole tool result is sent to the handler.
110
            Example: {
111
                "source": "docs", "handler": format_documents
112
            }
113
        :param inputs_from_state:
114
            Optional dictionary mapping state keys to tool parameter names.
115
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
116
        :param outputs_to_state:
117
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
118
            If the source is provided only the specified output key is sent to the handler.
119
            Example: {
120
                "documents": {"source": "docs", "handler": custom_handler}
121
            }
122
            If the source is omitted the whole tool result is sent to the handler.
123
            Example: {
124
                "documents": {"handler": custom_handler}
125
            }
126
        :raises ValueError: If the component is invalid or schema generation fails.
127
        """
128
        if not isinstance(component, Component):
1✔
129
            message = (
1✔
130
                f"Object {component!r} is not a Haystack component. "
131
                "Use ComponentTool only with Haystack component instances."
132
            )
133
            raise ValueError(message)
1✔
134

135
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
136
            msg = (
1✔
137
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
138
                "Create ComponentTool from a non-pipeline component instead."
139
            )
140
            raise ValueError(msg)
1✔
141

142
        self._unresolved_parameters = parameters
1✔
143
        # Create the tools schema from the component run method parameters
144
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
145

146
        def component_invoker(**kwargs):
1✔
147
            """
148
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
149

150
            :param kwargs: The keyword arguments to invoke the component with.
151
            :returns: The result of the component invocation.
152
            """
153
            converted_kwargs = {}
1✔
154
            input_sockets = component.__haystack_input__._sockets_dict
1✔
155
            for param_name, param_value in kwargs.items():
1✔
156
                param_type = input_sockets[param_name].type
1✔
157

158
                # Check if the type (or list element type) has from_dict
159
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
160
                if hasattr(target_type, "from_dict"):
1✔
161
                    if isinstance(param_value, list):
1✔
162
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
163
                    elif isinstance(param_value, dict):
×
164
                        param_value = target_type.from_dict(param_value)
×
165
                else:
166
                    # Let TypeAdapter handle both single values and lists
167
                    type_adapter = TypeAdapter(param_type)
1✔
168
                    param_value = type_adapter.validate_python(param_value)
1✔
169

170
                converted_kwargs[param_name] = param_value
1✔
171
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
172
            return component.run(**converted_kwargs)
1✔
173

174
        # Generate a name for the tool if not provided
175
        if not name:
1✔
176
            class_name = component.__class__.__name__
1✔
177
            # Convert camelCase/PascalCase to snake_case
178
            name = "".join(
1✔
179
                [
180
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
181
                    for i, c in enumerate(class_name)
182
                ]
183
            ).lstrip("_")
184

185
        description = description or component.__doc__ or name
1✔
186

187
        # Create the Tool instance with the component invoker as the function to be called and the schema
188
        super().__init__(
1✔
189
            name=name,
190
            description=description,
191
            parameters=tool_schema,
192
            function=component_invoker,
193
            inputs_from_state=inputs_from_state,
194
            outputs_to_state=outputs_to_state,
195
            outputs_to_string=outputs_to_string,
196
        )
197
        self._component = component
1✔
198

199
    def to_dict(self) -> Dict[str, Any]:
1✔
200
        """
201
        Serializes the ComponentTool to a dictionary.
202
        """
203
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
204

205
        serialized = {
1✔
206
            "component": serialized_component,
207
            "name": self.name,
208
            "description": self.description,
209
            "parameters": self._unresolved_parameters,
210
            "outputs_to_string": self.outputs_to_string,
211
            "inputs_from_state": self.inputs_from_state,
212
            "outputs_to_state": self.outputs_to_state,
213
        }
214

215
        if self.outputs_to_state is not None:
1✔
216
            serialized_outputs = {}
1✔
217
            for key, config in self.outputs_to_state.items():
1✔
218
                serialized_config = config.copy()
1✔
219
                if "handler" in config:
1✔
220
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
221
                serialized_outputs[key] = serialized_config
1✔
222
            serialized["outputs_to_state"] = serialized_outputs
1✔
223

224
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
225
            serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
×
226

227
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
228

229
    @classmethod
1✔
230
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
231
        """
232
        Deserializes the ComponentTool from a dictionary.
233
        """
234
        inner_data = data["data"]
1✔
235
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
236
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
237

238
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
239
            deserialized_outputs = {}
1✔
240
            for key, config in inner_data["outputs_to_state"].items():
1✔
241
                deserialized_config = config.copy()
1✔
242
                if "handler" in config:
1✔
243
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
244
                deserialized_outputs[key] = deserialized_config
1✔
245
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
246

247
        if (
1✔
248
            inner_data.get("outputs_to_string") is not None
249
            and inner_data["outputs_to_string"].get("handler") is not None
250
        ):
251
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
×
252
                inner_data["outputs_to_string"]["handler"]
253
            )
254

255
        return cls(
1✔
256
            component=component,
257
            name=inner_data["name"],
258
            description=inner_data["description"],
259
            parameters=inner_data.get("parameters", None),
260
            outputs_to_string=inner_data.get("outputs_to_string", None),
261
            inputs_from_state=inner_data.get("inputs_from_state", None),
262
            outputs_to_state=inner_data.get("outputs_to_state", None),
263
        )
264

265
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
266
        """
267
        Creates an OpenAI tools schema from a component's run method parameters.
268

269
        :param component: The component to create the schema from.
270
        :raises SchemaGenerationError: If schema generation fails
271
        :returns: OpenAI tools schema for the component's run method parameters.
272
        """
273
        component_run_description, param_descriptions = _get_param_descriptions(component.run)
1✔
274

275
        # collect fields (types and defaults) and descriptions from function parameters
276
        fields: Dict[str, Any] = {}
1✔
277

278
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
279
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
280
                continue
1✔
281
            input_type = socket.type
1✔
282
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
283

284
            # if the parameter has not a default value, Pydantic requires an Ellipsis (...)
285
            # to explicitly indicate that the parameter is required
286
            default = ... if socket.is_mandatory else socket.default_value
1✔
287
            resolved_type = _resolve_type(input_type)
1✔
288
            fields[input_name] = (resolved_type, Field(default=default, description=description))
1✔
289

290
        try:
1✔
291
            model = create_model(component.run.__name__, __doc__=component_run_description, **fields)
1✔
292
            parameters_schema = model.model_json_schema()
1✔
293
        except Exception as e:
×
294
            raise SchemaGenerationError(
×
295
                f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'"
296
            ) from e
297

298
        # we don't want to include title keywords in the schema, as they contain redundant information
299
        # there is no programmatic way to prevent Pydantic from adding them, so we remove them later
300
        # see https://github.com/pydantic/pydantic/discussions/8504
301
        _remove_title_from_schema(parameters_schema)
1✔
302

303
        return parameters_schema
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc