• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14195113562

01 Apr 2025 12:04PM UTC coverage: 90.004% (-0.09%) from 90.091%
14195113562

push

github

web-flow
feat: Add `outputs_to_string` to Tool and ComponentTool (#9152)

* Add outputs_to_string to Tool and ComponentTool

* Doc string and fix tests

* Add reno

* Fix mypy

10282 of 11424 relevant lines covered (90.0%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.67
haystack/tools/component_tool.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from copy import copy, deepcopy
1✔
6
from dataclasses import fields, is_dataclass
1✔
7
from inspect import getdoc
1✔
8
from typing import Any, Callable, Dict, Optional, Union, get_args, get_origin
1✔
9

10
from pydantic import TypeAdapter
1✔
11

12
from haystack import logging
1✔
13
from haystack.core.component import Component
1✔
14
from haystack.core.serialization import (
1✔
15
    component_from_dict,
16
    component_to_dict,
17
    generate_qualified_class_name,
18
    import_class_by_name,
19
)
20
from haystack.lazy_imports import LazyImport
1✔
21
from haystack.tools import Tool
1✔
22
from haystack.tools.errors import SchemaGenerationError
1✔
23
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
1✔
24

25
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
1✔
26
    from docstring_parser import parse
1✔
27

28

29
logger = logging.getLogger(__name__)
1✔
30

31

32
class ComponentTool(Tool):
1✔
33
    """
34
    A Tool that wraps Haystack components, allowing them to be used as tools by LLMs.
35

36
    ComponentTool automatically generates LLM-compatible tool schemas from component input sockets,
37
    which are derived from the component's `run` method signature and type hints.
38

39

40
    Key features:
41
    - Automatic LLM tool calling schema generation from component input sockets
42
    - Type conversion and validation for component inputs
43
    - Support for types:
44
        - Dataclasses
45
        - Lists of dataclasses
46
        - Basic types (str, int, float, bool, dict)
47
        - Lists of basic types
48
    - Automatic name generation from component class name
49
    - Description extraction from component docstrings
50

51
    To use ComponentTool, you first need a Haystack component - either an existing one or a new one you create.
52
    You can create a ComponentTool from the component by passing the component to the ComponentTool constructor.
53
    Below is an example of creating a ComponentTool from an existing SerperDevWebSearch component.
54

55
    ```python
56
    from haystack import component, Pipeline
57
    from haystack.tools import ComponentTool
58
    from haystack.components.websearch import SerperDevWebSearch
59
    from haystack.utils import Secret
60
    from haystack.components.tools.tool_invoker import ToolInvoker
61
    from haystack.components.generators.chat import OpenAIChatGenerator
62
    from haystack.dataclasses import ChatMessage
63

64
    # Create a SerperDev search component
65
    search = SerperDevWebSearch(api_key=Secret.from_env_var("SERPERDEV_API_KEY"), top_k=3)
66

67
    # Create a tool from the component
68
    tool = ComponentTool(
69
        component=search,
70
        name="web_search",  # Optional: defaults to "serper_dev_web_search"
71
        description="Search the web for current information on any topic"  # Optional: defaults to component docstring
72
    )
73

74
    # Create pipeline with OpenAIChatGenerator and ToolInvoker
75
    pipeline = Pipeline()
76
    pipeline.add_component("llm", OpenAIChatGenerator(model="gpt-4o-mini", tools=[tool]))
77
    pipeline.add_component("tool_invoker", ToolInvoker(tools=[tool]))
78

79
    # Connect components
80
    pipeline.connect("llm.replies", "tool_invoker.messages")
81

82
    message = ChatMessage.from_user("Use the web search tool to find information about Nikola Tesla")
83

84
    # Run pipeline
85
    result = pipeline.run({"llm": {"messages": [message]}})
86

87
    print(result)
88
    ```
89

90
    """
91

92
    def __init__(
1✔
93
        self,
94
        component: Component,
95
        name: Optional[str] = None,
96
        description: Optional[str] = None,
97
        parameters: Optional[Dict[str, Any]] = None,
98
        *,
99
        outputs_to_string: Optional[Dict[str, Union[str, Callable[[Any], str]]]] = None,
100
        inputs_from_state: Optional[Dict[str, str]] = None,
101
        outputs_to_state: Optional[Dict[str, Dict[str, Union[str, Callable]]]] = None,
102
    ):
103
        """
104
        Create a Tool instance from a Haystack component.
105

106
        :param component: The Haystack component to wrap as a tool.
107
        :param name: Optional name for the tool (defaults to snake_case of component class name).
108
        :param description: Optional description (defaults to component's docstring).
109
        :param parameters:
110
            A JSON schema defining the parameters expected by the Tool.
111
            Will fall back to the parameters defined in the component's run method signature if not provided.
112
        :param outputs_to_string:
113
            Optional dictionary defining how a tool outputs should be converted into a string.
114
            If the source is provided only the specified output key is sent to the handler.
115
            If the source is omitted the whole tool result is sent to the handler.
116
            Example: {
117
                "source": "docs", "handler": format_documents
118
            }
119
        :param inputs_from_state:
120
            Optional dictionary mapping state keys to tool parameter names.
121
            Example: {"repository": "repo"} maps state's "repository" to tool's "repo" parameter.
122
        :param outputs_to_state:
123
            Optional dictionary defining how tool outputs map to keys within state as well as optional handlers.
124
            If the source is provided only the specified output key is sent to the handler.
125
            Example: {
126
                "documents": {"source": "docs", "handler": custom_handler}
127
            }
128
            If the source is omitted the whole tool result is sent to the handler.
129
            Example: {
130
                "documents": {"handler": custom_handler}
131
            }
132
        :raises ValueError: If the component is invalid or schema generation fails.
133
        """
134
        if not isinstance(component, Component):
1✔
135
            message = (
1✔
136
                f"Object {component!r} is not a Haystack component. "
137
                "Use ComponentTool only with Haystack component instances."
138
            )
139
            raise ValueError(message)
1✔
140

141
        if getattr(component, "__haystack_added_to_pipeline__", None):
1✔
142
            msg = (
1✔
143
                "Component has been added to a pipeline and can't be used to create a ComponentTool. "
144
                "Create ComponentTool from a non-pipeline component instead."
145
            )
146
            raise ValueError(msg)
1✔
147

148
        self._unresolved_parameters = parameters
1✔
149
        # Create the tools schema from the component run method parameters
150
        tool_schema = parameters or self._create_tool_parameters_schema(component, inputs_from_state or {})
1✔
151

152
        def component_invoker(**kwargs):
1✔
153
            """
154
            Invokes the component using keyword arguments provided by the LLM function calling/tool-generated response.
155

156
            :param kwargs: The keyword arguments to invoke the component with.
157
            :returns: The result of the component invocation.
158
            """
159
            converted_kwargs = {}
1✔
160
            input_sockets = component.__haystack_input__._sockets_dict
1✔
161
            for param_name, param_value in kwargs.items():
1✔
162
                param_type = input_sockets[param_name].type
1✔
163

164
                # Check if the type (or list element type) has from_dict
165
                target_type = get_args(param_type)[0] if get_origin(param_type) is list else param_type
1✔
166
                if hasattr(target_type, "from_dict"):
1✔
167
                    if isinstance(param_value, list):
1✔
168
                        param_value = [target_type.from_dict(item) for item in param_value if isinstance(item, dict)]
1✔
169
                    elif isinstance(param_value, dict):
×
170
                        param_value = target_type.from_dict(param_value)
×
171
                else:
172
                    # Let TypeAdapter handle both single values and lists
173
                    type_adapter = TypeAdapter(param_type)
1✔
174
                    param_value = type_adapter.validate_python(param_value)
1✔
175

176
                converted_kwargs[param_name] = param_value
1✔
177
            logger.debug(f"Invoking component {type(component)} with kwargs: {converted_kwargs}")
1✔
178
            return component.run(**converted_kwargs)
1✔
179

180
        # Generate a name for the tool if not provided
181
        if not name:
1✔
182
            class_name = component.__class__.__name__
1✔
183
            # Convert camelCase/PascalCase to snake_case
184
            name = "".join(
1✔
185
                [
186
                    "_" + c.lower() if c.isupper() and i > 0 and not class_name[i - 1].isupper() else c.lower()
187
                    for i, c in enumerate(class_name)
188
                ]
189
            ).lstrip("_")
190

191
        description = description or component.__doc__ or name
1✔
192

193
        # Create the Tool instance with the component invoker as the function to be called and the schema
194
        super().__init__(
1✔
195
            name=name,
196
            description=description,
197
            parameters=tool_schema,
198
            function=component_invoker,
199
            inputs_from_state=inputs_from_state,
200
            outputs_to_state=outputs_to_state,
201
            outputs_to_string=outputs_to_string,
202
        )
203
        self._component = component
1✔
204

205
    def to_dict(self) -> Dict[str, Any]:
1✔
206
        """
207
        Serializes the ComponentTool to a dictionary.
208
        """
209
        serialized_component = component_to_dict(obj=self._component, name=self.name)
1✔
210

211
        serialized = {
1✔
212
            "component": serialized_component,
213
            "name": self.name,
214
            "description": self.description,
215
            "parameters": self._unresolved_parameters,
216
            "outputs_to_string": self.outputs_to_string,
217
            "inputs_from_state": self.inputs_from_state,
218
            "outputs_to_state": self.outputs_to_state,
219
        }
220

221
        if self.outputs_to_state is not None:
1✔
222
            serialized_outputs = {}
1✔
223
            for key, config in self.outputs_to_state.items():
1✔
224
                serialized_config = config.copy()
1✔
225
                if "handler" in config:
1✔
226
                    serialized_config["handler"] = serialize_callable(config["handler"])
1✔
227
                serialized_outputs[key] = serialized_config
1✔
228
            serialized["outputs_to_state"] = serialized_outputs
1✔
229

230
        if self.outputs_to_string is not None and self.outputs_to_string.get("handler") is not None:
1✔
231
            serialized["outputs_to_string"] = serialize_callable(self.outputs_to_string["handler"])
×
232

233
        return {"type": generate_qualified_class_name(type(self)), "data": serialized}
1✔
234

235
    @classmethod
1✔
236
    def from_dict(cls, data: Dict[str, Any]) -> "Tool":
1✔
237
        """
238
        Deserializes the ComponentTool from a dictionary.
239
        """
240
        inner_data = data["data"]
1✔
241
        component_class = import_class_by_name(inner_data["component"]["type"])
1✔
242
        component = component_from_dict(cls=component_class, data=inner_data["component"], name=inner_data["name"])
1✔
243

244
        if "outputs_to_state" in inner_data and inner_data["outputs_to_state"]:
1✔
245
            deserialized_outputs = {}
1✔
246
            for key, config in inner_data["outputs_to_state"].items():
1✔
247
                deserialized_config = config.copy()
1✔
248
                if "handler" in config:
1✔
249
                    deserialized_config["handler"] = deserialize_callable(config["handler"])
1✔
250
                deserialized_outputs[key] = deserialized_config
1✔
251
            inner_data["outputs_to_state"] = deserialized_outputs
1✔
252

253
        if (
1✔
254
            inner_data.get("outputs_to_string") is not None
255
            and inner_data["outputs_to_string"].get("handler") is not None
256
        ):
257
            inner_data["outputs_to_string"]["handler"] = deserialize_callable(
×
258
                inner_data["outputs_to_string"]["handler"]
259
            )
260

261
        return cls(
1✔
262
            component=component,
263
            name=inner_data["name"],
264
            description=inner_data["description"],
265
            parameters=inner_data.get("parameters", None),
266
            outputs_to_string=inner_data.get("outputs_to_string", None),
267
            inputs_from_state=inner_data.get("inputs_from_state", None),
268
            outputs_to_state=inner_data.get("outputs_to_state", None),
269
        )
270

271
    def _create_tool_parameters_schema(self, component: Component, inputs_from_state: Dict[str, Any]) -> Dict[str, Any]:
1✔
272
        """
273
        Creates an OpenAI tools schema from a component's run method parameters.
274

275
        :param component: The component to create the schema from.
276
        :raises SchemaGenerationError: If schema generation fails
277
        :returns: OpenAI tools schema for the component's run method parameters.
278
        """
279
        properties = {}
1✔
280
        required = []
1✔
281

282
        param_descriptions = self._get_param_descriptions(component.run)
1✔
283

284
        for input_name, socket in component.__haystack_input__._sockets_dict.items():  # type: ignore[attr-defined]
1✔
285
            if inputs_from_state is not None and input_name in inputs_from_state:
1✔
286
                continue
1✔
287
            input_type = socket.type
1✔
288
            description = param_descriptions.get(input_name, f"Input '{input_name}' for the component.")
1✔
289

290
            try:
1✔
291
                property_schema = self._create_property_schema(input_type, description)
1✔
292
            except Exception as e:
×
293
                raise SchemaGenerationError(
×
294
                    f"Error processing input '{input_name}': {e}. "
295
                    f"Schema generation supports basic types (str, int, float, bool, dict), dataclasses, "
296
                    f"and lists of these types as input types for component's run method."
297
                ) from e
298

299
            properties[input_name] = property_schema
1✔
300

301
            # Use socket.is_mandatory to check if the input is required
302
            if socket.is_mandatory:
1✔
303
                required.append(input_name)
1✔
304

305
        parameters_schema = {"type": "object", "properties": properties}
1✔
306

307
        if required:
1✔
308
            parameters_schema["required"] = required
1✔
309

310
        return parameters_schema
1✔
311

312
    @staticmethod
1✔
313
    def _get_param_descriptions(method: Callable) -> Dict[str, str]:
1✔
314
        """
315
        Extracts parameter descriptions from the method's docstring using docstring_parser.
316

317
        :param method: The method to extract parameter descriptions from.
318
        :returns: A dictionary mapping parameter names to their descriptions.
319
        """
320
        docstring = getdoc(method)
1✔
321
        if not docstring:
1✔
322
            return {}
×
323

324
        docstring_parser_import.check()
1✔
325
        parsed_doc = parse(docstring)
1✔
326
        param_descriptions = {}
1✔
327
        for param in parsed_doc.params:
1✔
328
            if not param.description:
1✔
329
                logger.warning(
×
330
                    "Missing description for parameter '%s'. Please add a description in the component's "
331
                    "run() method docstring using the format ':param %%s: <description>'. "
332
                    "This description helps the LLM understand how to use this parameter." % param.arg_name
333
                )
334
            param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
335
        return param_descriptions
1✔
336

337
    @staticmethod
1✔
338
    def _is_nullable_type(python_type: Any) -> bool:
1✔
339
        """
340
        Checks if the type is a Union with NoneType (i.e., Optional).
341

342
        :param python_type: The Python type to check.
343
        :returns: True if the type is a Union with NoneType, False otherwise.
344
        """
345
        origin = get_origin(python_type)
1✔
346
        if origin is Union:
1✔
347
            return type(None) in get_args(python_type)
1✔
348
        return False
1✔
349

350
    def _create_list_schema(self, item_type: Any, description: str) -> Dict[str, Any]:
1✔
351
        """
352
        Creates a schema for a list type.
353

354
        :param item_type: The type of items in the list.
355
        :param description: The description of the list.
356
        :returns: A dictionary representing the list schema.
357
        """
358
        items_schema = self._create_property_schema(item_type, "")
1✔
359
        items_schema.pop("description", None)
1✔
360
        return {"type": "array", "description": description, "items": items_schema}
1✔
361

362
    def _create_dataclass_schema(self, python_type: Any, description: str) -> Dict[str, Any]:
1✔
363
        """
364
        Creates a schema for a dataclass.
365

366
        :param python_type: The dataclass type.
367
        :param description: The description of the dataclass.
368
        :returns: A dictionary representing the dataclass schema.
369
        """
370
        schema = {"type": "object", "description": description, "properties": {}}
1✔
371
        cls = python_type if isinstance(python_type, type) else python_type.__class__
1✔
372
        for field in fields(cls):
1✔
373
            field_description = f"Field '{field.name}' of '{cls.__name__}'."
1✔
374
            if isinstance(schema["properties"], dict):
1✔
375
                schema["properties"][field.name] = self._create_property_schema(field.type, field_description)
1✔
376
        return schema
1✔
377

378
    @staticmethod
1✔
379
    def _create_basic_type_schema(python_type: Any, description: str) -> Dict[str, Any]:
1✔
380
        """
381
        Creates a schema for a basic Python type.
382

383
        :param python_type: The Python type.
384
        :param description: The description of the type.
385
        :returns: A dictionary representing the basic type schema.
386
        """
387
        type_mapping = {str: "string", int: "integer", float: "number", bool: "boolean", dict: "object"}
1✔
388
        return {"type": type_mapping.get(python_type, "string"), "description": description}
1✔
389

390
    def _create_property_schema(self, python_type: Any, description: str, default: Any = None) -> Dict[str, Any]:
1✔
391
        """
392
        Creates a property schema for a given Python type, recursively if necessary.
393

394
        :param python_type: The Python type to create a property schema for.
395
        :param description: The description of the property.
396
        :param default: The default value of the property.
397
        :returns: A dictionary representing the property schema.
398
        :raises SchemaGenerationError: If schema generation fails, e.g., for unsupported types like Pydantic v2 models
399
        """
400
        nullable = self._is_nullable_type(python_type)
1✔
401
        if nullable:
1✔
402
            non_none_types = [t for t in get_args(python_type) if t is not type(None)]
1✔
403
            python_type = non_none_types[0] if non_none_types else str
1✔
404

405
        origin = get_origin(python_type)
1✔
406
        if origin is list:
1✔
407
            schema = self._create_list_schema(get_args(python_type)[0] if get_args(python_type) else Any, description)
1✔
408
        elif is_dataclass(python_type):
1✔
409
            schema = self._create_dataclass_schema(python_type, description)
1✔
410
        elif hasattr(python_type, "model_validate"):
1✔
411
            raise SchemaGenerationError(
×
412
                f"Pydantic models (e.g. {python_type.__name__}) are not supported as input types for "
413
                f"component's run method."
414
            )
415
        else:
416
            schema = self._create_basic_type_schema(python_type, description)
1✔
417

418
        if default is not None:
1✔
419
            schema["default"] = default
×
420

421
        return schema
1✔
422

423
    def __deepcopy__(self, memo: Dict[Any, Any]) -> "ComponentTool":
1✔
424
        # Jinja2 templates throw an Exception when we deepcopy them (see https://github.com/pallets/jinja/issues/758)
425
        # When we use a ComponentTool in a pipeline at runtime, we deepcopy the tool
426
        # We overwrite ComponentTool.__deepcopy__ to fix this until a more comprehensive fix is merged.
427
        # We track the issue here: https://github.com/deepset-ai/haystack/issues/9011
428
        result = copy(self)
1✔
429

430
        # Add the object to the memo dictionary to handle circular references
431
        memo[id(self)] = result
1✔
432

433
        # Deep copy all attributes with exception handling
434
        for key, value in self.__dict__.items():
1✔
435
            try:
1✔
436
                # Try to deep copy the attribute
437
                setattr(result, key, deepcopy(value, memo))
1✔
438
            except TypeError:
1✔
439
                # Fall back to using the original attribute for components that use Jinja2-templates
440
                logger.debug(
1✔
441
                    "deepcopy of ComponentTool {tool_name} failed. Using original attribute '{attribute}' instead.",
442
                    tool_name=self.name,
443
                    attribute=key,
444
                )
445
                setattr(result, key, getattr(self, key))
1✔
446

447
        return result
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc