• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14733019619

29 Apr 2025 01:53PM UTC coverage: 90.513% (+0.01%) from 90.5%
14733019619

push

github

web-flow
feat: validation function for `run()` and `run_async()` parameters signature for (custom) components (#9322)

* adding tests

* adding release notes

* small improvements

10896 of 12038 relevant lines covered (90.51%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.44
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
from collections.abc import Callable, Coroutine
1✔
74
from contextlib import contextmanager
1✔
75
from contextvars import ContextVar
1✔
76
from copy import deepcopy
1✔
77
from dataclasses import dataclass
1✔
78
from types import new_class
1✔
79
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, Union, runtime_checkable
1✔
80

81
from typing_extensions import ParamSpec
1✔
82

83
from haystack import logging
1✔
84
from haystack.core.errors import ComponentError
1✔
85

86
from .sockets import Sockets
1✔
87
from .types import InputSocket, OutputSocket, _empty
1✔
88

89
logger = logging.getLogger(__name__)
1✔
90

91
RunParamsT = ParamSpec("RunParamsT")
1✔
92
SyncRunReturnT = TypeVar("SyncRunReturnT", bound=Dict[str, Any])
1✔
93
AsyncRunReturnT = TypeVar("AsyncRunReturnT", bound=Coroutine[Any, Any, Dict[str, Any]])
1✔
94
RunReturnT = Union[SyncRunReturnT, AsyncRunReturnT]
1✔
95

96

97
@dataclass
1✔
98
class PreInitHookPayload:
1✔
99
    """
100
    Payload for the hook called before a component instance is initialized.
101

102
    :param callback:
103
        Receives the following inputs: component class and init parameter keyword args.
104
    :param in_progress:
105
        Flag to indicate if the hook is currently being executed.
106
        Used to prevent it from being called recursively (if the component's constructor
107
        instantiates another component).
108
    """
109

110
    callback: Callable
1✔
111
    in_progress: bool = False
1✔
112

113

114
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
115

116

117
@contextmanager
1✔
118
def _hook_component_init(callback: Callable):
1✔
119
    """
120
    Context manager to set a callback that will be invoked before a component's constructor is called.
121

122
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
123
    parameters in place.
124

125
    :param callback:
126
        Callback function to invoke.
127
    """
128
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
129
    try:
1✔
130
        yield
1✔
131
    finally:
132
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
133

134

135
@runtime_checkable
1✔
136
class Component(Protocol):
1✔
137
    """
138
    Note this is only used by type checking tools.
139

140
    In order to implement the `Component` protocol, custom components need to
141
    have a `run` method. The signature of the method and its return value
142
    won't be checked, i.e. classes with the following methods:
143

144
        def run(self, param: str) -> Dict[str, Any]:
145
            ...
146

147
    and
148

149
        def run(self, **kwargs):
150
            ...
151

152
    will be both considered as respecting the protocol. This makes the type
153
    checking much weaker, but we have other places where we ensure code is
154
    dealing with actual Components.
155

156
    The protocol is runtime checkable so it'll be possible to assert:
157

158
        isinstance(MyComponent, Component)
159
    """
160

161
    # This is the most reliable way to define the protocol for the `run` method.
162
    # Defining a method doesn't work as different Components will have different
163
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
164
    # expected signature must be identical.
165
    # This makes most Language Servers and type checkers happy and shows less errors.
166
    run: Callable[..., Dict[str, Any]]
1✔
167

168

169
class ComponentMeta(type):
1✔
170
    @staticmethod
1✔
171
    def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
172
        """
173
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
174
        """
175
        init_signature = inspect.signature(cls_type.__init__)
1✔
176
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
177

178
        out = {}
1✔
179
        for arg, (name, info) in zip(args, init_params.items()):
1✔
180
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
181
                raise ComponentError(
1✔
182
                    "Pre-init hooks do not support components with variadic positional args in their init method"
183
                )
184

185
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
186
            out[name] = arg
1✔
187
        return out
1✔
188

189
    @staticmethod
1✔
190
    def _parse_and_set_output_sockets(instance: Any):
1✔
191
        has_async_run = hasattr(instance, "run_async")
1✔
192

193
        # If `component.set_output_types()` was called in the component constructor,
194
        # `__haystack_output__` is already populated, no need to do anything.
195
        if not hasattr(instance, "__haystack_output__"):
1✔
196
            # If that's not the case, we need to populate `__haystack_output__`
197
            #
198
            # If either of the run methods were decorated, they'll have a field assigned that
199
            # stores the output specification. If both run methods were decorated, we ensure that
200
            # outputs are the same. We deepcopy the content of the cache to transfer ownership from
201
            # the class method to the actual instance, so that different instances of the same class
202
            # won't share this data.
203

204
            run_output_types = getattr(instance.run, "_output_types_cache", {})
1✔
205
            async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
1✔
206

207
            if has_async_run and run_output_types != async_run_output_types:
1✔
208
                raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
1✔
209
            output_types_cache = run_output_types
1✔
210

211
            instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
1✔
212

213
    @staticmethod
1✔
214
    def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
1✔
215
        def inner(method, sockets):
1✔
216
            from inspect import Parameter
1✔
217

218
            run_signature = inspect.signature(method)
1✔
219

220
            for param_name, param_info in run_signature.parameters.items():
1✔
221
                if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
1✔
222
                    continue
1✔
223

224
                socket_kwargs = {"name": param_name, "type": param_info.annotation}
1✔
225
                if param_info.default != Parameter.empty:
1✔
226
                    socket_kwargs["default_value"] = param_info.default
1✔
227

228
                new_socket = InputSocket(**socket_kwargs)
1✔
229

230
                # Also ensure that new sockets don't override existing ones.
231
                existing_socket = sockets.get(param_name)
1✔
232
                if existing_socket is not None and existing_socket != new_socket:
1✔
233
                    raise ComponentError(
1✔
234
                        "set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
235
                    )
236

237
                sockets[param_name] = new_socket
1✔
238

239
            return run_signature
1✔
240

241
        # Create the sockets if set_input_types() wasn't called in the constructor.
242
        if not hasattr(instance, "__haystack_input__"):
1✔
243
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
244

245
        inner(getattr(component_cls, "run"), instance.__haystack_input__)
1✔
246

247
        # Ensure that the sockets are the same for the async method, if it exists.
248
        async_run = getattr(component_cls, "run_async", None)
1✔
249
        if async_run is not None:
1✔
250
            run_sockets = Sockets(instance, {}, InputSocket)
1✔
251
            async_run_sockets = Sockets(instance, {}, InputSocket)
1✔
252

253
            # Can't use the sockets from above as they might contain
254
            # values set with set_input_types().
255
            run_sig = inner(getattr(component_cls, "run"), run_sockets)
1✔
256
            async_run_sig = inner(async_run, async_run_sockets)
1✔
257

258
            if async_run_sockets != run_sockets or run_sig != async_run_sig:
1✔
259
                sig_diff = _compare_run_methods_signatures(run_sig, async_run_sig)
1✔
260
                raise ComponentError(
1✔
261
                    f"Parameters of 'run' and 'run_async' methods must be the same.\nDifferences found:\n{sig_diff}"
262
                )
263

264
    def __call__(cls, *args, **kwargs):
1✔
265
        """
266
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
267
        """
268
        # This will call __new__ then __init__, giving us back the Component instance
269
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
270
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
271
            instance = super().__call__(*args, **kwargs)
1✔
272
        else:
273
            try:
1✔
274
                pre_init_hook.in_progress = True
1✔
275
                named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
1✔
276
                assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), (
1✔
277
                    "positional and keyword arguments overlap"
278
                )
279
                kwargs.update(named_positional_args)
1✔
280
                pre_init_hook.callback(cls, kwargs)
1✔
281
                instance = super().__call__(**kwargs)
1✔
282
            finally:
283
                pre_init_hook.in_progress = False
1✔
284

285
        # Before returning, we have the chance to modify the newly created
286
        # Component instance, so we take the chance and set up the I/O sockets
287
        has_async_run = hasattr(instance, "run_async")
1✔
288
        if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
1✔
289
            raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
1✔
290
        instance.__haystack_supports_async__ = has_async_run
1✔
291

292
        ComponentMeta._parse_and_set_input_sockets(cls, instance)
1✔
293
        ComponentMeta._parse_and_set_output_sockets(instance)
1✔
294

295
        # Since a Component can't be used in multiple Pipelines at the same time
296
        # we need to know if it's already owned by a Pipeline when adding it to one.
297
        # We use this flag to check that.
298
        instance.__haystack_added_to_pipeline__ = None
1✔
299

300
        return instance
1✔
301

302

303
def _component_repr(component: Component) -> str:
1✔
304
    """
305
    All Components override their __repr__ method with this one.
306

307
    It prints the component name and the input/output sockets.
308
    """
309
    result = object.__repr__(component)
1✔
310
    if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
1✔
311
        # This Component has been added in a Pipeline, let's get the name from there.
312
        result += f"\n{pipeline.get_component_name(component)}"
1✔
313

314
    # We're explicitly ignoring the type here because we're sure that the component
315
    # has the __haystack_input__ and __haystack_output__ attributes at this point
316
    return (
1✔
317
        f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}"
318
        f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}"
319
    )
320

321

322
def _component_run_has_kwargs(component_cls: Type) -> bool:
1✔
323
    run_method = getattr(component_cls, "run", None)
1✔
324
    if run_method is None:
1✔
325
        return False
×
326
    else:
327
        return any(
1✔
328
            param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
329
        )
330

331

332
def _compare_run_methods_signatures(run_sig: inspect.Signature, async_run_sig: inspect.Signature) -> str:
1✔
333
    """
334
    Builds a detailed error message with the differences between the signatures of the run and run_async methods.
335

336
    :param run_sig: The signature of the run method
337
    :param async_run_sig: The signature of the run_async method
338

339
    :returns:
340
        A detailed error message if signatures don't match, empty string if they do
341
    """
342
    differences = []
1✔
343
    run_params = list(run_sig.parameters.items())
1✔
344
    async_params = list(async_run_sig.parameters.items())
1✔
345

346
    if len(run_params) != len(async_params):
1✔
347
        differences.append(
1✔
348
            f"Different number of parameters: run has {len(run_params)}, run_async has {len(async_params)}"
349
        )
350

351
    for (run_name, run_param), (async_name, async_param) in zip(run_params, async_params):
1✔
352
        if run_name != async_name:
1✔
353
            differences.append(f"Parameter name mismatch: {run_name} vs {async_name}")
1✔
354

355
        if run_param.annotation != async_param.annotation:
1✔
356
            differences.append(
1✔
357
                f"Parameter '{run_name}' type mismatch: {run_param.annotation} vs {async_param.annotation}"
358
            )
359

360
        if run_param.default != async_param.default:
1✔
361
            differences.append(
1✔
362
                f"Parameter '{run_name}' default value mismatch: {run_param.default} vs {async_param.default}"
363
            )
364

365
        if run_param.kind != async_param.kind:
1✔
366
            differences.append(
1✔
367
                f"Parameter '{run_name}' kind (POSITIONAL, KEYWORD, etc.) mismatch: "
368
                f"{run_param.kind} vs {async_param.kind}"
369
            )
370

371
    return "\n".join(differences)
1✔
372

373

374
class _Component:
1✔
375
    """
376
    See module's docstring.
377

378
    Args:
379
        cls: the class that should be used as a component.
380

381
    Returns:
382
        A class that can be recognized as a component.
383

384
    Raises:
385
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
386
    """
387

388
    def __init__(self):
1✔
389
        self.registry = {}
1✔
390

391
    def set_input_type(
1✔
392
        self,
393
        instance,
394
        name: str,
395
        type: Any,  # noqa: A002
396
        default: Any = _empty,
397
    ):
398
        """
399
        Add a single input socket to the component instance.
400

401
        Replaces any existing input socket with the same name.
402

403
        :param instance: Component instance where the input type will be added.
404
        :param name: name of the input socket.
405
        :param type: type of the input socket.
406
        :param default: default value of the input socket, defaults to _empty
407
        """
408
        if not _component_run_has_kwargs(instance.__class__):
1✔
409
            raise ComponentError(
1✔
410
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
411
            )
412

413
        if not hasattr(instance, "__haystack_input__"):
1✔
414
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
415
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
416

417
    def set_input_types(self, instance, **types):
1✔
418
        """
419
        Method that specifies the input types when 'kwargs' is passed to the run method.
420

421
        Use as:
422

423
        ```python
424
        @component
425
        class MyComponent:
426

427
            def __init__(self, value: int):
428
                component.set_input_types(self, value_1=str, value_2=str)
429
                ...
430

431
            @component.output_types(output_1=int, output_2=str)
432
            def run(self, **kwargs):
433
                return {"output_1": kwargs["value_1"], "output_2": ""}
434
        ```
435

436
        Note that if the `run()` method also specifies some parameters, those will take precedence.
437

438
        For example:
439

440
        ```python
441
        @component
442
        class MyComponent:
443

444
            def __init__(self, value: int):
445
                component.set_input_types(self, value_1=str, value_2=str)
446
                ...
447

448
            @component.output_types(output_1=int, output_2=str)
449
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
450
                return {"output_1": kwargs["value_1"], "output_2": ""}
451
        ```
452

453
        would add a mandatory `value_0` parameters, make the `value_1`
454
        parameter optional with a default None, and keep the `value_2`
455
        parameter mandatory as specified in `set_input_types`.
456

457
        """
458
        if not _component_run_has_kwargs(instance.__class__):
1✔
459
            raise ComponentError(
1✔
460
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
461
            )
462

463
        instance.__haystack_input__ = Sockets(
1✔
464
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
465
        )
466

467
    def set_output_types(self, instance, **types):
1✔
468
        """
469
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
470

471
        Use as:
472

473
        ```python
474
        @component
475
        class MyComponent:
476

477
            def __init__(self, value: int):
478
                component.set_output_types(self, output_1=int, output_2=str)
479
                ...
480

481
            # no decorators here
482
            def run(self, value: int):
483
                return {"output_1": 1, "output_2": "2"}
484
        ```
485
        """
486
        has_decorator = hasattr(instance.run, "_output_types_cache")
1✔
487
        if has_decorator:
1✔
488
            raise ComponentError(
1✔
489
                "Cannot call `set_output_types` on a component that already has "
490
                "the 'output_types' decorator on its `run` method"
491
            )
492

493
        instance.__haystack_output__ = Sockets(
1✔
494
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
495
        )
496

497
    def output_types(
1✔
498
        self, **types: Any
499
    ) -> Callable[[Callable[RunParamsT, RunReturnT]], Callable[RunParamsT, RunReturnT]]:
500
        """
501
        Decorator factory that specifies the output types of a component.
502

503
        Use as:
504
        ```python
505
        @component
506
        class MyComponent:
507
            @component.output_types(output_1=int, output_2=str)
508
            def run(self, value: int):
509
                return {"output_1": 1, "output_2": "2"}
510
        ```
511
        """
512

513
        def output_types_decorator(run_method: Callable[RunParamsT, RunReturnT]) -> Callable[RunParamsT, RunReturnT]:
1✔
514
            """
515
            Decorator that sets the output types of the decorated method.
516

517
            This happens at class creation time, and since we don't have the decorated
518
            class available here, we temporarily store the output types as an attribute of
519
            the decorated method. The ComponentMeta metaclass will use this data to create
520
            sockets at instance creation time.
521
            """
522
            method_name = run_method.__name__
1✔
523
            if method_name not in ("run", "run_async"):
1✔
524
                raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
1✔
525

526
            setattr(
1✔
527
                run_method,
528
                "_output_types_cache",
529
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
530
            )
531
            return run_method
1✔
532

533
        return output_types_decorator
1✔
534

535
    def _component(self, cls: Any):
1✔
536
        """
537
        Decorator validating the structure of the component and registering it in the components registry.
538
        """
539
        logger.debug("Registering {component} as a component", component=cls)
1✔
540

541
        # Check for required methods and fail as soon as possible
542
        if not hasattr(cls, "run"):
1✔
543
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
544

545
        def copy_class_namespace(namespace):
1✔
546
            """
547
            This is the callback that `typing.new_class` will use to populate the newly created class.
548

549
            Simply copy the whole namespace from the decorated class.
550
            """
551
            for key, val in dict(cls.__dict__).items():
1✔
552
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
553
                if key in ("__dict__", "__weakref__"):
1✔
554
                    continue
1✔
555
                namespace[key] = val
1✔
556

557
        # Recreate the decorated component class so it uses our metaclass.
558
        # We must explicitly redefine the type of the class to make sure language servers
559
        # and type checkers understand that the class is of the correct type.
560
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
561
        new_cls: cls.__name__ = new_class(
1✔
562
            cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
563
        )  # type: ignore[no-redef]
564

565
        # Save the component in the class registry (for deserialization)
566
        class_path = f"{new_cls.__module__}.{new_cls.__name__}"
1✔
567
        if class_path in self.registry:
1✔
568
            # Corner case, but it may occur easily in notebooks when re-running cells.
569
            logger.debug(
1✔
570
                "Component {component} is already registered. Previous imported from '{module_name}', \
571
                new imported from '{new_module_name}'",
572
                component=class_path,
573
                module_name=self.registry[class_path],
574
                new_module_name=new_cls,
575
            )
576
        self.registry[class_path] = new_cls
1✔
577
        logger.debug("Registered Component {component}", component=new_cls)
1✔
578

579
        # Override the __repr__ method with a default one
580
        new_cls.__repr__ = _component_repr
1✔
581

582
        return new_cls
1✔
583

584
    def __call__(self, cls: Optional[type] = None):
1✔
585
        # We must wrap the call to the decorator in a function for it to work
586
        # correctly with or without parens
587
        def wrap(cls):
1✔
588
            return self._component(cls)
1✔
589

590
        if cls:
1✔
591
            # Decorator is called without parens
592
            return wrap(cls)
1✔
593

594
        # Decorator is called with parens
595
        return wrap
1✔
596

597

598
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc