• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13387999127

18 Feb 2025 10:03AM UTC coverage: 91.147%. Remained the same
13387999127

Pull #8870

github

web-flow
Merge 208ac314d into 0409e5da8
Pull Request #8870: fix: ComponentTool description should not be truncated

9420 of 10335 relevant lines covered (91.15%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.5
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from dataclasses import fields, is_dataclass
1✔
6
from inspect import getdoc
1✔
7
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
8

9
from pydantic import TypeAdapter
1✔
10

11
from haystack import logging
1✔
12
from haystack.core.component import Component
1✔
13
from haystack.core.serialization import (
1✔
14
    component_from_dict,
15
    component_to_dict,
16
    generate_qualified_class_name,
17
    import_class_by_name,
18
)
19
from haystack.lazy_imports import LazyImport
1✔
20
from haystack.tools import Tool
1✔
21
from haystack.tools.errors import SchemaGenerationError
1✔
22

23
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
1✔
24
    from docstring_parser import parse
1✔
25

26

27
logger = logging.getLogger(__name__)
1✔
28

29

30
class ComponentTool(Tool):
1✔
31
    """
32
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
33

34
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
35
    which are derived from the component's `run` method signature and type hints.
36

37

38
    Key features:
39
    - Automatic LLM tool calling schema generation from component input sockets
40
    - Type conversion and validation for component inputs
41
    - Support for types:
42
        - Dataclasses
43
        - Lists of dataclasses
44
        - Basic types (str, int, float, bool, dict)
45
        - Lists of basic types
46
    - Automatic name generation from component class name
47
    - Description extraction from component docstrings
48

49
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
50
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
51
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
52

53
    ```python
54
    from haystack import component, Pipeline
55
    from haystack.tools import ComponentTool
56
    from haystack.components.websearch import SerperDevWebSearch
57
    from haystack.utils import Secret
58
    from haystack.components.tools.tool_invoker import ToolInvoker
59
    from haystack.components.generators.chat import OpenAIChatGenerator
60
    from haystack.dataclasses import ChatMessage
61

62
    # Create a SerperDev search component
63
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
64

65
    # Create a tool from the component
66
    tool = ComponentTool(
67
        component=search,
68
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
69
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
70
    )
71

72
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
73
    pipeline = Pipeline()
74
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
75
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
76

77
    # Connect components
78
    pipeline.connect("llm.replies", "tool_invoker.messages")
79

80
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
81

82
    # Run pipeline
83
    result = pipeline.run({"llm": {"messages": [message]}})
84

85
    print(result)
86
    ```
87

88
    """
89

90
    def __init__(self, component: Component, name: Optional[str] = None, description: Optional[str] = None):
1✔
91
        """
92
        Create a Tool instance from a Haystack component.
93

94
        :param component: The Haystack component to wrap as a tool.
95
        :param name: Optional name for the tool (defaults to snake_case of component class name).
96
        :param description: Optional description (defaults to component's docstring).
97
        :raises ValueError: If the component is invalid or schema generation fails.
98
        """
99
        if not isinstance(component, Component):
1✔
100
            message = (
1✔
101
                f"Object {component!r} is not a Haystack component. "
102
                "Use ComponentTool only with Haystack component instances."
103
            )
104
            raise ValueError(message)
1✔
105

106
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
107
            msg = (
1✔
108
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
109
                "Create ComponentTool from a non-pipeline component instead."
110
            )
111
            raise ValueError(msg)
1✔
112

113
        # Create the tools schema from the component run method parameters
114
        tool_schema = self._create_tool_parameters_schema(component)
1✔
115

116
        def component_invoker(**kwargs):
1✔
117
            """
118
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
119

120
            :param kwargs: The keyword arguments to invoke the component with.
121
            :returns: The result of the component invocation.
122
            """
123
            converted_kwargs = {}
1✔
124
            input_sockets = component.__haystack_input__._sockets_dict
1✔
125
            for param_name, param_value in kwargs.items():
1✔
126
                param_type = input_sockets[param_name].type
1✔
127

128
                # Check if the type (or list element type) has from_dict
129
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
130
                if hasattr(target_type, "from_dict"):
1✔
131
                    if isinstance(param_value, list):
1✔
132
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
133
                    elif isinstance(param_value, dict):
×
134
                        param_value = target_type.from_dict(param_value)
×
135
                else:
136
                    # Let TypeAdapter handle both single values and lists
137
                    type_adapter = TypeAdapter(param_type)
1✔
138
                    param_value = type_adapter.validate_python(param_value)
1✔
139

140
                converted_kwargs[param_name] = param_value
1✔
141
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
142
            return component.run(**converted_kwargs)
1✔
143

144
        # Generate a name for the tool if not provided
145
        if not name:
1✔
146
            class_name = component.__class__.__name__
1✔
147
            # Convert camelCase/PascalCase to snake_case
148
            name = "".join(
1✔
149
                [
150
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
151
                    for i, c in enumerate(class_name)
152
                ]
153
            ).lstrip("_")
154

155
        description = description or component.__doc__ or name
1✔
156

157
        # Create the Tool instance with the component invoker as the function to be called and the schema
158
        super().__init__(name, description, tool_schema, component_invoker)
1✔
159
        self._component = component
1✔
160

161
    def to_dict(self) -> Dict[str, Any]:
1✔
162
        """
163
        Serializes the ComponentTool to a dictionary.
164
        """
165
        # we do not serialize the function in this case: it can be recreated from the component at deserialization time
166
        serialized = {"name": self.name, "description": self.description, "parameters": self.parameters}
1✔
167
        serialized["component"] = component_to_dict(obj=self._component, name=self.name)
1✔
168
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
169

170
    @classmethod
1✔
171
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
172
        """
173
        Deserializes the ComponentTool from a dictionary.
174
        """
175
        inner_data = data["data"]
1✔
176
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
177
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
178
        return cls(component=component, name=inner_data["name"], description=inner_data["description"])
1✔
179

180
    def _create_tool_parameters_schema(self, component: Component) -> Dict[str, Any]:
1✔
181
        """
182
        Creates an OpenAI tools schema from a component's run method parameters.
183

184
        :param component: The component to create the schema from.
185
        :raises SchemaGenerationError: If schema generation fails
186
        :returns: OpenAI tools schema for the component's run method parameters.
187
        """
188
        properties = {}
1✔
189
        required = []
1✔
190

191
        param_descriptions = self._get_param_descriptions(component.run)
1✔
192

193
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
194
            input_type = socket.type
1✔
195
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
196

197
            try:
1✔
198
                property_schema = self._create_property_schema(input_type, description)
1✔
199
            except Exception as e:
×
200
                raise SchemaGenerationError(
×
201
                    f"Error processing input '{input_name}': {e}. "
202
                    f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
203
                    f"and lists of these types as input types for component's run method."
204
                ) from e
205

206
            properties[input_name] = property_schema
1✔
207

208
            # Use socket.is_mandatory to check if the input is required
209
            if socket.is_mandatory:
1✔
210
                required.append(input_name)
1✔
211

212
        parameters_schema = {"type": "object", "properties": properties}
1✔
213

214
        if required:
1✔
215
            parameters_schema["required"] = required
1✔
216

217
        return parameters_schema
1✔
218

219
    @staticmethod
1✔
220
    def _get_param_descriptions(method: Callable) -> Dict[str, str]:
1✔
221
        """
222
        Extracts parameter descriptions from the method's docstring using docstring_parser.
223

224
        :param method: The method to extract parameter descriptions from.
225
        :returns: A dictionary mapping parameter names to their descriptions.
226
        """
227
        docstring = getdoc(method)
1✔
228
        if not docstring:
1✔
229
            return {}
×
230

231
        docstring_parser_import.check()
1✔
232
        parsed_doc = parse(docstring)
1✔
233
        param_descriptions = {}
1✔
234
        for param in parsed_doc.params:
1✔
235
            if not param.description:
1✔
236
                logger.warning(
×
237
                    "Missing description for parameter '%s'. Please add a description in the component's "
238
                    "run() method docstring using the format ':param %%s: <description>'. "
239
                    "This description helps the LLM understand how to use this parameter." % param.arg_name
240
                )
241
            param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
242
        return param_descriptions
1✔
243

244
    @staticmethod
1✔
245
    def _is_nullable_type(python_type: Any) -> bool:
1✔
246
        """
247
        Checks if the type is a Union with NoneType (i.e., Optional).
248

249
        :param python_type: The Python type to check.
250
        :returns: True if the type is a Union with NoneType, False otherwise.
251
        """
252
        origin = get_origin(python_type)
1✔
253
        if origin is Union:
1✔
254
            return type(None) in get_args(python_type)
1✔
255
        return False
1✔
256

257
    def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]:
1✔
258
        """
259
        Creates a schema for a list type.
260

261
        :param item_type: The type of items in the list.
262
        :param description: The description of the list.
263
        :returns: A dictionary representing the list schema.
264
        """
265
        items_schema = self._create_property_schema(item_type, "")
1✔
266
        items_schema.pop("description", None)
1✔
267
        return {"type": "array", "description": description, "items": items_schema}
1✔
268

269
    def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]:
1✔
270
        """
271
        Creates a schema for a dataclass.
272

273
        :param python_type: The dataclass type.
274
        :param description: The description of the dataclass.
275
        :returns: A dictionary representing the dataclass schema.
276
        """
277
        schema = {"type": "object", "description": description, "properties": {}}
1✔
278
        cls = python_type if isinstance(python_type, type) else python_type.__class__
1✔
279
        for field in fields(cls):
1✔
280
            field_description = f"Field '{field.name}' of '{cls.__name__}'."
1✔
281
            if isinstance(schema["properties"], dict):
1✔
282
                schema["properties"][field.name] = self._create_property_schema(field.type, field_description)
1✔
283
        return schema
1✔
284

285
    @staticmethod
1✔
286
    def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
1✔
287
        """
288
        Creates a schema for a basic Python type.
289

290
        :param python_type: The Python type.
291
        :param description: The description of the type.
292
        :returns: A dictionary representing the basic type schema.
293
        """
294
        type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
1✔
295
        return {"type": type_mapping.get(python_type, "string"), "description": description}
1✔
296

297
    def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
1✔
298
        """
299
        Creates a property schema for a given Python type, recursively if necessary.
300

301
        :param python_type: The Python type to create a property schema for.
302
        :param description: The description of the property.
303
        :param default: The default value of the property.
304
        :returns: A dictionary representing the property schema.
305
        :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models
306
        """
307
        nullable = self._is_nullable_type(python_type)
1✔
308
        if nullable:
1✔
309
            non_none_types = [t for t in get_args(python_type) if t is not type(None)]
1✔
310
            python_type = non_none_types[0] if non_none_types else str
1✔
311

312
        origin = get_origin(python_type)
1✔
313
        if origin is list:
1✔
314
            schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
1✔
315
        elif is_dataclass(python_type):
1✔
316
            schema = self._create_dataclass_schema(python_type, description)
1✔
317
        elif hasattr(python_type, "model_validate"):
1✔
318
            raise SchemaGenerationError(
×
319
                f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for "
320
                f"component's run method."
321
            )
322
        else:
323
            schema = self._create_basic_type_schema(python_type, description)
1✔
324

325
        if default is not None:
1✔
326
            schema["default"] = default
×
327

328
        return schema
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc