• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15165238383

21 May 2025 02:42PM UTC coverage: 90.404% (-0.04%) from 90.443%
15165238383

Pull #9275

github

web-flow
Merge 82e69fe2c into 17432f710
Pull Request #9275: feat: return common type in SuperComponent type compatibility check

11135 of 12317 relevant lines covered (90.4%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.87
haystack/tools/parameters_schema_utils.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import collections
1✔
6
from dataclasses import MISSING, fields, is_dataclass
1✔
7
from inspect import getdoc
1✔
8
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, get_args, get_origin
1✔
9

10
from pydantic import BaseModel, Field, create_model
1✔
11

12
from haystack import logging
1✔
13
from haystack.dataclasses import ChatMessage
1✔
14
from haystack.lazy_imports import LazyImport
1✔
15

16
with LazyImport(message="Run 'pip install docstring-parser'") as docstring_parser_import:
1✔
17
    from docstring_parser import parse
1✔
18

19

20
logger = logging.getLogger(__name__)
1✔
21

22

23
def _get_param_descriptions(method: Callable) -> Tuple[str, Dict[str, str]]:
1✔
24
    """
25
    Extracts parameter descriptions from the method's docstring using docstring_parser.
26

27
    :param method: The method to extract parameter descriptions from.
28
    :returns:
29
        A tuple including the short description of the method and a dictionary mapping parameter names to their
30
        descriptions.
31
    """
32
    docstring = getdoc(method)
1✔
33
    if not docstring:
1✔
34
        return "", {}
×
35

36
    docstring_parser_import.check()
1✔
37
    parsed_doc = parse(docstring)
1✔
38
    param_descriptions = {}
1✔
39
    for param in parsed_doc.params:
1✔
40
        if not param.description:
1✔
41
            logger.warning(
×
42
                "Missing description for parameter '%s'. Please add a description in the component's "
43
                "run() method docstring using the format ':param %%s: <description>'. "
44
                "This description helps the LLM understand how to use this parameter." % param.arg_name
45
            )
46
        param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
47
    return parsed_doc.short_description or "", param_descriptions
1✔
48

49

50
def _get_component_param_descriptions(component: Any) -> Tuple[str, Dict[str, str]]:
1✔
51
    """
52
    Get parameter descriptions from a component, handling both regular Components and SuperComponents.
53

54
    For regular components, this extracts descriptions from the run method's docstring.
55
    For SuperComponents, this extracts descriptions from the underlying pipeline components.
56

57
    :param component: The component to extract parameter descriptions from
58
    :returns: A tuple of (short_description, param_descriptions)
59
    """
60
    from haystack.core.super_component.super_component import _SuperComponent
1✔
61

62
    # Get descriptions from the component's run method
63
    short_desc, param_descriptions = _get_param_descriptions(component.run)
1✔
64

65
    # If it's a SuperComponent, enhance the descriptions from the original components
66
    if isinstance(component, _SuperComponent):
1✔
67
        # Collect descriptions from components in the pipeline
68
        component_descriptions = []
1✔
69
        processed_components = set()
1✔
70

71
        # First gather descriptions from all components that have inputs mapped
72
        for super_param_name, pipeline_paths in component.input_mapping.items():
1✔
73
            # Collect descriptions from all mapped components
74
            descriptions = []
1✔
75
            for path in pipeline_paths:
1✔
76
                try:
1✔
77
                    # Get the component and socket this input is mapped fromq
78
                    comp_name, socket_name = component._split_component_path(path)
1✔
79
                    pipeline_component = component.pipeline.get_component(comp_name)
1✔
80

81
                    # Get run method descriptions for this component
82
                    run_desc, run_param_descriptions = _get_param_descriptions(pipeline_component.run)
1✔
83

84
                    # Don't add the same component description multiple times
85
                    if comp_name not in processed_components:
1✔
86
                        processed_components.add(comp_name)
1✔
87
                        if run_desc:
1✔
88
                            component_descriptions.append(f"'{comp_name}': {run_desc}")
1✔
89

90
                    # Add parameter description if available
91
                    if input_param_mapping := run_param_descriptions.get(socket_name):
1✔
92
                        descriptions.append(f"Provided to the '{comp_name}' component as: '{input_param_mapping}'")
1✔
93
                except Exception as e:
×
94
                    logger.debug(f"Error extracting description for {super_param_name} from {path}: {str(e)}")
×
95

96
            # We don't only handle a one to one description mapping of input parameters, but a one to many mapping.
97
            # i.e. for a combined_input parameter description:
98
            # super_comp = SuperComponent(
99
            #   pipeline=pipeline,
100
            #   input_mapping={"combined_input": ["comp_a.query", "comp_b.text"]},
101
            # )
102
            if descriptions:
1✔
103
                param_descriptions[super_param_name] = ", and ".join(descriptions) + "."
1✔
104

105
        # We also create a combined description for the SuperComponent based on its components
106
        if component_descriptions:
1✔
107
            short_desc = f"A component that combines: {', '.join(component_descriptions)}"
1✔
108

109
    return short_desc, param_descriptions
1✔
110

111

112
def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]:
1✔
113
    """
114
    Convert a Python dataclass to an equivalent Pydantic model.
115

116
    :param dc_type: The dataclass type to convert.
117
    :returns:
118
        A dynamically generated Pydantic model class with fields and types derived from the dataclass definition.
119
        Field descriptions are extracted from docstrings when available.
120
    """
121
    _, param_descriptions = _get_param_descriptions(dc_type)
1✔
122
    cls = dc_type if isinstance(dc_type, type) else dc_type.__class__
1✔
123

124
    field_defs: Dict[str, Any] = {}
1✔
125
    for field in fields(dc_type):
1✔
126
        f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type)
1✔
127
        default = field.default if field.default is not MISSING else ...
1✔
128
        default = field.default_factory() if callable(field.default_factory) else default
1✔
129

130
        # Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores
131
        field_name = field.name
1✔
132
        if dc_type is ChatMessage and field_name.startswith("_"):
1✔
133
            # We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore
134
            field_name = field_name[1:]
1✔
135

136
        description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.")
1✔
137
        field_defs[field_name] = (f_type, Field(default, description=description))
1✔
138

139
    model = create_model(cls.__name__, **field_defs)
1✔
140
    return model
1✔
141

142

143
def _resolve_type(_type: Any) -> Any:
1✔
144
    """
145
    Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types.
146

147
    This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types
148
    it encounters into corresponding Pydantic models.
149

150
    :param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model.
151
        For generic types (like List[SomeDataclass]), the inner types are also resolved recursively.
152

153
    :returns:
154
        A fully resolved type, with all dataclass types converted to Pydantic models
155
    """
156
    if is_dataclass(_type):
1✔
157
        return _dataclass_to_pydantic_model(_type)
1✔
158

159
    origin = get_origin(_type)
1✔
160
    args = get_args(_type)
1✔
161

162
    if origin is list:
1✔
163
        return List[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
164

165
    if origin is collections.abc.Sequence:
1✔
166
        return Sequence[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
167

168
    if origin is Union:
1✔
169
        return Union[tuple(_resolve_type(a) for a in args)]  # type: ignore[misc]
1✔
170

171
    if origin is dict:
1✔
172
        return Dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any]  # type: ignore[misc]
1✔
173

174
    return _type
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc