• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 14793425610

02 May 2025 10:24AM UTC coverage: 90.513%. Remained the same
14793425610

Pull #9329

github

web-flow
Merge 01e63f4a2 into e3f9da13d
Pull Request #9329: experimenting with py.typed

10896 of 12038 relevant lines covered (90.51%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.44
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
from collections.abc import Callable, Coroutine
1✔
74
from contextlib import contextmanager
1✔
75
from contextvars import ContextVar
1✔
76
from copy import deepcopy
1✔
77
from dataclasses import dataclass
1✔
78
from types import new_class
1✔
79
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, Union, runtime_checkable
1✔
80

81
from typing_extensions import ParamSpec
1✔
82

83
from haystack import logging
1✔
84
from haystack.core.errors import ComponentError
1✔
85

86
from .sockets import Sockets
1✔
87
from .types import InputSocket, OutputSocket, _empty
1✔
88

89
logger = logging.getLogger(__name__)
1✔
90

91
RunParamsT = ParamSpec("RunParamsT")
1✔
92
SyncRunReturnT = TypeVar("SyncRunReturnT", bound=Dict[str, Any])
1✔
93
AsyncRunReturnT = TypeVar("AsyncRunReturnT", bound=Coroutine[Any, Any, Dict[str, Any]])
1✔
94
RunReturnT = Union[SyncRunReturnT, AsyncRunReturnT]
1✔
95

96
ClassT = TypeVar("ClassT", bound=type)
1✔
97

98

99
@dataclass
1✔
100
class PreInitHookPayload:
1✔
101
    """
102
    Payload for the hook called before a component instance is initialized.
103

104
    :param callback:
105
        Receives the following inputs: component class and init parameter keyword args.
106
    :param in_progress:
107
        Flag to indicate if the hook is currently being executed.
108
        Used to prevent it from being called recursively (if the component's constructor
109
        instantiates another component).
110
    """
111

112
    callback: Callable
1✔
113
    in_progress: bool = False
1✔
114

115

116
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
117

118

119
@contextmanager
1✔
120
def _hook_component_init(callback: Callable):
1✔
121
    """
122
    Context manager to set a callback that will be invoked before a component's constructor is called.
123

124
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
125
    parameters in place.
126

127
    :param callback:
128
        Callback function to invoke.
129
    """
130
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
131
    try:
1✔
132
        yield
1✔
133
    finally:
134
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
135

136

137
@runtime_checkable
1✔
138
class Component(Protocol):
1✔
139
    """
140
    Note this is only used by type checking tools.
141

142
    In order to implement the `Component` protocol, custom components need to
143
    have a `run` method. The signature of the method and its return value
144
    won't be checked, i.e. classes with the following methods:
145

146
        def run(self, param: str) -> Dict[str, Any]:
147
            ...
148

149
    and
150

151
        def run(self, **kwargs):
152
            ...
153

154
    will be both considered as respecting the protocol. This makes the type
155
    checking much weaker, but we have other places where we ensure code is
156
    dealing with actual Components.
157

158
    The protocol is runtime checkable so it'll be possible to assert:
159

160
        isinstance(MyComponent, Component)
161
    """
162

163
    # This is the most reliable way to define the protocol for the `run` method.
164
    # Defining a method doesn't work as different Components will have different
165
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
166
    # expected signature must be identical.
167
    # This makes most Language Servers and type checkers happy and shows less errors.
168
    run: Callable[..., Dict[str, Any]]
1✔
169

170

171
class ComponentMeta(type):
1✔
172
    @staticmethod
1✔
173
    def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
174
        """
175
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
176
        """
177
        init_signature = inspect.signature(cls_type.__init__)
1✔
178
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
179

180
        out = {}
1✔
181
        for arg, (name, info) in zip(args, init_params.items()):
1✔
182
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
183
                raise ComponentError(
1✔
184
                    "Pre-init hooks do not support components with variadic positional args in their init method"
185
                )
186

187
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
188
            out[name] = arg
1✔
189
        return out
1✔
190

191
    @staticmethod
1✔
192
    def _parse_and_set_output_sockets(instance: Any):
1✔
193
        has_async_run = hasattr(instance, "run_async")
1✔
194

195
        # If `component.set_output_types()` was called in the component constructor,
196
        # `__haystack_output__` is already populated, no need to do anything.
197
        if not hasattr(instance, "__haystack_output__"):
1✔
198
            # If that's not the case, we need to populate `__haystack_output__`
199
            #
200
            # If either of the run methods were decorated, they'll have a field assigned that
201
            # stores the output specification. If both run methods were decorated, we ensure that
202
            # outputs are the same. We deepcopy the content of the cache to transfer ownership from
203
            # the class method to the actual instance, so that different instances of the same class
204
            # won't share this data.
205

206
            run_output_types = getattr(instance.run, "_output_types_cache", {})
1✔
207
            async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
1✔
208

209
            if has_async_run and run_output_types != async_run_output_types:
1✔
210
                raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
1✔
211
            output_types_cache = run_output_types
1✔
212

213
            instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
1✔
214

215
    @staticmethod
1✔
216
    def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
1✔
217
        def inner(method, sockets):
1✔
218
            from inspect import Parameter
1✔
219

220
            run_signature = inspect.signature(method)
1✔
221

222
            for param_name, param_info in run_signature.parameters.items():
1✔
223
                if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
1✔
224
                    continue
1✔
225

226
                socket_kwargs = {"name": param_name, "type": param_info.annotation}
1✔
227
                if param_info.default != Parameter.empty:
1✔
228
                    socket_kwargs["default_value"] = param_info.default
1✔
229

230
                new_socket = InputSocket(**socket_kwargs)
1✔
231

232
                # Also ensure that new sockets don't override existing ones.
233
                existing_socket = sockets.get(param_name)
1✔
234
                if existing_socket is not None and existing_socket != new_socket:
1✔
235
                    raise ComponentError(
1✔
236
                        "set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
237
                    )
238

239
                sockets[param_name] = new_socket
1✔
240

241
            return run_signature
1✔
242

243
        # Create the sockets if set_input_types() wasn't called in the constructor.
244
        if not hasattr(instance, "__haystack_input__"):
1✔
245
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
246

247
        inner(getattr(component_cls, "run"), instance.__haystack_input__)
1✔
248

249
        # Ensure that the sockets are the same for the async method, if it exists.
250
        async_run = getattr(component_cls, "run_async", None)
1✔
251
        if async_run is not None:
1✔
252
            run_sockets = Sockets(instance, {}, InputSocket)
1✔
253
            async_run_sockets = Sockets(instance, {}, InputSocket)
1✔
254

255
            # Can't use the sockets from above as they might contain
256
            # values set with set_input_types().
257
            run_sig = inner(getattr(component_cls, "run"), run_sockets)
1✔
258
            async_run_sig = inner(async_run, async_run_sockets)
1✔
259

260
            if async_run_sockets != run_sockets or run_sig != async_run_sig:
1✔
261
                sig_diff = _compare_run_methods_signatures(run_sig, async_run_sig)
1✔
262
                raise ComponentError(
1✔
263
                    f"Parameters of 'run' and 'run_async' methods must be the same.\nDifferences found:\n{sig_diff}"
264
                )
265

266
    def __call__(cls, *args, **kwargs):
1✔
267
        """
268
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
269
        """
270
        # This will call __new__ then __init__, giving us back the Component instance
271
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
272
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
273
            instance = super().__call__(*args, **kwargs)
1✔
274
        else:
275
            try:
1✔
276
                pre_init_hook.in_progress = True
1✔
277
                named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
1✔
278
                assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), (
1✔
279
                    "positional and keyword arguments overlap"
280
                )
281
                kwargs.update(named_positional_args)
1✔
282
                pre_init_hook.callback(cls, kwargs)
1✔
283
                instance = super().__call__(**kwargs)
1✔
284
            finally:
285
                pre_init_hook.in_progress = False
1✔
286

287
        # Before returning, we have the chance to modify the newly created
288
        # Component instance, so we take the chance and set up the I/O sockets
289
        has_async_run = hasattr(instance, "run_async")
1✔
290
        if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
1✔
291
            raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
1✔
292
        instance.__haystack_supports_async__ = has_async_run
1✔
293

294
        ComponentMeta._parse_and_set_input_sockets(cls, instance)
1✔
295
        ComponentMeta._parse_and_set_output_sockets(instance)
1✔
296

297
        # Since a Component can't be used in multiple Pipelines at the same time
298
        # we need to know if it's already owned by a Pipeline when adding it to one.
299
        # We use this flag to check that.
300
        instance.__haystack_added_to_pipeline__ = None
1✔
301

302
        return instance
1✔
303

304

305
def _component_repr(component: Component) -> str:
1✔
306
    """
307
    All Components override their __repr__ method with this one.
308

309
    It prints the component name and the input/output sockets.
310
    """
311
    result = object.__repr__(component)
1✔
312
    if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
1✔
313
        # This Component has been added in a Pipeline, let's get the name from there.
314
        result += f"\n{pipeline.get_component_name(component)}"
1✔
315

316
    # We're explicitly ignoring the type here because we're sure that the component
317
    # has the __haystack_input__ and __haystack_output__ attributes at this point
318
    return (
1✔
319
        f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}"
320
        f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}"
321
    )
322

323

324
def _component_run_has_kwargs(component_cls: Type) -> bool:
1✔
325
    run_method = getattr(component_cls, "run", None)
1✔
326
    if run_method is None:
1✔
327
        return False
×
328
    else:
329
        return any(
1✔
330
            param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
331
        )
332

333

334
def _compare_run_methods_signatures(run_sig: inspect.Signature, async_run_sig: inspect.Signature) -> str:
1✔
335
    """
336
    Builds a detailed error message with the differences between the signatures of the run and run_async methods.
337

338
    :param run_sig: The signature of the run method
339
    :param async_run_sig: The signature of the run_async method
340

341
    :returns:
342
        A detailed error message if signatures don't match, empty string if they do
343
    """
344
    differences = []
1✔
345
    run_params = list(run_sig.parameters.items())
1✔
346
    async_params = list(async_run_sig.parameters.items())
1✔
347

348
    if len(run_params) != len(async_params):
1✔
349
        differences.append(
1✔
350
            f"Different number of parameters: run has {len(run_params)}, run_async has {len(async_params)}"
351
        )
352

353
    for (run_name, run_param), (async_name, async_param) in zip(run_params, async_params):
1✔
354
        if run_name != async_name:
1✔
355
            differences.append(f"Parameter name mismatch: {run_name} vs {async_name}")
1✔
356

357
        if run_param.annotation != async_param.annotation:
1✔
358
            differences.append(
1✔
359
                f"Parameter '{run_name}' type mismatch: {run_param.annotation} vs {async_param.annotation}"
360
            )
361

362
        if run_param.default != async_param.default:
1✔
363
            differences.append(
1✔
364
                f"Parameter '{run_name}' default value mismatch: {run_param.default} vs {async_param.default}"
365
            )
366

367
        if run_param.kind != async_param.kind:
1✔
368
            differences.append(
1✔
369
                f"Parameter '{run_name}' kind (POSITIONAL, KEYWORD, etc.) mismatch: "
370
                f"{run_param.kind} vs {async_param.kind}"
371
            )
372

373
    return "\n".join(differences)
1✔
374

375

376
class _Component:
1✔
377
    """
378
    See module's docstring.
379

380
    Args:
381
        cls: the class that should be used as a component.
382

383
    Returns:
384
        A class that can be recognized as a component.
385

386
    Raises:
387
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
388
    """
389

390
    def __init__(self):
1✔
391
        self.registry = {}
1✔
392

393
    def set_input_type(
1✔
394
        self,
395
        instance,
396
        name: str,
397
        type: Any,  # noqa: A002
398
        default: Any = _empty,
399
    ):
400
        """
401
        Add a single input socket to the component instance.
402

403
        Replaces any existing input socket with the same name.
404

405
        :param instance: Component instance where the input type will be added.
406
        :param name: name of the input socket.
407
        :param type: type of the input socket.
408
        :param default: default value of the input socket, defaults to _empty
409
        """
410
        if not _component_run_has_kwargs(instance.__class__):
1✔
411
            raise ComponentError(
1✔
412
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
413
            )
414

415
        if not hasattr(instance, "__haystack_input__"):
1✔
416
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
417
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
418

419
    def set_input_types(self, instance, **types):
1✔
420
        """
421
        Method that specifies the input types when 'kwargs' is passed to the run method.
422

423
        Use as:
424

425
        ```python
426
        @component
427
        class MyComponent:
428

429
            def __init__(self, value: int):
430
                component.set_input_types(self, value_1=str, value_2=str)
431
                ...
432

433
            @component.output_types(output_1=int, output_2=str)
434
            def run(self, **kwargs):
435
                return {"output_1": kwargs["value_1"], "output_2": ""}
436
        ```
437

438
        Note that if the `run()` method also specifies some parameters, those will take precedence.
439

440
        For example:
441

442
        ```python
443
        @component
444
        class MyComponent:
445

446
            def __init__(self, value: int):
447
                component.set_input_types(self, value_1=str, value_2=str)
448
                ...
449

450
            @component.output_types(output_1=int, output_2=str)
451
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
452
                return {"output_1": kwargs["value_1"], "output_2": ""}
453
        ```
454

455
        would add a mandatory `value_0` parameters, make the `value_1`
456
        parameter optional with a default None, and keep the `value_2`
457
        parameter mandatory as specified in `set_input_types`.
458

459
        """
460
        if not _component_run_has_kwargs(instance.__class__):
1✔
461
            raise ComponentError(
1✔
462
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
463
            )
464

465
        instance.__haystack_input__ = Sockets(
1✔
466
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
467
        )
468

469
    def set_output_types(self, instance, **types):
1✔
470
        """
471
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
472

473
        Use as:
474

475
        ```python
476
        @component
477
        class MyComponent:
478

479
            def __init__(self, value: int):
480
                component.set_output_types(self, output_1=int, output_2=str)
481
                ...
482

483
            # no decorators here
484
            def run(self, value: int):
485
                return {"output_1": 1, "output_2": "2"}
486
        ```
487
        """
488
        has_decorator = hasattr(instance.run, "_output_types_cache")
1✔
489
        if has_decorator:
1✔
490
            raise ComponentError(
1✔
491
                "Cannot call `set_output_types` on a component that already has "
492
                "the 'output_types' decorator on its `run` method"
493
            )
494

495
        instance.__haystack_output__ = Sockets(
1✔
496
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
497
        )
498

499
    def output_types(
1✔
500
        self, **types: Any
501
    ) -> Callable[[Callable[RunParamsT, RunReturnT]], Callable[RunParamsT, RunReturnT]]:
502
        """
503
        Decorator factory that specifies the output types of a component.
504

505
        Use as:
506
        ```python
507
        @component
508
        class MyComponent:
509
            @component.output_types(output_1=int, output_2=str)
510
            def run(self, value: int):
511
                return {"output_1": 1, "output_2": "2"}
512
        ```
513
        """
514

515
        def output_types_decorator(run_method: Callable[RunParamsT, RunReturnT]) -> Callable[RunParamsT, RunReturnT]:
1✔
516
            """
517
            Decorator that sets the output types of the decorated method.
518

519
            This happens at class creation time, and since we don't have the decorated
520
            class available here, we temporarily store the output types as an attribute of
521
            the decorated method. The ComponentMeta metaclass will use this data to create
522
            sockets at instance creation time.
523
            """
524
            method_name = run_method.__name__
1✔
525
            if method_name not in ("run", "run_async"):
1✔
526
                raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
1✔
527

528
            setattr(
1✔
529
                run_method,
530
                "_output_types_cache",
531
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
532
            )
533
            return run_method
1✔
534

535
        return output_types_decorator
1✔
536

537
    def _component(self, cls: Type[ClassT]) -> Type[ClassT]:
1✔
538
        """
539
        Decorator validating the structure of the component and registering it in the components registry.
540
        """
541
        logger.debug("Registering {component} as a component", component=cls)
1✔
542

543
        # Check for required methods and fail as soon as possible
544
        if not hasattr(cls, "run"):
1✔
545
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
546

547
        def copy_class_namespace(namespace):
1✔
548
            """
549
            This is the callback that `typing.new_class` will use to populate the newly created class.
550

551
            Simply copy the whole namespace from the decorated class.
552
            """
553
            for key, val in dict(cls.__dict__).items():
1✔
554
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
555
                if key in ("__dict__", "__weakref__"):
1✔
556
                    continue
1✔
557
                namespace[key] = val
1✔
558

559
        # Recreate the decorated component class so it uses our metaclass.
560
        # We must explicitly redefine the type of the class to make sure language servers
561
        # and type checkers understand that the class is of the correct type.
562
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
563
        new_cls: cls.__name__ = new_class(
1✔
564
            cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
565
        )
566

567
        # Save the component in the class registry (for deserialization)
568
        class_path = f"{new_cls.__module__}.{new_cls.__name__}"
1✔
569
        if class_path in self.registry:
1✔
570
            # Corner case, but it may occur easily in notebooks when re-running cells.
571
            logger.debug(
1✔
572
                "Component {component} is already registered. Previous imported from '{module_name}', \
573
                new imported from '{new_module_name}'",
574
                component=class_path,
575
                module_name=self.registry[class_path],
576
                new_module_name=new_cls,
577
            )
578
        self.registry[class_path] = new_cls
1✔
579
        logger.debug("Registered Component {component}", component=new_cls)
1✔
580

581
        # Override the __repr__ method with a default one
582
        new_cls.__repr__ = _component_repr
1✔
583

584
        return new_cls
1✔
585

586
    def __call__(self, cls: Optional[type] = None):
1✔
587
        # We must wrap the call to the decorator in a function for it to work
588
        # correctly with or without parens
589
        def wrap(cls):
1✔
590
            return self._component(cls)
1✔
591

592
        if cls:
1✔
593
            # Decorator is called without parens
594
            return wrap(cls)
1✔
595

596
        # Decorator is called with parens
597
        return wrap
1✔
598

599

600
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc