• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13972131258

20 Mar 2025 02:43PM UTC coverage: 90.021% (-0.03%) from 90.054%
13972131258

Pull #9069

github

web-flow
Merge 8371761b0 into 67ab3788e
Pull Request #9069: refactor!: `ChatMessage` serialization-deserialization updates

9833 of 10923 relevant lines covered (90.02%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

99.38
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Haystack to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
from collections.abc import Callable
1✔
74
from contextlib import contextmanager
1✔
75
from contextvars import ContextVar
1✔
76
from copy import deepcopy
1✔
77
from dataclasses import dataclass
1✔
78
from types import new_class
1✔
79
from typing import Any, Dict, Optional, Protocol, Type, TypeVar, runtime_checkable
1✔
80

81
from typing_extensions import ParamSpec
1✔
82

83
from haystack import logging
1✔
84
from haystack.core.errors import ComponentError
1✔
85

86
from .sockets import Sockets
1✔
87
from .types import InputSocket, OutputSocket, _empty
1✔
88

89
logger = logging.getLogger(__name__)
1✔
90

91
P = ParamSpec("P")
1✔
92
R = TypeVar("R", bound=Dict[str, Any])
1✔
93

94

95
@dataclass
1✔
96
class PreInitHookPayload:
1✔
97
    """
98
    Payload for the hook called before a component instance is initialized.
99

100
    :param callback:
101
        Receives the following inputs: component class and init parameter keyword args.
102
    :param in_progress:
103
        Flag to indicate if the hook is currently being executed.
104
        Used to prevent it from being called recursively (if the component's constructor
105
        instantiates another component).
106
    """
107

108
    callback: Callable
1✔
109
    in_progress: bool = False
1✔
110

111

112
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
113

114

115
@contextmanager
1✔
116
def _hook_component_init(callback: Callable):
1✔
117
    """
118
    Context manager to set a callback that will be invoked before a component's constructor is called.
119

120
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
121
    parameters in place.
122

123
    :param callback:
124
        Callback function to invoke.
125
    """
126
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
127
    try:
1✔
128
        yield
1✔
129
    finally:
130
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
131

132

133
@runtime_checkable
1✔
134
class Component(Protocol):
1✔
135
    """
136
    Note this is only used by type checking tools.
137

138
    In order to implement the `Component` protocol, custom components need to
139
    have a `run` method. The signature of the method and its return value
140
    won't be checked, i.e. classes with the following methods:
141

142
        def run(self, param: str) -> Dict[str, Any]:
143
            ...
144

145
    and
146

147
        def run(self, **kwargs):
148
            ...
149

150
    will be both considered as respecting the protocol. This makes the type
151
    checking much weaker, but we have other places where we ensure code is
152
    dealing with actual Components.
153

154
    The protocol is runtime checkable so it'll be possible to assert:
155

156
        isinstance(MyComponent, Component)
157
    """
158

159
    # This is the most reliable way to define the protocol for the `run` method.
160
    # Defining a method doesn't work as different Components will have different
161
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
162
    # expected signature must be identical.
163
    # This makes most Language Servers and type checkers happy and shows less errors.
164
    run: Callable[..., Dict[str, Any]]
1✔
165

166

167
class ComponentMeta(type):
1✔
168
    @staticmethod
1✔
169
    def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
170
        """
171
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
172
        """
173
        init_signature = inspect.signature(cls_type.__init__)
1✔
174
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
175

176
        out = {}
1✔
177
        for arg, (name, info) in zip(args, init_params.items()):
1✔
178
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
179
                raise ComponentError(
1✔
180
                    "Pre-init hooks do not support components with variadic positional args in their init method"
181
                )
182

183
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
184
            out[name] = arg
1✔
185
        return out
1✔
186

187
    @staticmethod
1✔
188
    def _parse_and_set_output_sockets(instance: Any):
1✔
189
        has_async_run = hasattr(instance, "run_async")
1✔
190

191
        # If `component.set_output_types()` was called in the component constructor,
192
        # `__haystack_output__` is already populated, no need to do anything.
193
        if not hasattr(instance, "__haystack_output__"):
1✔
194
            # If that's not the case, we need to populate `__haystack_output__`
195
            #
196
            # If either of the run methods were decorated, they'll have a field assigned that
197
            # stores the output specification. If both run methods were decorated, we ensure that
198
            # outputs are the same. We deepcopy the content of the cache to transfer ownership from
199
            # the class method to the actual instance, so that different instances of the same class
200
            # won't share this data.
201

202
            run_output_types = getattr(instance.run, "_output_types_cache", {})
1✔
203
            async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
1✔
204

205
            if has_async_run and run_output_types != async_run_output_types:
1✔
206
                raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
1✔
207
            output_types_cache = run_output_types
1✔
208

209
            instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
1✔
210

211
    @staticmethod
1✔
212
    def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
1✔
213
        def inner(method, sockets):
1✔
214
            from inspect import Parameter
1✔
215

216
            run_signature = inspect.signature(method)
1✔
217

218
            for param_name, param_info in run_signature.parameters.items():
1✔
219
                if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
1✔
220
                    continue
1✔
221

222
                socket_kwargs = {"name": param_name, "type": param_info.annotation}
1✔
223
                if param_info.default != Parameter.empty:
1✔
224
                    socket_kwargs["default_value"] = param_info.default
1✔
225

226
                new_socket = InputSocket(**socket_kwargs)
1✔
227

228
                # Also ensure that new sockets don't override existing ones.
229
                existing_socket = sockets.get(param_name)
1✔
230
                if existing_socket is not None and existing_socket != new_socket:
1✔
231
                    raise ComponentError(
1✔
232
                        "set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
233
                    )
234

235
                sockets[param_name] = new_socket
1✔
236

237
            return run_signature
1✔
238

239
        # Create the sockets if set_input_types() wasn't called in the constructor.
240
        if not hasattr(instance, "__haystack_input__"):
1✔
241
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
242

243
        inner(getattr(component_cls, "run"), instance.__haystack_input__)
1✔
244

245
        # Ensure that the sockets are the same for the async method, if it exists.
246
        async_run = getattr(component_cls, "run_async", None)
1✔
247
        if async_run is not None:
1✔
248
            run_sockets = Sockets(instance, {}, InputSocket)
1✔
249
            async_run_sockets = Sockets(instance, {}, InputSocket)
1✔
250

251
            # Can't use the sockets from above as they might contain
252
            # values set with set_input_types().
253
            run_sig = inner(getattr(component_cls, "run"), run_sockets)
1✔
254
            async_run_sig = inner(async_run, async_run_sockets)
1✔
255

256
            if async_run_sockets != run_sockets or run_sig != async_run_sig:
1✔
257
                raise ComponentError("Parameters of 'run' and 'run_async' methods must be the same")
1✔
258

259
    def __call__(cls, *args, **kwargs):
1✔
260
        """
261
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
262
        """
263
        # This will call __new__ then __init__, giving us back the Component instance
264
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
265
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
266
            instance = super().__call__(*args, **kwargs)
1✔
267
        else:
268
            try:
1✔
269
                pre_init_hook.in_progress = True
1✔
270
                named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
1✔
271
                assert set(named_positional_args.keys()).intersection(kwargs.keys()) == set(), (
1✔
272
                    "positional and keyword arguments overlap"
273
                )
274
                kwargs.update(named_positional_args)
1✔
275
                pre_init_hook.callback(cls, kwargs)
1✔
276
                instance = super().__call__(**kwargs)
1✔
277
            finally:
278
                pre_init_hook.in_progress = False
1✔
279

280
        # Before returning, we have the chance to modify the newly created
281
        # Component instance, so we take the chance and set up the I/O sockets
282
        has_async_run = hasattr(instance, "run_async")
1✔
283
        if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
1✔
284
            raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
1✔
285
        instance.__haystack_supports_async__ = has_async_run
1✔
286

287
        ComponentMeta._parse_and_set_input_sockets(cls, instance)
1✔
288
        ComponentMeta._parse_and_set_output_sockets(instance)
1✔
289

290
        # Since a Component can't be used in multiple Pipelines at the same time
291
        # we need to know if it's already owned by a Pipeline when adding it to one.
292
        # We use this flag to check that.
293
        instance.__haystack_added_to_pipeline__ = None
1✔
294

295
        return instance
1✔
296

297

298
def _component_repr(component: Component) -> str:
1✔
299
    """
300
    All Components override their __repr__ method with this one.
301

302
    It prints the component name and the input/output sockets.
303
    """
304
    result = object.__repr__(component)
1✔
305
    if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
1✔
306
        # This Component has been added in a Pipeline, let's get the name from there.
307
        result += f"\n{pipeline.get_component_name(component)}"
1✔
308

309
    # We're explicitly ignoring the type here because we're sure that the component
310
    # has the __haystack_input__ and __haystack_output__ attributes at this point
311
    return (
1✔
312
        f"{result}\n{getattr(component, '__haystack_input__', '<invalid_input_sockets>')}"
313
        f"\n{getattr(component, '__haystack_output__', '<invalid_output_sockets>')}"
314
    )
315

316

317
def _component_run_has_kwargs(component_cls: Type) -> bool:
1✔
318
    run_method = getattr(component_cls, "run", None)
1✔
319
    if run_method is None:
1✔
320
        return False
×
321
    else:
322
        return any(
1✔
323
            param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
324
        )
325

326

327
class _Component:
1✔
328
    """
329
    See module's docstring.
330

331
    Args:
332
        cls: the class that should be used as a component.
333

334
    Returns:
335
        A class that can be recognized as a component.
336

337
    Raises:
338
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
339
    """
340

341
    def __init__(self):
1✔
342
        self.registry = {}
1✔
343

344
    def set_input_type(
1✔
345
        self,
346
        instance,
347
        name: str,
348
        type: Any,  # noqa: A002
349
        default: Any = _empty,
350
    ):
351
        """
352
        Add a single input socket to the component instance.
353

354
        Replaces any existing input socket with the same name.
355

356
        :param instance: Component instance where the input type will be added.
357
        :param name: name of the input socket.
358
        :param type: type of the input socket.
359
        :param default: default value of the input socket, defaults to _empty
360
        """
361
        if not _component_run_has_kwargs(instance.__class__):
1✔
362
            raise ComponentError(
1✔
363
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
364
            )
365

366
        if not hasattr(instance, "__haystack_input__"):
1✔
367
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
368
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
369

370
    def set_input_types(self, instance, **types):
1✔
371
        """
372
        Method that specifies the input types when 'kwargs' is passed to the run method.
373

374
        Use as:
375

376
        ```python
377
        @component
378
        class MyComponent:
379

380
            def __init__(self, value: int):
381
                component.set_input_types(self, value_1=str, value_2=str)
382
                ...
383

384
            @component.output_types(output_1=int, output_2=str)
385
            def run(self, **kwargs):
386
                return {"output_1": kwargs["value_1"], "output_2": ""}
387
        ```
388

389
        Note that if the `run()` method also specifies some parameters, those will take precedence.
390

391
        For example:
392

393
        ```python
394
        @component
395
        class MyComponent:
396

397
            def __init__(self, value: int):
398
                component.set_input_types(self, value_1=str, value_2=str)
399
                ...
400

401
            @component.output_types(output_1=int, output_2=str)
402
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
403
                return {"output_1": kwargs["value_1"], "output_2": ""}
404
        ```
405

406
        would add a mandatory `value_0` parameters, make the `value_1`
407
        parameter optional with a default None, and keep the `value_2`
408
        parameter mandatory as specified in `set_input_types`.
409

410
        """
411
        if not _component_run_has_kwargs(instance.__class__):
1✔
412
            raise ComponentError(
1✔
413
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
414
            )
415

416
        instance.__haystack_input__ = Sockets(
1✔
417
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
418
        )
419

420
    def set_output_types(self, instance, **types):
1✔
421
        """
422
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
423

424
        Use as:
425

426
        ```python
427
        @component
428
        class MyComponent:
429

430
            def __init__(self, value: int):
431
                component.set_output_types(self, output_1=int, output_2=str)
432
                ...
433

434
            # no decorators here
435
            def run(self, value: int):
436
                return {"output_1": 1, "output_2": "2"}
437
        ```
438
        """
439
        has_decorator = hasattr(instance.run, "_output_types_cache")
1✔
440
        if has_decorator:
1✔
441
            raise ComponentError(
1✔
442
                "Cannot call `set_output_types` on a component that already has "
443
                "the 'output_types' decorator on its `run` method"
444
            )
445

446
        instance.__haystack_output__ = Sockets(
1✔
447
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
448
        )
449

450
    def output_types(self, **types: Any) -> Callable[[Callable[P, R]], Callable[P, R]]:
1✔
451
        """
452
        Decorator factory that specifies the output types of a component.
453

454
        Use as:
455

456
        ```python
457
        @component
458
        class MyComponent:
459
            @component.output_types(output_1=int, output_2=str)
460
            def run(self, value: int):
461
                return {"output_1": 1, "output_2": "2"}
462
        ```
463
        """
464

465
        def output_types_decorator(run_method: Callable[P, R]) -> Callable[P, R]:
1✔
466
            """
467
            Decorator that sets the output types of the decorated method.
468

469
            This happens at class creation time, and since we don't have the decorated
470
            class available here, we temporarily store the output types as an attribute of
471
            the decorated method. The ComponentMeta metaclass will use this data to create
472
            sockets at instance creation time.
473
            """
474
            method_name = run_method.__name__
1✔
475
            if method_name not in ("run", "run_async"):
1✔
476
                raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
1✔
477

478
            setattr(
1✔
479
                run_method,
480
                "_output_types_cache",
481
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
482
            )
483
            return run_method
1✔
484

485
        return output_types_decorator
1✔
486

487
    def _component(self, cls: Any):
1✔
488
        """
489
        Decorator validating the structure of the component and registering it in the components registry.
490
        """
491
        logger.debug("Registering {component} as a component", component=cls)
1✔
492

493
        # Check for required methods and fail as soon as possible
494
        if not hasattr(cls, "run"):
1✔
495
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
496

497
        def copy_class_namespace(namespace):
1✔
498
            """
499
            This is the callback that `typing.new_class` will use to populate the newly created class.
500

501
            Simply copy the whole namespace from the decorated class.
502
            """
503
            for key, val in dict(cls.__dict__).items():
1✔
504
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
505
                if key in ("__dict__", "__weakref__"):
1✔
506
                    continue
1✔
507
                namespace[key] = val
1✔
508

509
        # Recreate the decorated component class so it uses our metaclass.
510
        # We must explicitly redefine the type of the class to make sure language servers
511
        # and type checkers understand that the class is of the correct type.
512
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
513
        new_cls: cls.__name__ = new_class(
1✔
514
            cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace
515
        )  # type: ignore[no-redef]
516

517
        # Save the component in the class registry (for deserialization)
518
        class_path = f"{new_cls.__module__}.{new_cls.__name__}"
1✔
519
        if class_path in self.registry:
1✔
520
            # Corner case, but it may occur easily in notebooks when re-running cells.
521
            logger.debug(
1✔
522
                "Component {component} is already registered. Previous imported from '{module_name}', \
523
                new imported from '{new_module_name}'",
524
                component=class_path,
525
                module_name=self.registry[class_path],
526
                new_module_name=new_cls,
527
            )
528
        self.registry[class_path] = new_cls
1✔
529
        logger.debug("Registered Component {component}", component=new_cls)
1✔
530

531
        # Override the __repr__ method with a default one
532
        new_cls.__repr__ = _component_repr
1✔
533

534
        return new_cls
1✔
535

536
    def __call__(self, cls: Optional[type] = None):
1✔
537
        # We must wrap the call to the decorator in a function for it to work
538
        # correctly with or without parens
539
        def wrap(cls):
1✔
540
            return self._component(cls)
1✔
541

542
        if cls:
1✔
543
            # Decorator is called without parens
544
            return wrap(cls)
1✔
545

546
        # Decorator is called with parens
547
        return wrap
1✔
548

549

550
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc