• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15463513653

05 Jun 2025 09:28AM UTC coverage: 90.414% (+0.02%) from 90.392%
15463513653

push

github

web-flow
chore: Make docstring-parser core dep (#9477)

* Make docstring-parser core dep

* Add reno note

11478 of 12695 relevant lines covered (90.41%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.67
haystack/tools/parameters_schema_utils.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import collections
1✔
6
from dataclasses import MISSING, fields, is_dataclass
1✔
7
from inspect import getdoc
1✔
8
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union, get_args, get_origin
1✔
9

10
from docstring_parser import parse
1✔
11
from pydantic import BaseModel, Field, create_model
1✔
12

13
from haystack import logging
1✔
14
from haystack.dataclasses import ChatMessage
1✔
15

16
logger = logging.getLogger(__name__)
1✔
17

18

19
def _get_param_descriptions(method: Callable) -> Tuple[str, Dict[str, str]]:
1✔
20
    """
21
    Extracts parameter descriptions from the method's docstring using docstring_parser.
22

23
    :param method: The method to extract parameter descriptions from.
24
    :returns:
25
        A tuple including the short description of the method and a dictionary mapping parameter names to their
26
        descriptions.
27
    """
28
    docstring = getdoc(method)
1✔
29
    if not docstring:
1✔
30
        return "", {}
×
31

32
    parsed_doc = parse(docstring)
1✔
33
    param_descriptions = {}
1✔
34
    for param in parsed_doc.params:
1✔
35
        if not param.description:
1✔
36
            logger.warning(
×
37
                "Missing description for parameter '%s'. Please add a description in the component's "
38
                "run() method docstring using the format ':param %%s: <description>'. "
39
                "This description helps the LLM understand how to use this parameter." % param.arg_name
40
            )
41
        param_descriptions[param.arg_name] = param.description.strip() if param.description else ""
1✔
42
    return parsed_doc.short_description or "", param_descriptions
1✔
43

44

45
def _get_component_param_descriptions(component: Any) -> Tuple[str, Dict[str, str]]:
1✔
46
    """
47
    Get parameter descriptions from a component, handling both regular Components and SuperComponents.
48

49
    For regular components, this extracts descriptions from the run method's docstring.
50
    For SuperComponents, this extracts descriptions from the underlying pipeline components.
51

52
    :param component: The component to extract parameter descriptions from
53
    :returns: A tuple of (short_description, param_descriptions)
54
    """
55
    from haystack.core.super_component.super_component import _SuperComponent
1✔
56

57
    # Get descriptions from the component's run method
58
    short_desc, param_descriptions = _get_param_descriptions(component.run)
1✔
59

60
    # If it's a SuperComponent, enhance the descriptions from the original components
61
    if isinstance(component, _SuperComponent):
1✔
62
        # Collect descriptions from components in the pipeline
63
        component_descriptions = []
1✔
64
        processed_components = set()
1✔
65

66
        # First gather descriptions from all components that have inputs mapped
67
        for super_param_name, pipeline_paths in component.input_mapping.items():
1✔
68
            # Collect descriptions from all mapped components
69
            descriptions = []
1✔
70
            for path in pipeline_paths:
1✔
71
                try:
1✔
72
                    # Get the component and socket this input is mapped fromq
73
                    comp_name, socket_name = component._split_component_path(path)
1✔
74
                    pipeline_component = component.pipeline.get_component(comp_name)
1✔
75

76
                    # Get run method descriptions for this component
77
                    run_desc, run_param_descriptions = _get_param_descriptions(pipeline_component.run)
1✔
78

79
                    # Don't add the same component description multiple times
80
                    if comp_name not in processed_components:
1✔
81
                        processed_components.add(comp_name)
1✔
82
                        if run_desc:
1✔
83
                            component_descriptions.append(f"'{comp_name}': {run_desc}")
1✔
84

85
                    # Add parameter description if available
86
                    if input_param_mapping := run_param_descriptions.get(socket_name):
1✔
87
                        descriptions.append(f"Provided to the '{comp_name}' component as: '{input_param_mapping}'")
1✔
88
                except Exception as e:
×
89
                    logger.debug(f"Error extracting description for {super_param_name} from {path}: {str(e)}")
×
90

91
            # We don't only handle a one to one description mapping of input parameters, but a one to many mapping.
92
            # i.e. for a combined_input parameter description:
93
            # super_comp = SuperComponent(
94
            #   pipeline=pipeline,
95
            #   input_mapping={"combined_input": ["comp_a.query", "comp_b.text"]},
96
            # )
97
            if descriptions:
1✔
98
                param_descriptions[super_param_name] = ", and ".join(descriptions) + "."
1✔
99

100
        # We also create a combined description for the SuperComponent based on its components
101
        if component_descriptions:
1✔
102
            short_desc = f"A component that combines: {', '.join(component_descriptions)}"
1✔
103

104
    return short_desc, param_descriptions
1✔
105

106

107
def _dataclass_to_pydantic_model(dc_type: Any) -> type[BaseModel]:
1✔
108
    """
109
    Convert a Python dataclass to an equivalent Pydantic model.
110

111
    :param dc_type: The dataclass type to convert.
112
    :returns:
113
        A dynamically generated Pydantic model class with fields and types derived from the dataclass definition.
114
        Field descriptions are extracted from docstrings when available.
115
    """
116
    _, param_descriptions = _get_param_descriptions(dc_type)
1✔
117
    cls = dc_type if isinstance(dc_type, type) else dc_type.__class__
1✔
118

119
    field_defs: Dict[str, Any] = {}
1✔
120
    for field in fields(dc_type):
1✔
121
        f_type = field.type if isinstance(field.type, str) else _resolve_type(field.type)
1✔
122
        default = field.default if field.default is not MISSING else ...
1✔
123
        default = field.default_factory() if callable(field.default_factory) else default
1✔
124

125
        # Special handling for ChatMessage since pydantic doesn't allow for field names with leading underscores
126
        field_name = field.name
1✔
127
        if dc_type is ChatMessage and field_name.startswith("_"):
1✔
128
            # We remove the underscore since ChatMessage.from_dict does allow for field names without the underscore
129
            field_name = field_name[1:]
1✔
130

131
        description = param_descriptions.get(field_name, f"Field '{field_name}' of '{cls.__name__}'.")
1✔
132
        field_defs[field_name] = (f_type, Field(default, description=description))
1✔
133

134
    model = create_model(cls.__name__, **field_defs)
1✔
135
    return model
1✔
136

137

138
def _resolve_type(_type: Any) -> Any:
1✔
139
    """
140
    Recursively resolve and convert complex type annotations, transforming dataclasses into Pydantic-compatible types.
141

142
    This function walks through nested type annotations (e.g., List, Dict, Union) and converts any dataclass types
143
    it encounters into corresponding Pydantic models.
144

145
    :param _type: The type annotation to resolve. If the type is a dataclass, it will be converted to a Pydantic model.
146
        For generic types (like List[SomeDataclass]), the inner types are also resolved recursively.
147

148
    :returns:
149
        A fully resolved type, with all dataclass types converted to Pydantic models
150
    """
151
    if is_dataclass(_type):
1✔
152
        return _dataclass_to_pydantic_model(_type)
1✔
153

154
    origin = get_origin(_type)
1✔
155
    args = get_args(_type)
1✔
156

157
    if origin is list:
1✔
158
        return List[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
159

160
    if origin is collections.abc.Sequence:
1✔
161
        return Sequence[_resolve_type(args[0]) if args else Any]  # type: ignore[misc]
1✔
162

163
    if origin is Union:
1✔
164
        return Union[tuple(_resolve_type(a) for a in args)]  # type: ignore[misc]
1✔
165

166
    if origin is dict:
1✔
167
        return Dict[args[0] if args else Any, _resolve_type(args[1]) if args else Any]  # type: ignore[misc]
1✔
168

169
    return _type
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc