• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 18595249452

17 Oct 2025 02:08PM UTC coverage: 92.22% (+0.02%) from 92.2%
18595249452

Pull #9886

github

web-flow
Merge ad30d1879 into cc4f024af
Pull Request #9886: feat: Update tools param to Optional[Union[list[Union[Tool, Toolset]], Toolset]]

13382 of 14511 relevant lines covered (92.22%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

95.4
haystack/tools/parameters_schema_utils.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import collections
1✔
6
from dataclasses import MISSING, fields, is_dataclass
1✔
7
from inspect import getdoc
1✔
8
from typing import Any, Callable, Sequence, Union, get_args, get_origin
1✔
9

10
from docstring_parser import parse
1✔
11
from pydantic import BaseModel, Field, create_model
1✔
12

13
from haystack import logging
1✔
14
from haystack.dataclasses import ChatMessage
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
# Schema placeholder models for Tool and Toolset
20
# These are used during JSON schema generation to represent non-serializable types
21
class _ToolSchemaPlaceholder(BaseModel):
1✔
22
    """Placeholder model representing a Tool for JSON schema generation."""
23

24
    name: str = Field(description="Name of the tool")
1✔
25
    description: str = Field(description="Description of the tool")
1✔
26
    parameters: dict[str, Any] = Field(description="JSON schema of the tool parameters")
1✔
27

28

29
class _ToolsetSchemaPlaceholder(BaseModel):
1✔
30
    """Placeholder model representing a Toolset for JSON schema generation."""
31

32
    tools: list[_ToolSchemaPlaceholder] = Field(description="List of tools in the toolset")
1✔
33

34

35
def _get_param_descriptions(method: Callable) -> tuple[str, dict[str, str]]:
1✔
36
    """
37
    Extracts parameter descriptions from the method's docstring using docstring_parser.
38

39
    :param method: The method to extract parameter descriptions from.
40
    :returns:
41
        A tuple including the short description of the method and a dictionary mapping parameter names to their
42
        descriptions.
43
    """
44
    docstring = getdoc(method)
1✔
45
    if not docstring:
1✔
46
        return "", {}
×
47

48
    parsed_doc = parse(docstring)
1✔
49
    param_descriptions = {}
1✔
50
    for param in parsed_doc.params:
1✔
51
        if not param.description:
1✔
52
            logger.warning(
×
53
                "Missing description for parameter '%s'. Please add a description in the component's "
54
                "run() method docstring using the format ':param %%s: <description>'. "
55
                "This description helps the LLM understand how to use this parameter." % param.arg_name
56
            )
57
        param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
58
    return parsed_doc.short_description or "", param_descriptions
1✔
59

60

61
def _get_component_param_descriptions(component: Any) -> tuple[str, dict[str, str]]:
1✔
62
    """
63
    Get parameter descriptions from a component, handling both regular Components and SuperComponents.
64

65
    For regular components, this extracts descriptions from the run method's docstring.
66
    For SuperComponents, this extracts descriptions from the underlying pipeline components.
67

68
    :param component: The component to extract parameter descriptions from
69
    :returns: A tuple of (short_description, param_descriptions)
70
    """
71
    from haystack.core.super_component.super_component import _SuperComponent
1✔
72

73
    # Get descriptions from the component's run method
74
    short_desc, param_descriptions = _get_param_descriptions(component.run)
1✔
75

76
    # If it's a SuperComponent, enhance the descriptions from the original components
77
    if isinstance(component, _SuperComponent):
1✔
78
        # Collect descriptions from components in the pipeline
79
        component_descriptions = []
1✔
80
        processed_components = set()
1✔
81

82
        # First gather descriptions from all components that have inputs mapped
83
        for super_param_name, pipeline_paths in component.input_mapping.items():
1✔
84
            # Collect descriptions from all mapped components
85
            descriptions = []
1✔
86
            for path in pipeline_paths:
1✔
87
                try:
1✔
88
                    # Get the component and socket this input is mapped fromq
89
                    comp_name, socket_name = component._split_component_path(path)
1✔
90
                    pipeline_component = component.pipeline.get_component(comp_name)
1✔
91

92
                    # Get run method descriptions for this component
93
                    run_desc, run_param_descriptions = _get_param_descriptions(pipeline_component.run)
1✔
94

95
                    # Don't add the same component description multiple times
96
                    if comp_name not in processed_components:
1✔
97
                        processed_components.add(comp_name)
1✔
98
                        if run_desc:
1✔
99
                            component_descriptions.append(f"'{comp_name}': {run_desc}")
1✔
100

101
                    # Add parameter description if available
102
                    if input_param_mapping := run_param_descriptions.get(socket_name):
1✔
103
                        descriptions.append(f"Provided to the '{comp_name}' component as: '{input_param_mapping}'")
1✔
104
                except Exception as e:
×
105
                    logger.debug(f"Error extracting description for {super_param_name} from {path}: {str(e)}")
×
106

107
            # We don't only handle a one to one description mapping of input parameters, but a one to many mapping.
108
            # i.e. for a combined_input parameter description:
109
            # super_comp = SuperComponent(
110
            #   pipeline=pipeline,
111
            #   input_mapping={"combined_input": ["comp_a.query", "comp_b.text"]},
112
            # )
113
            if descriptions:
1✔
114
                param_descriptions[super_param_name] = ", and ".join(descriptions) + "."
1✔
115

116
        # We also create a combined description for the SuperComponent based on its components
117
        if component_descriptions:
1✔
118
            short_desc = f"A component that combines: {', '.join(component_descriptions)}"
1✔
119

120
    return short_desc, param_descriptions
1✔
121

122

123
def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]:
1✔
124
    """
125
    Convert a Python dataclass to an equivalent Pydantic model.
126

127
    :param dc_type: The dataclass type to convert.
128
    :returns:
129
        A dynamically generated Pydantic model class with fields and types derived from the dataclass definition.
130
        Field descriptions are extracted from docstrings when available.
131
    """
132
    _, param_descriptions = _get_param_descriptions(dc_type)
1✔
133
    cls = dc_type if isinstance(dc_type, type) else dc_type.__class__
1✔
134

135
    field_defs: dict[str, Any] = {}
1✔
136
    for field in fields(dc_type):
1✔
137
        f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type)
1✔
138
        default = field.default if field.default is not MISSING else ...
1✔
139
        default = field.default_factory() if callable(field.default_factory) else default
1✔
140

141
        # Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores
142
        field_name = field.name
1✔
143
        if dc_type is ChatMessage and field_name.startswith("_"):
1✔
144
            # We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore
145
            field_name = field_name[1:]
1✔
146

147
        description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.")
1✔
148
        field_defs[field_name] = (f_type, Field(default, description=description))
1✔
149

150
    model = create_model(cls.__name__, **field_defs)
1✔
151
    return model
1✔
152

153

154
def _resolve_type(_type: Any) -> Any:  # noqa: PLR0911  # pylint: disable=too-many-return-statements
1✔
155
    """
156
    Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types.
157

158
    This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types
159
    it encounters into corresponding Pydantic models.
160

161
    :param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model.
162
        For generic types (like list[SomeDataclass]), the inner types are also resolved recursively.
163

164
    :returns:
165
        A fully resolved type, with all dataclass types converted to Pydantic models
166
    """
167
    # Special handling for Tool and Toolset types - replace with schema placeholders
168
    # These types contain Callables which cannot be serialized to JSON Schema
169
    from haystack.tools.tool import Tool
1✔
170
    from haystack.tools.toolset import Toolset
1✔
171

172
    if _type is Tool:
1✔
173
        return _ToolSchemaPlaceholder
1✔
174

175
    if _type is Toolset:
1✔
176
        return _ToolsetSchemaPlaceholder
1✔
177

178
    if is_dataclass(_type):
1✔
179
        return _dataclass_to_pydantic_model(_type)
1✔
180

181
    origin = get_origin(_type)
1✔
182
    args = get_args(_type)
1✔
183

184
    if origin is list:
1✔
185
        return list[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
186

187
    if origin is collections.abc.Sequence:
1✔
188
        return Sequence[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
189

190
    if origin is Union:
1✔
191
        return Union[tuple(_resolve_type(a) for a in args)]
1✔
192

193
    if origin is dict:
1✔
194
        return dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any]  # type: ignore[misc]
1✔
195

196
    return _type
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc