• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15210934031

23 May 2025 01:01PM UTC coverage: 90.056% (-0.03%) from 90.087%
15210934031

Pull #9434

github

web-flow
Merge f2e68af13 into d8cc6f733
Pull Request #9434: fix: Fix invoker to work when using dataclass with from_dict but dataclass…

11338 of 12590 relevant lines covered (90.06%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.39
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
6

7
from pydantic import Field, TypeAdapter, create_model
1✔
8

9
from haystack import logging
1✔
10
from haystack.core.component import Component
1✔
11
from haystack.core.serialization import (
1✔
12
    component_from_dict,
13
    component_to_dict,
14
    generate_qualified_class_name,
15
    import_class_by_name,
16
)
17
from haystack.tools import Tool
1✔
18
from haystack.tools.errors import SchemaGenerationError
1✔
19
from haystack.tools.from_function import _remove_title_from_schema
1✔
20
from haystack.tools.parameters_schema_utils import _get_component_param_descriptions, _resolve_type
1✔
21
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
22

23
logger = logging.getLogger(__name__)
1✔
24

25

26
class ComponentTool(Tool):
1✔
27
    """
28
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
29

30
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
31
    which are derived from the component's `run` method signature and type hints.
32

33

34
    Key features:
35
    - Automatic LLM tool calling schema generation from component input sockets
36
    - Type conversion and validation for component inputs
37
    - Support for types:
38
        - Dataclasses
39
        - Lists of dataclasses
40
        - Basic types (str, int, float, bool, dict)
41
        - Lists of basic types
42
    - Automatic name generation from component class name
43
    - Description extraction from component docstrings
44

45
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
46
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
47
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
48

49
    ```python
50
    from haystack import component, Pipeline
51
    from haystack.tools import ComponentTool
52
    from haystack.components.websearch import SerperDevWebSearch
53
    from haystack.utils import Secret
54
    from haystack.components.tools.tool_invoker import ToolInvoker
55
    from haystack.components.generators.chat import OpenAIChatGenerator
56
    from haystack.dataclasses import ChatMessage
57

58
    # Create a SerperDev search component
59
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
60

61
    # Create a tool from the component
62
    tool = ComponentTool(
63
        component=search,
64
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
65
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
66
    )
67

68
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
69
    pipeline = Pipeline()
70
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
71
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
72

73
    # Connect components
74
    pipeline.connect("llm.replies", "tool_invoker.messages")
75

76
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
77

78
    # Run pipeline
79
    result = pipeline.run({"llm": {"messages": [message]}})
80

81
    print(result)
82
    ```
83

84
    """
85

86
    def __init__(
1✔
87
        self,
88
        component: Component,
89
        name: Optional[str] = None,
90
        description: Optional[str] = None,
91
        parameters: Optional[Dict[str, Any]] = None,
92
        *,
93
        outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
94
        inputs_from_state: Optional[Dict[str, str]] = None,
95
        outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
96
    ):
97
        """
98
        Create a Tool instance from a Haystack component.
99

100
        :param component: The Haystack component to wrap as a tool.
101
        :param name: Optional name for the tool (defaults to snake_case of component class name).
102
        :param description: Optional description (defaults to component's docstring).
103
        :param parameters:
104
            A JSON schema defining the parameters expected by the Tool.
105
            Will fall back to the parameters defined in the component's run method signature if not provided.
106
        :param outputs_to_string:
107
            Optional dictionary defining how a tool outputs should be converted into a string.
108
            If the source is provided only the specified output key is sent to the handler.
109
            If the source is omitted the whole tool result is sent to the handler.
110
            Example: {
111
                "source": "docs", "handler": format_documents
112
            }
113
        :param inputs_from_state:
114
            Optional dictionary mapping state keys to tool parameter names.
115
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
116
        :param outputs_to_state:
117
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
118
            If the source is provided only the specified output key is sent to the handler.
119
            Example: {
120
                "documents": {"source": "docs", "handler": custom_handler}
121
            }
122
            If the source is omitted the whole tool result is sent to the handler.
123
            Example: {
124
                "documents": {"handler": custom_handler}
125
            }
126
        :raises ValueError: If the component is invalid or schema generation fails.
127
        """
128
        if not isinstance(component, Component):
1✔
129
            message = (
1✔
130
                f"Object {component!r} is not a Haystack component. "
131
                "Use ComponentTool only with Haystack component instances."
132
            )
133
            raise ValueError(message)
1✔
134

135
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
136
            msg = (
1✔
137
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
138
                "Create ComponentTool from a non-pipeline component instead."
139
            )
140
            raise ValueError(msg)
1✔
141

142
        self._unresolved_parameters = parameters
1✔
143
        # Create the tools schema from the component run method parameters
144
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
145

146
        def component_invoker(**kwargs):
1✔
147
            """
148
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
149

150
            :param kwargs: The keyword arguments to invoke the component with.
151
            :returns: The result of the component invocation.
152
            """
153
            converted_kwargs = {}
1✔
154
            input_sockets = component.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
1✔
155
            for param_name, param_value in kwargs.items():
1✔
156
                param_type = input_sockets[param_name].type
1✔
157

158
                # Check if the type (or list element type) has from_dict
159
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
160
                if hasattr(target_type, "from_dict"):
1✔
161
                    if isinstance(param_value, list):
1✔
162
                        resolved_param_value = [
1✔
163
                            target_type.from_dict(item) if isinstance(item, dict) else item for item in param_value
164
                        ]
165
                    elif isinstance(param_value, dict):
×
166
                        resolved_param_value = target_type.from_dict(param_value)
×
167
                    else:
168
                        resolved_param_value = param_value
×
169
                else:
170
                    # Let TypeAdapter handle both single values and lists
171
                    type_adapter = TypeAdapter(param_type)
1✔
172
                    resolved_param_value = type_adapter.validate_python(param_value)
1✔
173

174
                converted_kwargs[param_name] = resolved_param_value
1✔
175
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
176
            return component.run(**converted_kwargs)
1✔
177

178
        # Generate a name for the tool if not provided
179
        if not name:
1✔
180
            class_name = component.__class__.__name__
1✔
181
            # Convert camelCase/PascalCase to snake_case
182
            name = "".join(
1✔
183
                [
184
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
185
                    for i, c in enumerate(class_name)
186
                ]
187
            ).lstrip("_")
188

189
        description = description or component.__doc__ or name
1✔
190

191
        # Create the Tool instance with the component invoker as the function to be called and the schema
192
        super().__init__(
1✔
193
            name=name,
194
            description=description,
195
            parameters=tool_schema,
196
            function=component_invoker,
197
            inputs_from_state=inputs_from_state,
198
            outputs_to_state=outputs_to_state,
199
            outputs_to_string=outputs_to_string,
200
        )
201
        self._component = component
1✔
202

203
    def to_dict(self) -> Dict[str, Any]:
1✔
204
        """
205
        Serializes the ComponentTool to a dictionary.
206
        """
207
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
208

209
        serialized = {
1✔
210
            "component": serialized_component,
211
            "name": self.name,
212
            "description": self.description,
213
            "parameters": self._unresolved_parameters,
214
            "outputs_to_string": self.outputs_to_string,
215
            "inputs_from_state": self.inputs_from_state,
216
            "outputs_to_state": self.outputs_to_state,
217
        }
218

219
        if self.outputs_to_state is not None:
1✔
220
            serialized_outputs = {}
1✔
221
            for key, config in self.outputs_to_state.items():
1✔
222
                serialized_config = config.copy()
1✔
223
                if "handler" in config:
1✔
224
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
225
                serialized_outputs[key] = serialized_config
1✔
226
            serialized["outputs_to_state"] = serialized_outputs
1✔
227

228
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
229
            serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
×
230

231
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
232

233
    @classmethod
1✔
234
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
235
        """
236
        Deserializes the ComponentTool from a dictionary.
237
        """
238
        inner_data = data["data"]
1✔
239
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
240
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
241

242
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
243
            deserialized_outputs = {}
1✔
244
            for key, config in inner_data["outputs_to_state"].items():
1✔
245
                deserialized_config = config.copy()
1✔
246
                if "handler" in config:
1✔
247
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
248
                deserialized_outputs[key] = deserialized_config
1✔
249
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
250

251
        if (
1✔
252
            inner_data.get("outputs_to_string") is not None
253
            and inner_data["outputs_to_string"].get("handler") is not None
254
        ):
255
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
×
256
                inner_data["outputs_to_string"]["handler"]
257
            )
258

259
        return cls(
1✔
260
            component=component,
261
            name=inner_data["name"],
262
            description=inner_data["description"],
263
            parameters=inner_data.get("parameters", None),
264
            outputs_to_string=inner_data.get("outputs_to_string", None),
265
            inputs_from_state=inner_data.get("inputs_from_state", None),
266
            outputs_to_state=inner_data.get("outputs_to_state", None),
267
        )
268

269
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
270
        """
271
        Creates an OpenAI tools schema from a component's run method parameters.
272

273
        :param component: The component to create the schema from.
274
        :raises SchemaGenerationError: If schema generation fails
275
        :returns: OpenAI tools schema for the component's run method parameters.
276
        """
277
        component_run_description, param_descriptions = _get_component_param_descriptions(component)
1✔
278

279
        # collect fields (types and defaults) and descriptions from function parameters
280
        fields: Dict[str, Any] = {}
1✔
281

282
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
283
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
284
                continue
1✔
285
            input_type = socket.type
1✔
286
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
287

288
            # if the parameter has not a default value, Pydantic requires an Ellipsis (...)
289
            # to explicitly indicate that the parameter is required
290
            default = ... if socket.is_mandatory else socket.default_value
1✔
291
            resolved_type = _resolve_type(input_type)
1✔
292
            fields[input_name] = (resolved_type, Field(default=default, description=description))
1✔
293

294
        try:
1✔
295
            model = create_model(component.run.__name__, __doc__=component_run_description, **fields)
1✔
296
            parameters_schema = model.model_json_schema()
1✔
297
        except Exception as e:
×
298
            raise SchemaGenerationError(
×
299
                f"Failed to create JSON schema for the run method of Component '{component.__class__.__name__}'"
300
            ) from e
301

302
        # we don't want to include title keywords in the schema, as they contain redundant information
303
        # there is no programmatic way to prevent Pydantic from adding them, so we remove them later
304
        # see https://github.com/pydantic/pydantic/discussions/8504
305
        _remove_title_from_schema(parameters_schema)
1✔
306

307
        return parameters_schema
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc