• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10577969486

27 Aug 2024 12:20PM UTC coverage: 90.201% (+0.02%) from 90.185%
10577969486

push

github

web-flow
feat: Extend core component machinery to support an async run method (experimental) (#8279)

* feat: Extend core component machinery to support an async run method

* Add reno

* Fix incorrect docstring

* Make `async_run` a coroutine

* Make `supports_async` a dunder field

7005 of 7766 relevant lines covered (90.2%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.93
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
import sys
1✔
74
from collections.abc import Callable
1✔
75
from contextlib import contextmanager
1✔
76
from contextvars import ContextVar
1✔
77
from copy import deepcopy
1✔
78
from dataclasses import dataclass
1✔
79
from types import new_class
1✔
80
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
1✔
81

82
from haystack import logging
1✔
83
from haystack.core.errors import ComponentError
1✔
84

85
from .sockets import Sockets
1✔
86
from .types import InputSocket, OutputSocket, _empty
1✔
87

88
logger = logging.getLogger(__name__)
1✔
89

90

91
@dataclass
1✔
92
class PreInitHookPayload:
1✔
93
    """
94
    Payload for the hook called before a component instance is initialized.
95

96
    :param callback:
97
        Receives the following inputs: component class and init parameter keyword args.
98
    :param in_progress:
99
        Flag to indicate if the hook is currently being executed.
100
        Used to prevent it from being called recursively (if the component's constructor
101
        instantiates another component).
102
    """
103

104
    callback: Callable
1✔
105
    in_progress: bool = False
1✔
106

107

108
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
109

110

111
@contextmanager
1✔
112
def _hook_component_init(callback: Callable):
1✔
113
    """
114
    Context manager to set a callback that will be invoked before a component's constructor is called.
115

116
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
117
    parameters in place.
118

119
    :param callback:
120
        Callback function to invoke.
121
    """
122
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
123
    try:
1✔
124
        yield
1✔
125
    finally:
126
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
127

128

129
@runtime_checkable
1✔
130
class Component(Protocol):
1✔
131
    """
132
    Note this is only used by type checking tools.
133

134
    In order to implement the `Component` protocol, custom components need to
135
    have a `run` method. The signature of the method and its return value
136
    won't be checked, i.e. classes with the following methods:
137

138
        def run(self, param: str) -> Dict[str, Any]:
139
            ...
140

141
    and
142

143
        def run(self, **kwargs):
144
            ...
145

146
    will be both considered as respecting the protocol. This makes the type
147
    checking much weaker, but we have other places where we ensure code is
148
    dealing with actual Components.
149

150
    The protocol is runtime checkable so it'll be possible to assert:
151

152
        isinstance(MyComponent, Component)
153
    """
154

155
    # This is the most reliable way to define the protocol for the `run` method.
156
    # Defining a method doesn't work as different Components will have different
157
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
158
    # expected signature must be identical.
159
    # This makes most Language Servers and type checkers happy and shows less errors.
160
    # NOTE: This check can be removed when we drop Python 3.8 support.
161
    if sys.version_info >= (3, 9):
1✔
162
        run: Callable[..., Dict[str, Any]]
×
163
    else:
164
        run: Callable
1✔
165

166

167
class ComponentMeta(type):
1✔
168
    @staticmethod
1✔
169
    def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
170
        """
171
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
172
        """
173
        init_signature = inspect.signature(cls_type.__init__)
1✔
174
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
175

176
        out = {}
1✔
177
        for arg, (name, info) in zip(args, init_params.items()):
1✔
178
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
179
                raise ComponentError(
1✔
180
                    "Pre-init hooks do not support components with variadic positional args in their init method"
181
                )
182

183
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
184
            out[name] = arg
1✔
185
        return out
1✔
186

187
    @staticmethod
1✔
188
    def _parse_and_set_output_sockets(instance: Any):
1✔
189
        has_async_run = hasattr(instance, "async_run")
1✔
190

191
        # If `component.set_output_types()` was called in the component constructor,
192
        # `__haystack_output__` is already populated, no need to do anything.
193
        if not hasattr(instance, "__haystack_output__"):
1✔
194
            # If that's not the case, we need to populate `__haystack_output__`
195
            #
196
            # If either of the run methods were decorated, they'll have a field assigned that
197
            # stores the output specification. If both run methods were decorated, we ensure that
198
            # outputs are the same. We deepcopy the content of the cache to transfer ownership from
199
            # the class method to the actual instance, so that different instances of the same class
200
            # won't share this data.
201

202
            run_output_types = getattr(instance.run, "_output_types_cache", {})
1✔
203
            async_run_output_types = getattr(instance.async_run, "_output_types_cache", {}) if has_async_run else {}
1✔
204

205
            if has_async_run and run_output_types != async_run_output_types:
1✔
206
                raise ComponentError("Output type specifications of 'run' and 'async_run' methods must be the same")
1✔
207
            output_types_cache = run_output_types
1✔
208

209
            instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
1✔
210

211
    @staticmethod
1✔
212
    def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
1✔
213
        def inner(method, sockets):
1✔
214
            run_signature = inspect.signature(method)
1✔
215

216
            # First is 'self' and it doesn't matter.
217
            for param in list(run_signature.parameters)[1:]:
1✔
218
                if run_signature.parameters[param].kind not in (
1✔
219
                    inspect.Parameter.VAR_POSITIONAL,
220
                    inspect.Parameter.VAR_KEYWORD,
221
                ):  # ignore variable args
222
                    socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
1✔
223
                    if run_signature.parameters[param].default != inspect.Parameter.empty:
1✔
224
                        socket_kwargs["default_value"] = run_signature.parameters[param].default
1✔
225
                    sockets[param] = InputSocket(**socket_kwargs)
1✔
226

227
        # Create the sockets if set_input_types() wasn't called in the constructor.
228
        # If it was called and there are some parameters also in the `run()` method, these take precedence.
229
        if not hasattr(instance, "__haystack_input__"):
1✔
230
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
231

232
        inner(getattr(component_cls, "run"), instance.__haystack_input__)
1✔
233

234
        # Ensure that the sockets are the same for the async method, if it exists.
235
        async_run = getattr(component_cls, "async_run", None)
1✔
236
        if async_run is not None:
1✔
237
            run_sockets = Sockets(instance, {}, InputSocket)
1✔
238
            async_run_sockets = Sockets(instance, {}, InputSocket)
1✔
239

240
            # Can't use the sockets from above as they might contain
241
            # values set with set_input_types().
242
            inner(getattr(component_cls, "run"), run_sockets)
1✔
243
            inner(async_run, async_run_sockets)
1✔
244
            if async_run_sockets != run_sockets:
1✔
245
                raise ComponentError("Parameters of 'run' and 'async_run' methods must be the same")
×
246

247
    def __call__(cls, *args, **kwargs):
1✔
248
        """
249
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
250
        """
251
        # This will call __new__ then __init__, giving us back the Component instance
252
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
253
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
254
            instance = super().__call__(*args, **kwargs)
1✔
255
        else:
256
            try:
1✔
257
                pre_init_hook.in_progress = True
1✔
258
                named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
1✔
259
                assert (
1✔
260
                    set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
261
                ), "positional and keyword arguments overlap"
262
                kwargs.update(named_positional_args)
1✔
263
                pre_init_hook.callback(cls, kwargs)
1✔
264
                instance = super().__call__(**kwargs)
1✔
265
            finally:
266
                pre_init_hook.in_progress = False
1✔
267

268
        # Before returning, we have the chance to modify the newly created
269
        # Component instance, so we take the chance and set up the I/O sockets
270
        has_async_run = hasattr(instance, "async_run")
1✔
271
        if has_async_run and not inspect.iscoroutinefunction(instance.async_run):
1✔
272
            raise ComponentError(f"Method 'async_run' of component '{cls.__name__}' must be a coroutine")
1✔
273
        instance.__haystack_supports_async__ = has_async_run
1✔
274

275
        ComponentMeta._parse_and_set_input_sockets(cls, instance)
1✔
276
        ComponentMeta._parse_and_set_output_sockets(instance)
1✔
277

278
        # Since a Component can't be used in multiple Pipelines at the same time
279
        # we need to know if it's already owned by a Pipeline when adding it to one.
280
        # We use this flag to check that.
281
        instance.__haystack_added_to_pipeline__ = None
1✔
282

283
        # Only Components with variadic inputs can be greedy. If the user set the greedy flag
284
        # to True, but the component doesn't have a variadic input, we set it to False.
285
        # We can have this information only at instance creation time, so we do it here.
286
        is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
1✔
287
        if not is_variadic and cls.__haystack_is_greedy__:
1✔
288
            logger.warning(
1✔
289
                "Component '{component}' has no variadic input, but it's marked as greedy. "
290
                "This is not supported and can lead to unexpected behavior.",
291
                component=cls.__name__,
292
            )
293

294
        return instance
1✔
295

296

297
def _component_repr(component: Component) -> str:
1✔
298
    """
299
    All Components override their __repr__ method with this one.
300

301
    It prints the component name and the input/output sockets.
302
    """
303
    result = object.__repr__(component)
1✔
304
    if pipeline := getattr(component, "__haystack_added_to_pipeline__"):
1✔
305
        # This Component has been added in a Pipeline, let's get the name from there.
306
        result += f"\n{pipeline.get_component_name(component)}"
1✔
307

308
    # We're explicitly ignoring the type here because we're sure that the component
309
    # has the __haystack_input__ and __haystack_output__ attributes at this point
310
    return f"{result}\n{component.__haystack_input__}\n{component.__haystack_output__}"  # type: ignore[attr-defined]
1✔
311

312

313
class _Component:
1✔
314
    """
315
    See module's docstring.
316

317
    Args:
318
        class_: the class that Canals should use as a component.
319
        serializable: whether to check, at init time, if the component can be saved with
320
        `save_pipelines()`.
321

322
    Returns:
323
        A class that can be recognized as a component.
324

325
    Raises:
326
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
327
    """
328

329
    def __init__(self):
1✔
330
        self.registry = {}
1✔
331

332
    def set_input_type(
1✔
333
        self,
334
        instance,
335
        name: str,
336
        type: Any,  # noqa: A002
337
        default: Any = _empty,
338
    ):
339
        """
340
        Add a single input socket to the component instance.
341

342
        :param instance: Component instance where the input type will be added.
343
        :param name: name of the input socket.
344
        :param type: type of the input socket.
345
        :param default: default value of the input socket, defaults to _empty
346
        """
347
        if not hasattr(instance, "__haystack_input__"):
1✔
348
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
×
349
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
350

351
    def set_input_types(self, instance, **types):
1✔
352
        """
353
        Method that specifies the input types when 'kwargs' is passed to the run method.
354

355
        Use as:
356

357
        ```python
358
        @component
359
        class MyComponent:
360

361
            def __init__(self, value: int):
362
                component.set_input_types(self, value_1=str, value_2=str)
363
                ...
364

365
            @component.output_types(output_1=int, output_2=str)
366
            def run(self, **kwargs):
367
                return {"output_1": kwargs["value_1"], "output_2": ""}
368
        ```
369

370
        Note that if the `run()` method also specifies some parameters, those will take precedence.
371

372
        For example:
373

374
        ```python
375
        @component
376
        class MyComponent:
377

378
            def __init__(self, value: int):
379
                component.set_input_types(self, value_1=str, value_2=str)
380
                ...
381

382
            @component.output_types(output_1=int, output_2=str)
383
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
384
                return {"output_1": kwargs["value_1"], "output_2": ""}
385
        ```
386

387
        would add a mandatory `value_0` parameters, make the `value_1`
388
        parameter optional with a default None, and keep the `value_2`
389
        parameter mandatory as specified in `set_input_types`.
390

391
        """
392
        instance.__haystack_input__ = Sockets(
1✔
393
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
394
        )
395

396
    def set_output_types(self, instance, **types):
1✔
397
        """
398
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
399

400
        Use as:
401

402
        ```python
403
        @component
404
        class MyComponent:
405

406
            def __init__(self, value: int):
407
                component.set_output_types(self, output_1=int, output_2=str)
408
                ...
409

410
            # no decorators here
411
            def run(self, value: int):
412
                return {"output_1": 1, "output_2": "2"}
413
        ```
414
        """
415
        instance.__haystack_output__ = Sockets(
1✔
416
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
417
        )
418

419
    def output_types(self, **types):
1✔
420
        """
421
        Decorator factory that specifies the output types of a component.
422

423
        Use as:
424

425
        ```python
426
        @component
427
        class MyComponent:
428
            @component.output_types(output_1=int, output_2=str)
429
            def run(self, value: int):
430
                return {"output_1": 1, "output_2": "2"}
431
        ```
432
        """
433

434
        def output_types_decorator(run_method):
1✔
435
            """
436
            Decorator that sets the output types of the decorated method.
437

438
            This happens at class creation time, and since we don't have the decorated
439
            class available here, we temporarily store the output types as an attribute of
440
            the decorated method. The ComponentMeta metaclass will use this data to create
441
            sockets at instance creation time.
442
            """
443
            method_name = run_method.__name__
1✔
444
            if method_name not in ("run", "async_run"):
1✔
445
                raise ComponentError("'output_types' decorator can only be used on 'run' and `async_run` methods")
1✔
446

447
            setattr(
1✔
448
                run_method,
449
                "_output_types_cache",
450
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
451
            )
452
            return run_method
1✔
453

454
        return output_types_decorator
1✔
455

456
    def _component(self, cls, is_greedy: bool = False):
1✔
457
        """
458
        Decorator validating the structure of the component and registering it in the components registry.
459
        """
460
        logger.debug("Registering {component} as a component", component=cls)
1✔
461

462
        # Check for required methods and fail as soon as possible
463
        if not hasattr(cls, "run"):
1✔
464
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
465

466
        def copy_class_namespace(namespace):
1✔
467
            """
468
            This is the callback that `typing.new_class` will use to populate the newly created class.
469

470
            Simply copy the whole namespace from the decorated class.
471
            """
472
            for key, val in dict(cls.__dict__).items():
1✔
473
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
474
                if key in ("__dict__", "__weakref__"):
1✔
475
                    continue
1✔
476
                namespace[key] = val
1✔
477

478
        # Recreate the decorated component class so it uses our metaclass.
479
        # We must explicitly redefine the type of the class to make sure language servers
480
        # and type checkers understand that the class is of the correct type.
481
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
482
        cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)  # type: ignore[no-redef]
1✔
483

484
        # Save the component in the class registry (for deserialization)
485
        class_path = f"{cls.__module__}.{cls.__name__}"
1✔
486
        if class_path in self.registry:
1✔
487
            # Corner case, but it may occur easily in notebooks when re-running cells.
488
            logger.debug(
1✔
489
                "Component {component} is already registered. Previous imported from '{module_name}', \
490
                new imported from '{new_module_name}'",
491
                component=class_path,
492
                module_name=self.registry[class_path],
493
                new_module_name=cls,
494
            )
495
        self.registry[class_path] = cls
1✔
496
        logger.debug("Registered Component {component}", component=cls)
1✔
497

498
        # Override the __repr__ method with a default one
499
        cls.__repr__ = _component_repr
1✔
500

501
        # The greedy flag can be True only if the component has a variadic input.
502
        # At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
503
        # So we set it to whatever the user specified, during the instance creation we'll change it if needed
504
        # since we'll have access to the input sockets and check if any of them is variadic.
505
        setattr(cls, "__haystack_is_greedy__", is_greedy)
1✔
506

507
        return cls
1✔
508

509
    def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
1✔
510
        # We must wrap the call to the decorator in a function for it to work
511
        # correctly with or without parens
512
        def wrap(cls):
1✔
513
            return self._component(cls, is_greedy=is_greedy)
1✔
514

515
        if cls:
1✔
516
            # Decorator is called without parens
517
            return wrap(cls)
1✔
518

519
        # Decorator is called with parens
520
        return wrap
1✔
521

522

523
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc