• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 8109583845

01 Mar 2024 09:56AM UTC coverage: 89.912% (+0.006%) from 89.906%
8109583845

push

github

web-flow
fix: Update `Component` protocol to fix some type checking issues (#7270)

* Update Component protocol to fix some type checking issues

* Add release notes

* Fix logline in test

* Fix run type definition

5285 of 5878 relevant lines covered (89.91%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.78
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
"""
1✔
5
    Attributes:
6

7
        component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
8

9
    All components must follow the contract below. This docstring is the source of truth for components contract.
10

11
    <hr>
12

13
    `@component` decorator
14

15
    All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
16

17
    <hr>
18

19
    `__init__(self, **kwargs)`
20

21
    Optional method.
22

23
    Components may have an `__init__` method where they define:
24

25
    - `self.init_parameters = {same parameters that the __init__ method received}`:
26
        In this dictionary you can store any state the components wish to be persisted when they are saved.
27
        These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
28
        Note that by default the `@component` decorator saves the arguments automatically.
29
        However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
30
        Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
31

32
    Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
33
    dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
34
    time. If there's the need for such values, consider serializing them to a string.
35

36
    _(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
37

38
    The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
39
    validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
40
    the `warm_up()` method.
41

42
    <hr>
43

44
    `warm_up(self)`
45

46
    Optional method.
47

48
    This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
49
    because Pipeline will not keep track of which components it called `warm_up()` on.
50

51
    <hr>
52

53
    `run(self, data)`
54

55
    Mandatory method.
56

57
    This is the method where the main functionality of the component should be carried out. It's called by
58
    `Pipeline.run()`.
59

60
    When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
61
    method decorated with `@component.input`. This dataclass contains:
62

63
    - all the input values coming from other components connected to it,
64
    - if any is missing, the corresponding value defined in `self.defaults`, if it exists.
65

66
    `run()` must return a single instance of the dataclass declared through the method decorated with
67
    `@component.output`.
68

69
"""
70

71
import inspect
1✔
72
import sys
1✔
73
from collections.abc import Callable
1✔
74
from copy import deepcopy
1✔
75
from types import new_class
1✔
76
from typing import Any, Dict, Optional, Protocol, runtime_checkable
1✔
77

78
from haystack import logging
1✔
79
from haystack.core.errors import ComponentError
1✔
80

81
from .sockets import Sockets
1✔
82
from .types import InputSocket, OutputSocket, _empty
1✔
83

84
logger = logging.getLogger(__name__)
1✔
85

86

87
@runtime_checkable
1✔
88
class Component(Protocol):
1✔
89
    """
90
    Note this is only used by type checking tools.
91

92
    In order to implement the `Component` protocol, custom components need to
93
    have a `run` method. The signature of the method and its return value
94
    won't be checked, i.e. classes with the following methods:
95

96
        def run(self, param: str) -> Dict[str, Any]:
97
            ...
98

99
    and
100

101
        def run(self, **kwargs):
102
            ...
103

104
    will be both considered as respecting the protocol. This makes the type
105
    checking much weaker, but we have other places where we ensure code is
106
    dealing with actual Components.
107

108
    The protocol is runtime checkable so it'll be possible to assert:
109

110
        isinstance(MyComponent, Component)
111
    """
112

113
    # This is the most reliable way to define the protocol for the `run` method.
114
    # Defining a method doesn't work as different Components will have different
115
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
116
    # expected signature must be identical.
117
    # This makes most Language Servers and type checkers happy and shows less errors.
118
    # NOTE: This check can be removed when we drop Python 3.8 support.
119
    if sys.version_info >= (3, 9):
1✔
120
        run: Callable[..., Dict[str, Any]]
×
121
    else:
122
        run: Callable
1✔
123

124

125
class ComponentMeta(type):
1✔
126
    def __call__(cls, *args, **kwargs):
1✔
127
        """
128
        This method is called when clients instantiate a Component and
129
        runs before __new__ and __init__.
130
        """
131
        # This will call __new__ then __init__, giving us back the Component instance
132
        instance = super().__call__(*args, **kwargs)
1✔
133

134
        # Before returning, we have the chance to modify the newly created
135
        # Component instance, so we take the chance and set up the I/O sockets
136

137
        # If `component.set_output_types()` was called in the component constructor,
138
        # `__haystack_output__` is already populated, no need to do anything.
139
        if not hasattr(instance, "__haystack_output__"):
1✔
140
            # If that's not the case, we need to populate `__haystack_output__`
141
            #
142
            # If the `run` method was decorated, it has a `_output_types_cache` field assigned
143
            # that stores the output specification.
144
            # We deepcopy the content of the cache to transfer ownership from the class method
145
            # to the actual instance, so that different instances of the same class won't share this data.
146
            instance.__haystack_output__ = Sockets(
1✔
147
                instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
148
            )
149

150
        # Create the sockets if set_input_types() wasn't called in the constructor.
151
        # If it was called and there are some parameters also in the `run()` method, these take precedence.
152
        if not hasattr(instance, "__haystack_input__"):
1✔
153
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
154
        run_signature = inspect.signature(getattr(cls, "run"))
1✔
155
        for param in list(run_signature.parameters)[1:]:  # First is 'self' and it doesn't matter.
1✔
156
            if run_signature.parameters[param].kind not in (
1✔
157
                inspect.Parameter.VAR_POSITIONAL,
158
                inspect.Parameter.VAR_KEYWORD,
159
            ):  # ignore variable args
160
                socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
1✔
161
                if run_signature.parameters[param].default != inspect.Parameter.empty:
1✔
162
                    socket_kwargs["default_value"] = run_signature.parameters[param].default
1✔
163
                instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
1✔
164

165
        # Since a Component can't be used in multiple Pipelines at the same time
166
        # we need to know if it's already owned by a Pipeline when adding it to one.
167
        # We use this flag to check that.
168
        instance.__haystack_added_to_pipeline__ = None
1✔
169

170
        # Only Components with variadic inputs can be greedy. If the user set the greedy flag
171
        # to True, but the component doesn't have a variadic input, we set it to False.
172
        # We can have this information only at instance creation time, so we do it here.
173
        is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
1✔
174
        if not is_variadic and cls.__haystack_is_greedy__:
1✔
175
            logger.warning(
1✔
176
                "Component '{component}' has no variadic input, but it's marked as greedy. "
177
                "This is not supported and can lead to unexpected behavior.",
178
                component=cls.__name__,
179
            )
180

181
        return instance
1✔
182

183

184
def _component_repr(component: Component) -> str:
1✔
185
    """
186
    All Components override their __repr__ method with this one.
187
    It prints the component name and the input/output sockets.
188
    """
189
    result = object.__repr__(component)
1✔
190
    if pipeline := getattr(component, "__haystack_added_to_pipeline__"):
1✔
191
        # This Component has been added in a Pipeline, let's get the name from there.
192
        result += f"\n{pipeline.get_component_name(component)}"
1✔
193

194
    # We're explicitly ignoring the type here because we're sure that the component
195
    # has the __haystack_input__ and __haystack_output__ attributes at this point
196
    return f"{result}\n{component.__haystack_input__}\n{component.__haystack_output__}"  # type: ignore[attr-defined]
1✔
197

198

199
class _Component:
1✔
200
    """
201
    See module's docstring.
202

203
    Args:
204
        class_: the class that Canals should use as a component.
205
        serializable: whether to check, at init time, if the component can be saved with
206
        `save_pipelines()`.
207

208
    Returns:
209
        A class that can be recognized as a component.
210

211
    Raises:
212
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
213
    """
214

215
    def __init__(self):
1✔
216
        self.registry = {}
1✔
217

218
    def set_input_type(self, instance, name: str, type: Any, default: Any = _empty):
1✔
219
        """
220
        Add a single input socket to the component instance.
221

222
        :param instance: Component instance where the input type will be added.
223
        :param name: name of the input socket.
224
        :param type: type of the input socket.
225
        :param default: default value of the input socket, defaults to _empty
226
        """
227
        if not hasattr(instance, "__haystack_input__"):
1✔
228
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
229
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
230

231
    def set_input_types(self, instance, **types):
1✔
232
        """
233
        Method that specifies the input types when 'kwargs' is passed to the run method.
234

235
        Use as:
236

237
        ```python
238
        @component
239
        class MyComponent:
240

241
            def __init__(self, value: int):
242
                component.set_input_types(value_1=str, value_2=str)
243
                ...
244

245
            @component.output_types(output_1=int, output_2=str)
246
            def run(self, **kwargs):
247
                return {"output_1": kwargs["value_1"], "output_2": ""}
248
        ```
249

250
        Note that if the `run()` method also specifies some parameters, those will take precedence.
251

252
        For example:
253

254
        ```python
255
        @component
256
        class MyComponent:
257

258
            def __init__(self, value: int):
259
                component.set_input_types(value_1=str, value_2=str)
260
                ...
261

262
            @component.output_types(output_1=int, output_2=str)
263
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
264
                return {"output_1": kwargs["value_1"], "output_2": ""}
265
        ```
266

267
        would add a mandatory `value_0` parameters, make the `value_1`
268
        parameter optional with a default None, and keep the `value_2`
269
        parameter mandatory as specified in `set_input_types`.
270

271
        """
272
        instance.__haystack_input__ = Sockets(
1✔
273
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
274
        )
275

276
    def set_output_types(self, instance, **types):
1✔
277
        """
278
        Method that specifies the output types when the 'run' method is not decorated
279
        with 'component.output_types'.
280

281
        Use as:
282

283
        ```python
284
        @component
285
        class MyComponent:
286

287
            def __init__(self, value: int):
288
                component.set_output_types(output_1=int, output_2=str)
289
                ...
290

291
            # no decorators here
292
            def run(self, value: int):
293
                return {"output_1": 1, "output_2": "2"}
294
        ```
295
        """
296
        instance.__haystack_output__ = Sockets(
1✔
297
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
298
        )
299

300
    def output_types(self, **types):
1✔
301
        """
302
        Decorator factory that specifies the output types of a component.
303

304
        Use as:
305

306
        ```python
307
        @component
308
        class MyComponent:
309
            @component.output_types(output_1=int, output_2=str)
310
            def run(self, value: int):
311
                return {"output_1": 1, "output_2": "2"}
312
        ```
313
        """
314

315
        def output_types_decorator(run_method):
1✔
316
            """
317
            This happens at class creation time, and since we don't have the decorated
318
            class available here, we temporarily store the output types as an attribute of
319
            the decorated method. The ComponentMeta metaclass will use this data to create
320
            sockets at instance creation time.
321
            """
322
            setattr(
1✔
323
                run_method,
324
                "_output_types_cache",
325
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
326
            )
327
            return run_method
1✔
328

329
        return output_types_decorator
1✔
330

331
    def _component(self, cls, is_greedy: bool = False):
1✔
332
        """
333
        Decorator validating the structure of the component and registering it in the components registry.
334
        """
335
        logger.debug("Registering {component} as a component", component=cls)
1✔
336

337
        # Check for required methods and fail as soon as possible
338
        if not hasattr(cls, "run"):
1✔
339
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
340

341
        def copy_class_namespace(namespace):
1✔
342
            """
343
            This is the callback that `typing.new_class` will use
344
            to populate the newly created class. We just copy
345
            the whole namespace from the decorated class.
346
            """
347
            for key, val in dict(cls.__dict__).items():
1✔
348
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
349
                if key in ("__dict__", "__weakref__"):
1✔
350
                    continue
1✔
351
                namespace[key] = val
1✔
352

353
        # Recreate the decorated component class so it uses our metaclass.
354
        # We must explicitly redefine the type of the class to make sure language servers
355
        # and type checkers understand that the class is of the correct type.
356
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
357
        cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)  # type: ignore[no-redef]
1✔
358

359
        # Save the component in the class registry (for deserialization)
360
        class_path = f"{cls.__module__}.{cls.__name__}"
1✔
361
        if class_path in self.registry:
1✔
362
            # Corner case, but it may occur easily in notebooks when re-running cells.
363
            logger.debug(
1✔
364
                "Component {component} is already registered. Previous imported from '{module}', new imported from '{new_module}'",
365
                component=class_path,
366
                module=self.registry[class_path],
367
                new_module=cls,
368
            )
369
        self.registry[class_path] = cls
1✔
370
        logger.debug("Registered Component {component}", component=cls)
1✔
371

372
        # Override the __repr__ method with a default one
373
        cls.__repr__ = _component_repr
1✔
374

375
        # The greedy flag can be True only if the component has a variadic input.
376
        # At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
377
        # So we set it to whatever the user specified, during the instance creation we'll change it if needed
378
        # since we'll have access to the input sockets and check if any of them is variadic.
379
        setattr(cls, "__haystack_is_greedy__", is_greedy)
1✔
380

381
        return cls
1✔
382

383
    def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
1✔
384
        # We must wrap the call to the decorator in a function for it to work
385
        # correctly with or without parens
386
        def wrap(cls):
1✔
387
            return self._component(cls, is_greedy=is_greedy)
1✔
388

389
        if cls:
1✔
390
            # Decorator is called without parens
391
            return wrap(cls)
1✔
392

393
        # Decorator is called with parens
394
        return wrap
1✔
395

396

397
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc