• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 20304787960

17 Dec 2025 01:39PM UTC coverage: 92.156% (+0.002%) from 92.154%
20304787960

Pull #10256

github

web-flow
Merge b69f7a4d9 into 99d506b3e
Pull Request #10256: feat: Component tool state validation

14157 of 15362 relevant lines covered (92.16%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.51
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Optional, Union, get_args, get_origin
1✔
6

7
from pydantic import Field, TypeAdapter, create_model
1✔
8

9
from haystack import logging
1✔
10
from haystack.core.component import Component
1✔
11
from haystack.core.serialization import (
1✔
12
    component_from_dict,
13
    component_to_dict,
14
    generate_qualified_class_name,
15
    import_class_by_name,
16
)
17
from haystack.tools import Tool
1✔
18
from haystack.tools.errors import SchemaGenerationError
1✔
19
from haystack.tools.from_function import _remove_title_from_schema
1✔
20
from haystack.tools.parameters_schema_utils import _get_component_param_descriptions, _resolve_type
1✔
21
from haystack.tools.tool import _deserialize_outputs_to_state, _serialize_outputs_to_state
1✔
22
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
23

24
logger = logging.getLogger(__name__)
1✔
25

26

27
class ComponentTool(Tool):
1✔
28
    """
29
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
30

31
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
32
    which are derived from the component's `run` method signature and type hints.
33

34

35
    Key features:
36
    - Automatic LLM tool calling schema generation from component input sockets
37
    - Type conversion and validation for component inputs
38
    - Support for types:
39
        - Dataclasses
40
        - Lists of dataclasses
41
        - Basic types (str, int, float, bool, dict)
42
        - Lists of basic types
43
    - Automatic name generation from component class name
44
    - Description extraction from component docstrings
45

46
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
47
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
48
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
49

50
    ## Usage Example:
51

52
    ```python
53
    from haystack import component, Pipeline
54
    from haystack.tools import ComponentTool
55
    from haystack.components.websearch import SerperDevWebSearch
56
    from haystack.utils import Secret
57
    from haystack.components.tools.tool_invoker import ToolInvoker
58
    from haystack.components.generators.chat import OpenAIChatGenerator
59
    from haystack.dataclasses import ChatMessage
60

61
    # Create a SerperDev search component
62
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
63

64
    # Create a tool from the component
65
    tool = ComponentTool(
66
        component=search,
67
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
68
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
69
    )
70

71
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
72
    pipeline = Pipeline()
73
    pipeline.add_component("llm", OpenAIChatGenerator(tools=[tool]))
74
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
75

76
    # Connect components
77
    pipeline.connect("llm.replies", "tool_invoker.messages")
78

79
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
80

81
    # Run pipeline
82
    result = pipeline.run({"llm": {"messages": [message]}})
83

84
    print(result)
85
    ```
86

87
    """
88

89
    def __init__(
1✔
90
        self,
91
        component: Component,
92
        name: Optional[str] = None,
93
        description: Optional[str] = None,
94
        parameters: Optional[dict[str, Any]] = None,
95
        *,
96
        outputs_to_string: Optional[dict[str, Union[str, Callable[[Any], str]]]] = None,
97
        inputs_from_state: Optional[dict[str, str]] = None,
98
        outputs_to_state: Optional[dict[str, dict[str, Union[str, Callable]]]] = None,
99
    ) -> None:
100
        """
101
        Create a Tool instance from a Haystack component.
102

103
        :param component: The Haystack component to wrap as a tool.
104
        :param name: Optional name for the tool (defaults to snake_case of component class name).
105
        :param description: Optional description (defaults to component's docstring).
106
        :param parameters:
107
            A JSON schema defining the parameters expected by the Tool.
108
            Will fall back to the parameters defined in the component's run method signature if not provided.
109
        :param outputs_to_string:
110
            Optional dictionary defining how a tool outputs should be converted into a string.
111
            If the source is provided only the specified output key is sent to the handler.
112
            If the source is omitted the whole tool result is sent to the handler.
113
            Example:
114
            ```python
115
            {
116
                "source": "docs", "handler": format_documents
117
            }
118
            ```
119
        :param inputs_from_state:
120
            Optional dictionary mapping state keys to tool parameter names.
121
            Example: `{"repository": "repo"}` maps state's "repository" to tool's "repo" parameter.
122
        :param outputs_to_state:
123
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
124
            If the source is provided only the specified output key is sent to the handler.
125
            Example:
126
            ```python
127
            {
128
                "documents": {"source": "docs", "handler": custom_handler}
129
            }
130
            ```
131
            If the source is omitted the whole tool result is sent to the handler.
132
            Example:
133
            ```python
134
            {
135
                "documents": {"handler": custom_handler}
136
            }
137
            ```
138
        :raises ValueError: If the component is invalid or schema generation fails.
139
        """
140
        if not isinstance(component, Component):
1✔
141
            message = (
1✔
142
                f"Object {component!r} is not a Haystack component. "
143
                "Use ComponentTool only with Haystack component instances."
144
            )
145
            raise ValueError(message)
1✔
146

147
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
148
            msg = (
1✔
149
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
150
                "Create ComponentTool from a non-pipeline component instead."
151
            )
152
            raise ValueError(msg)
1✔
153

154
        self._unresolved_parameters = parameters
1✔
155
        # Create the tools schema from the component run method parameters
156
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
157

158
        def component_invoker(**kwargs):
1✔
159
            """
160
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
161

162
            :param kwargs: The keyword arguments to invoke the component with.
163
            :returns: The result of the component invocation.
164
            """
165
            converted_kwargs = {}
1✔
166
            input_sockets = component.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
1✔
167
            for param_name, param_value in kwargs.items():
1✔
168
                param_type = input_sockets[param_name].type
1✔
169

170
                # Check if the type (or list element type) has from_dict
171
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
172
                if hasattr(target_type, "from_dict"):
1✔
173
                    if isinstance(param_value, list):
1✔
174
                        resolved_param_value = [
1✔
175
                            target_type.from_dict(item) if isinstance(item, dict) else item for item in param_value
176
                        ]
177
                    elif isinstance(param_value, dict):
×
178
                        resolved_param_value = target_type.from_dict(param_value)
×
179
                    else:
180
                        resolved_param_value = param_value
×
181
                else:
182
                    # Let TypeAdapter handle both single values and lists
183
                    type_adapter = TypeAdapter(param_type)
1✔
184
                    resolved_param_value = type_adapter.validate_python(param_value)
1✔
185

186
                converted_kwargs[param_name] = resolved_param_value
1✔
187
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
188
            return component.run(**converted_kwargs)
1✔
189

190
        # Generate a name for the tool if not provided
191
        if not name:
1✔
192
            class_name = component.__class__.__name__
1✔
193
            # Convert camelCase/PascalCase to snake_case
194
            name = "".join(
1✔
195
                [
196
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
197
                    for i, c in enumerate(class_name)
198
                ]
199
            ).lstrip("_")
200

201
        description = description or component.__doc__ or name
1✔
202

203
        # Store component before calling super().__init__() so _get_valid_outputs() can access it
204
        self._component = component
1✔
205
        self._is_warmed_up = False
1✔
206

207
        # Create the Tool instance with the component invoker as the function to be called and the schema
208
        super().__init__(
1✔
209
            name=name,
210
            description=description,
211
            parameters=tool_schema,
212
            function=component_invoker,
213
            inputs_from_state=inputs_from_state,
214
            outputs_to_state=outputs_to_state,
215
            outputs_to_string=outputs_to_string,
216
        )
217

218
    def _get_valid_inputs(self) -> set[str]:
1✔
219
        """
220
        Return valid input parameter names from the component's input sockets.
221

222
        Used to validate `inputs_from_state` against the component's actual inputs.
223
        This ensures users don't reference non-existent component inputs.
224

225
        :returns: Set of component input socket names.
226
        """
227
        return set(self._component.__haystack_input__._sockets_dict.keys())  # type: ignore[attr-defined]
1✔
228

229
    def _get_valid_outputs(self) -> set[str]:
1✔
230
        """
231
        Return valid output names from the component's output sockets.
232

233
        Used to validate `outputs_to_state` against the component's actual outputs.
234
        This ensures users don't reference non-existent component outputs.
235

236
        :returns: Set of component output socket names.
237
        """
238
        return set(self._component.__haystack_output__._sockets_dict.keys())  # type: ignore[attr-defined]
1✔
239

240
    def warm_up(self):
1✔
241
        """
242
        Prepare the ComponentTool for use.
243
        """
244
        if not self._is_warmed_up:
1✔
245
            if hasattr(self._component, "warm_up"):
1✔
246
                self._component.warm_up()
1✔
247
            self._is_warmed_up = True
1✔
248

249
    def to_dict(self) -> dict[str, Any]:
1✔
250
        """
251
        Serializes the ComponentTool to a dictionary.
252
        """
253
        serialized: dict[str, Any] = {
1✔
254
            "component": component_to_dict(obj=self._component, name=self.name),
255
            "name": self.name,
256
            "description": self.description,
257
            "parameters": self._unresolved_parameters,
258
            "inputs_from_state": self.inputs_from_state,
259
            "outputs_to_state": _serialize_outputs_to_state(self.outputs_to_state) if self.outputs_to_state else None,
260
        }
261

262
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
263
            # This is soft-copied as to not modify the attributes in place
264
            serialized["outputs_to_string"] = self.outputs_to_string.copy()
1✔
265
            serialized["outputs_to_string"]["handler"] = serialize_callable(self.outputs_to_string["handler"])
1✔
266
        else:
267
            serialized["outputs_to_string"] = None
1✔
268

269
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
270

271
    @classmethod
1✔
272
    def from_dict(cls, data: dict[str, Any]) -> "ComponentTool":
1✔
273
        """
274
        Deserializes the ComponentTool from a dictionary.
275
        """
276
        inner_data = data["data"]
1✔
277
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
278
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
279

280
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
281
            inner_data["outputs_to_state"] = _deserialize_outputs_to_state(inner_data["outputs_to_state"])
1✔
282

283
        if (
1✔
284
            inner_data.get("outputs_to_string") is not None
285
            and inner_data["outputs_to_string"].get("handler") is not None
286
        ):
287
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
1✔
288
                inner_data["outputs_to_string"]["handler"]
289
            )
290

291
        return cls(
1✔
292
            component=component,
293
            name=inner_data["name"],
294
            description=inner_data["description"],
295
            parameters=inner_data.get("parameters", None),
296
            outputs_to_string=inner_data.get("outputs_to_string", None),
297
            inputs_from_state=inner_data.get("inputs_from_state", None),
298
            outputs_to_state=inner_data.get("outputs_to_state", None),
299
        )
300

301
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: dict[str, Any]) -> dict[str, Any]:
1✔
302
        """
303
        Creates an OpenAI tools schema from a component's run method parameters.
304

305
        :param component: The component to create the schema from.
306
        :raises SchemaGenerationError: If schema generation fails
307
        :returns: OpenAI tools schema for the component's run method parameters.
308
        """
309
        component_run_description, param_descriptions = _get_component_param_descriptions(component)
1✔
310

311
        # collect fields (types and defaults) and descriptions from function parameters
312
        fields: dict[str, Any] = {}
1✔
313

314
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
315
            if inputs_from_state is not None and input_name in list(inputs_from_state.values()):
1✔
316
                continue
1✔
317
            input_type = socket.type
1✔
318
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
319

320
            # if the parameter has not a default value, Pydantic requires an Ellipsis (...)
321
            # to explicitly indicate that the parameter is required
322
            default = ... if socket.is_mandatory else socket.default_value
1✔
323
            resolved_type = _resolve_type(input_type)
1✔
324
            fields[input_name] = (resolved_type, Field(default=default, description=description))
1✔
325

326
        parameters_schema: dict[str, Any] = {}
1✔
327
        try:
1✔
328
            model = create_model(component.run.__name__, __doc__=component_run_description, **fields)
1✔
329
            parameters_schema = model.model_json_schema()
1✔
330
        except Exception as e:
×
331
            raise SchemaGenerationError(
×
332
                f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'"
333
            ) from e
334

335
        # we don't want to include title keywords in the schema, as they contain redundant information
336
        # there is no programmatic way to prevent Pydantic from adding them, so we remove them later
337
        # see https://github.com/pydantic/pydantic/discussions/8504
338
        _remove_title_from_schema(parameters_schema)
1✔
339

340
        return parameters_schema
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc