• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14126378592

28 Mar 2025 09:55AM UTC coverage: 90.173% (+0.008%) from 90.165%
14126378592

Pull #9126

github

web-flow
Merge 64e16ca41 into 657d09d7f
Pull Request #9126: refactor: use `token` in `SASEvaluator` and `SentenceTransformersDiversityRanker`

10140 of 11245 relevant lines covered (90.17%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.81
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from copy import copy, deepcopy
1✔
6
from dataclasses import fields, is_dataclass
1✔
7
from inspect import getdoc
1✔
8
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
9

10
from pydantic import TypeAdapter
1✔
11

12
from haystack import logging
1✔
13
from haystack.core.component import Component
1✔
14
from haystack.core.serialization import (
1✔
15
    component_from_dict,
16
    component_to_dict,
17
    generate_qualified_class_name,
18
    import_class_by_name,
19
)
20
from haystack.lazy_imports import LazyImport
1✔
21
from haystack.tools import Tool
1✔
22
from haystack.tools.errors import SchemaGenerationError
1✔
23
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
24

25
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
1✔
26
    from docstring_parser import parse
1✔
27

28

29
logger = logging.getLogger(__name__)
1✔
30

31

32
class ComponentTool(Tool):
1✔
33
    """
34
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
35

36
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
37
    which are derived from the component's `run` method signature and type hints.
38

39

40
    Key features:
41
    - Automatic LLM tool calling schema generation from component input sockets
42
    - Type conversion and validation for component inputs
43
    - Support for types:
44
        - Dataclasses
45
        - Lists of dataclasses
46
        - Basic types (str, int, float, bool, dict)
47
        - Lists of basic types
48
    - Automatic name generation from component class name
49
    - Description extraction from component docstrings
50

51
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
52
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
53
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
54

55
    ```python
56
    from haystack import component, Pipeline
57
    from haystack.tools import ComponentTool
58
    from haystack.components.websearch import SerperDevWebSearch
59
    from haystack.utils import Secret
60
    from haystack.components.tools.tool_invoker import ToolInvoker
61
    from haystack.components.generators.chat import OpenAIChatGenerator
62
    from haystack.dataclasses import ChatMessage
63

64
    # Create a SerperDev search component
65
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
66

67
    # Create a tool from the component
68
    tool = ComponentTool(
69
        component=search,
70
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
71
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
72
    )
73

74
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
75
    pipeline = Pipeline()
76
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
77
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
78

79
    # Connect components
80
    pipeline.connect("llm.replies", "tool_invoker.messages")
81

82
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
83

84
    # Run pipeline
85
    result = pipeline.run({"llm": {"messages": [message]}})
86

87
    print(result)
88
    ```
89

90
    """
91

92
    def __init__(
1✔
93
        self,
94
        component: Component,
95
        name: Optional[str] = None,
96
        description: Optional[str] = None,
97
        parameters: Optional[Dict[str, Any]] = None,
98
        *,
99
        inputs_from_state: Optional[Dict[str, Any]] = None,
100
        outputs_to_state: Optional[Dict[str, Any]] = None,
101
    ):
102
        """
103
        Create a Tool instance from a Haystack component.
104

105
        :param component: The Haystack component to wrap as a tool.
106
        :param name: Optional name for the tool (defaults to snake_case of component class name).
107
        :param description: Optional description (defaults to component's docstring).
108
        :param parameters:
109
            A JSON schema defining the parameters expected by the Tool.
110
            Will fall back to the parameters defined in the component's run method signature if not provided.
111
        :param inputs_from_state:
112
            Optional dictionary mapping state keys to tool parameter names.
113
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
114
        :param outputs_to_state:
115
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
116
            Example: {
117
                "documents": {"source": "docs", "handler": custom_handler},
118
                "message": {"source": "summary", "handler": format_summary}
119
            }
120
        :raises ValueError: If the component is invalid or schema generation fails.
121
        """
122
        if not isinstance(component, Component):
1✔
123
            message = (
1✔
124
                f"Object {component!r} is not a Haystack component. "
125
                "Use ComponentTool only with Haystack component instances."
126
            )
127
            raise ValueError(message)
1✔
128

129
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
130
            msg = (
1✔
131
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
132
                "Create ComponentTool from a non-pipeline component instead."
133
            )
134
            raise ValueError(msg)
1✔
135

136
        self._unresolved_parameters = parameters
1✔
137
        # Create the tools schema from the component run method parameters
138
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
139

140
        def component_invoker(**kwargs):
1✔
141
            """
142
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
143

144
            :param kwargs: The keyword arguments to invoke the component with.
145
            :returns: The result of the component invocation.
146
            """
147
            converted_kwargs = {}
1✔
148
            input_sockets = component.__haystack_input__._sockets_dict
1✔
149
            for param_name, param_value in kwargs.items():
1✔
150
                param_type = input_sockets[param_name].type
1✔
151

152
                # Check if the type (or list element type) has from_dict
153
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
154
                if hasattr(target_type, "from_dict"):
1✔
155
                    if isinstance(param_value, list):
1✔
156
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
157
                    elif isinstance(param_value, dict):
×
158
                        param_value = target_type.from_dict(param_value)
×
159
                else:
160
                    # Let TypeAdapter handle both single values and lists
161
                    type_adapter = TypeAdapter(param_type)
1✔
162
                    param_value = type_adapter.validate_python(param_value)
1✔
163

164
                converted_kwargs[param_name] = param_value
1✔
165
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
166
            return component.run(**converted_kwargs)
1✔
167

168
        # Generate a name for the tool if not provided
169
        if not name:
1✔
170
            class_name = component.__class__.__name__
1✔
171
            # Convert camelCase/PascalCase to snake_case
172
            name = "".join(
1✔
173
                [
174
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
175
                    for i, c in enumerate(class_name)
176
                ]
177
            ).lstrip("_")
178

179
        description = description or component.__doc__ or name
1✔
180

181
        # Create the Tool instance with the component invoker as the function to be called and the schema
182
        super().__init__(
1✔
183
            name=name,
184
            description=description,
185
            parameters=tool_schema,
186
            function=component_invoker,
187
            inputs_from_state=inputs_from_state,
188
            outputs_to_state=outputs_to_state,
189
        )
190
        self._component = component
1✔
191

192
    def to_dict(self) -> Dict[str, Any]:
1✔
193
        """
194
        Serializes the ComponentTool to a dictionary.
195
        """
196
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
197

198
        serialized = {
1✔
199
            "component": serialized_component,
200
            "name": self.name,
201
            "description": self.description,
202
            "parameters": self._unresolved_parameters,
203
            "inputs_from_state": self.inputs_from_state,
204
        }
205

206
        if self.outputs_to_state is not None:
1✔
207
            serialized_outputs = {}
1✔
208
            for key, config in self.outputs_to_state.items():
1✔
209
                serialized_config = config.copy()
1✔
210
                if "handler" in config:
1✔
211
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
212
                serialized_outputs[key] = serialized_config
1✔
213
            serialized["outputs_to_state"] = serialized_outputs
1✔
214

215
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
216

217
    @classmethod
1✔
218
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
219
        """
220
        Deserializes the ComponentTool from a dictionary.
221
        """
222
        inner_data = data["data"]
1✔
223
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
224
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
225

226
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
227
            deserialized_outputs = {}
1✔
228
            for key, config in inner_data["outputs_to_state"].items():
1✔
229
                deserialized_config = config.copy()
1✔
230
                if "handler" in config:
1✔
231
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
232
                deserialized_outputs[key] = deserialized_config
1✔
233
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
234

235
        return cls(
1✔
236
            component=component,
237
            name=inner_data["name"],
238
            description=inner_data["description"],
239
            parameters=inner_data.get("parameters", None),
240
            inputs_from_state=inner_data.get("inputs_from_state", None),
241
            outputs_to_state=inner_data.get("outputs_to_state", None),
242
        )
243

244
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
245
        """
246
        Creates an OpenAI tools schema from a component's run method parameters.
247

248
        :param component: The component to create the schema from.
249
        :raises SchemaGenerationError: If schema generation fails
250
        :returns: OpenAI tools schema for the component's run method parameters.
251
        """
252
        properties = {}
1✔
253
        required = []
1✔
254

255
        param_descriptions = self._get_param_descriptions(component.run)
1✔
256

257
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
258
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
259
                continue
1✔
260
            input_type = socket.type
1✔
261
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
262

263
            try:
1✔
264
                property_schema = self._create_property_schema(input_type, description)
1✔
265
            except Exception as e:
×
266
                raise SchemaGenerationError(
×
267
                    f"Error processing input '{input_name}': {e}. "
268
                    f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
269
                    f"and lists of these types as input types for component's run method."
270
                ) from e
271

272
            properties[input_name] = property_schema
1✔
273

274
            # Use socket.is_mandatory to check if the input is required
275
            if socket.is_mandatory:
1✔
276
                required.append(input_name)
1✔
277

278
        parameters_schema = {"type": "object", "properties": properties}
1✔
279

280
        if required:
1✔
281
            parameters_schema["required"] = required
1✔
282

283
        return parameters_schema
1✔
284

285
    @staticmethod
1✔
286
    def _get_param_descriptions(method: Callable) -> Dict[str, str]:
1✔
287
        """
288
        Extracts parameter descriptions from the method's docstring using docstring_parser.
289

290
        :param method: The method to extract parameter descriptions from.
291
        :returns: A dictionary mapping parameter names to their descriptions.
292
        """
293
        docstring = getdoc(method)
1✔
294
        if not docstring:
1✔
295
            return {}
×
296

297
        docstring_parser_import.check()
1✔
298
        parsed_doc = parse(docstring)
1✔
299
        param_descriptions = {}
1✔
300
        for param in parsed_doc.params:
1✔
301
            if not param.description:
1✔
302
                logger.warning(
×
303
                    "Missing description for parameter '%s'. Please add a description in the component's "
304
                    "run() method docstring using the format ':param %%s: <description>'. "
305
                    "This description helps the LLM understand how to use this parameter." % param.arg_name
306
                )
307
            param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
308
        return param_descriptions
1✔
309

310
    @staticmethod
1✔
311
    def _is_nullable_type(python_type: Any) -> bool:
1✔
312
        """
313
        Checks if the type is a Union with NoneType (i.e., Optional).
314

315
        :param python_type: The Python type to check.
316
        :returns: True if the type is a Union with NoneType, False otherwise.
317
        """
318
        origin = get_origin(python_type)
1✔
319
        if origin is Union:
1✔
320
            return type(None) in get_args(python_type)
1✔
321
        return False
1✔
322

323
    def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]:
1✔
324
        """
325
        Creates a schema for a list type.
326

327
        :param item_type: The type of items in the list.
328
        :param description: The description of the list.
329
        :returns: A dictionary representing the list schema.
330
        """
331
        items_schema = self._create_property_schema(item_type, "")
1✔
332
        items_schema.pop("description", None)
1✔
333
        return {"type": "array", "description": description, "items": items_schema}
1✔
334

335
    def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]:
1✔
336
        """
337
        Creates a schema for a dataclass.
338

339
        :param python_type: The dataclass type.
340
        :param description: The description of the dataclass.
341
        :returns: A dictionary representing the dataclass schema.
342
        """
343
        schema = {"type": "object", "description": description, "properties": {}}
1✔
344
        cls = python_type if isinstance(python_type, type) else python_type.__class__
1✔
345
        for field in fields(cls):
1✔
346
            field_description = f"Field '{field.name}' of '{cls.__name__}'."
1✔
347
            if isinstance(schema["properties"], dict):
1✔
348
                schema["properties"][field.name] = self._create_property_schema(field.type, field_description)
1✔
349
        return schema
1✔
350

351
    @staticmethod
1✔
352
    def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
1✔
353
        """
354
        Creates a schema for a basic Python type.
355

356
        :param python_type: The Python type.
357
        :param description: The description of the type.
358
        :returns: A dictionary representing the basic type schema.
359
        """
360
        type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
1✔
361
        return {"type": type_mapping.get(python_type, "string"), "description": description}
1✔
362

363
    def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
1✔
364
        """
365
        Creates a property schema for a given Python type, recursively if necessary.
366

367
        :param python_type: The Python type to create a property schema for.
368
        :param description: The description of the property.
369
        :param default: The default value of the property.
370
        :returns: A dictionary representing the property schema.
371
        :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models
372
        """
373
        nullable = self._is_nullable_type(python_type)
1✔
374
        if nullable:
1✔
375
            non_none_types = [t for t in get_args(python_type) if t is not type(None)]
1✔
376
            python_type = non_none_types[0] if non_none_types else str
1✔
377

378
        origin = get_origin(python_type)
1✔
379
        if origin is list:
1✔
380
            schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
1✔
381
        elif is_dataclass(python_type):
1✔
382
            schema = self._create_dataclass_schema(python_type, description)
1✔
383
        elif hasattr(python_type, "model_validate"):
1✔
384
            raise SchemaGenerationError(
×
385
                f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for "
386
                f"component's run method."
387
            )
388
        else:
389
            schema = self._create_basic_type_schema(python_type, description)
1✔
390

391
        if default is not None:
1✔
392
            schema["default"] = default
×
393

394
        return schema
1✔
395

396
    def __deepcopy__(self, memo: Dict[Any, Any]) -> "ComponentTool":
1✔
397
        # Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
398
        # When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
399
        # We overwrite ComponentTool.__deepcopy__ to fix this until a more comprehensive fix is merged.
400
        # We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
401
        result = copy(self)
1✔
402

403
        # Add the object to the memo dictionary to handle circular references
404
        memo[id(self)] = result
1✔
405

406
        # Deep copy all attributes with exception handling
407
        for key, value in self.__dict__.items():
1✔
408
            try:
1✔
409
                # Try to deep copy the attribute
410
                setattr(result, key, deepcopy(value, memo))
1✔
411
            except TypeError:
1✔
412
                # Fall back to using the original attribute for components that use Jinja2-templates
413
                logger.debug(
1✔
414
                    "deepcopy of ComponentTool {tool_name} failed. Using original attribute '{attribute}' instead.",
415
                    tool_name=self.name,
416
                    attribute=key,
417
                )
418
                setattr(result, key, getattr(self, key))
1✔
419

420
        return result
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc