• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 9610677276

21 Jun 2024 08:29AM UTC coverage: 90.181% (+0.01%) from 90.171%
9610677276

push

github

web-flow
fix: Prevent component pre-init hook from being called recursively (#7894)

7026 of 7791 relevant lines covered (90.18%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

98.28
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
import sys
1✔
74
from collections.abc import Callable
1✔
75
from contextlib import contextmanager
1✔
76
from contextvars import ContextVar
1✔
77
from copy import deepcopy
1✔
78
from dataclasses import dataclass
1✔
79
from types import new_class
1✔
80
from typing import Any, Dict, Optional, Protocol, runtime_checkable
1✔
81

82
from haystack import logging
1✔
83
from haystack.core.errors import ComponentError
1✔
84

85
from .sockets import Sockets
1✔
86
from .types import InputSocket, OutputSocket, _empty
1✔
87

88
logger = logging.getLogger(__name__)
1✔
89

90

91
@dataclass
1✔
92
class PreInitHookPayload:
1✔
93
    """
94
    Payload for the hook called before a component instance is initialized.
95

96
    :param callback:
97
        Receives the following inputs: component class and init parameter keyword args.
98
    :param in_progress:
99
        Flag to indicate if the hook is currently being executed.
100
        Used to prevent it from being called recursively (if the component's constructor
101
        instantiates another component).
102
    """
103

104
    callback: Callable
1✔
105
    in_progress: bool = False
1✔
106

107

108
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
109

110

111
@contextmanager
1✔
112
def _hook_component_init(callback: Callable):
1✔
113
    """
114
    Context manager to set a callback that will be invoked before a component's constructor is called.
115

116
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
117
    parameters in place.
118

119
    :param callback:
120
        Callback function to invoke.
121
    """
122
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
123
    try:
1✔
124
        yield
1✔
125
    finally:
126
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
127

128

129
@runtime_checkable
1✔
130
class Component(Protocol):
1✔
131
    """
132
    Note this is only used by type checking tools.
133

134
    In order to implement the `Component` protocol, custom components need to
135
    have a `run` method. The signature of the method and its return value
136
    won't be checked, i.e. classes with the following methods:
137

138
        def run(self, param: str) -> Dict[str, Any]:
139
            ...
140

141
    and
142

143
        def run(self, **kwargs):
144
            ...
145

146
    will be both considered as respecting the protocol. This makes the type
147
    checking much weaker, but we have other places where we ensure code is
148
    dealing with actual Components.
149

150
    The protocol is runtime checkable so it'll be possible to assert:
151

152
        isinstance(MyComponent, Component)
153
    """
154

155
    # This is the most reliable way to define the protocol for the `run` method.
156
    # Defining a method doesn't work as different Components will have different
157
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
158
    # expected signature must be identical.
159
    # This makes most Language Servers and type checkers happy and shows less errors.
160
    # NOTE: This check can be removed when we drop Python 3.8 support.
161
    if sys.version_info >= (3, 9):
1✔
162
        run: Callable[..., Dict[str, Any]]
×
163
    else:
164
        run: Callable
1✔
165

166

167
class ComponentMeta(type):
1✔
168
    @staticmethod
1✔
169
    def positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
170
        """
171
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
172
        """
173
        init_signature = inspect.signature(cls_type.__init__)
1✔
174
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
175

176
        out = {}
1✔
177
        for arg, (name, info) in zip(args, init_params.items()):
1✔
178
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
179
                raise ComponentError(
1✔
180
                    "Pre-init hooks do not support components with variadic positional args in their init method"
181
                )
182

183
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
184
            out[name] = arg
1✔
185
        return out
1✔
186

187
    def __call__(cls, *args, **kwargs):
1✔
188
        """
189
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
190
        """
191
        # This will call __new__ then __init__, giving us back the Component instance
192
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
193
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
194
            instance = super().__call__(*args, **kwargs)
1✔
195
        else:
196
            try:
1✔
197
                pre_init_hook.in_progress = True
1✔
198
                named_positional_args = ComponentMeta.positional_to_kwargs(cls, args)
1✔
199
                assert (
1✔
200
                    set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
201
                ), "positional and keyword arguments overlap"
202
                kwargs.update(named_positional_args)
1✔
203
                pre_init_hook.callback(cls, kwargs)
1✔
204
                instance = super().__call__(**kwargs)
1✔
205
            finally:
206
                pre_init_hook.in_progress = False
1✔
207

208
        # Before returning, we have the chance to modify the newly created
209
        # Component instance, so we take the chance and set up the I/O sockets
210

211
        # If `component.set_output_types()` was called in the component constructor,
212
        # `__haystack_output__` is already populated, no need to do anything.
213
        if not hasattr(instance, "__haystack_output__"):
1✔
214
            # If that's not the case, we need to populate `__haystack_output__`
215
            #
216
            # If the `run` method was decorated, it has a `_output_types_cache` field assigned
217
            # that stores the output specification.
218
            # We deepcopy the content of the cache to transfer ownership from the class method
219
            # to the actual instance, so that different instances of the same class won't share this data.
220
            instance.__haystack_output__ = Sockets(
1✔
221
                instance, deepcopy(getattr(instance.run, "_output_types_cache", {})), OutputSocket
222
            )
223

224
        # Create the sockets if set_input_types() wasn't called in the constructor.
225
        # If it was called and there are some parameters also in the `run()` method, these take precedence.
226
        if not hasattr(instance, "__haystack_input__"):
1✔
227
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
228
        run_signature = inspect.signature(getattr(cls, "run"))
1✔
229
        for param in list(run_signature.parameters)[1:]:  # First is 'self' and it doesn't matter.
1✔
230
            if run_signature.parameters[param].kind not in (
1✔
231
                inspect.Parameter.VAR_POSITIONAL,
232
                inspect.Parameter.VAR_KEYWORD,
233
            ):  # ignore variable args
234
                socket_kwargs = {"name": param, "type": run_signature.parameters[param].annotation}
1✔
235
                if run_signature.parameters[param].default != inspect.Parameter.empty:
1✔
236
                    socket_kwargs["default_value"] = run_signature.parameters[param].default
1✔
237
                instance.__haystack_input__[param] = InputSocket(**socket_kwargs)
1✔
238

239
        # Since a Component can't be used in multiple Pipelines at the same time
240
        # we need to know if it's already owned by a Pipeline when adding it to one.
241
        # We use this flag to check that.
242
        instance.__haystack_added_to_pipeline__ = None
1✔
243

244
        # Only Components with variadic inputs can be greedy. If the user set the greedy flag
245
        # to True, but the component doesn't have a variadic input, we set it to False.
246
        # We can have this information only at instance creation time, so we do it here.
247
        is_variadic = any(socket.is_variadic for socket in instance.__haystack_input__._sockets_dict.values())
1✔
248
        if not is_variadic and cls.__haystack_is_greedy__:
1✔
249
            logger.warning(
1✔
250
                "Component '{component}' has no variadic input, but it's marked as greedy. "
251
                "This is not supported and can lead to unexpected behavior.",
252
                component=cls.__name__,
253
            )
254

255
        return instance
1✔
256

257

258
def _component_repr(component: Component) -> str:
1✔
259
    """
260
    All Components override their __repr__ method with this one.
261

262
    It prints the component name and the input/output sockets.
263
    """
264
    result = object.__repr__(component)
1✔
265
    if pipeline := getattr(component, "__haystack_added_to_pipeline__"):
1✔
266
        # This Component has been added in a Pipeline, let's get the name from there.
267
        result += f"\n{pipeline.get_component_name(component)}"
1✔
268

269
    # We're explicitly ignoring the type here because we're sure that the component
270
    # has the __haystack_input__ and __haystack_output__ attributes at this point
271
    return f"{result}\n{component.__haystack_input__}\n{component.__haystack_output__}"  # type: ignore[attr-defined]
1✔
272

273

274
class _Component:
1✔
275
    """
276
    See module's docstring.
277

278
    Args:
279
        class_: the class that Canals should use as a component.
280
        serializable: whether to check, at init time, if the component can be saved with
281
        `save_pipelines()`.
282

283
    Returns:
284
        A class that can be recognized as a component.
285

286
    Raises:
287
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
288
    """
289

290
    def __init__(self):
1✔
291
        self.registry = {}
1✔
292

293
    def set_input_type(self, instance, name: str, type: Any, default: Any = _empty):  # noqa: A002
1✔
294
        """
295
        Add a single input socket to the component instance.
296

297
        :param instance: Component instance where the input type will be added.
298
        :param name: name of the input socket.
299
        :param type: type of the input socket.
300
        :param default: default value of the input socket, defaults to _empty
301
        """
302
        if not hasattr(instance, "__haystack_input__"):
1✔
303
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
×
304
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
305

306
    def set_input_types(self, instance, **types):
1✔
307
        """
308
        Method that specifies the input types when 'kwargs' is passed to the run method.
309

310
        Use as:
311

312
        ```python
313
        @component
314
        class MyComponent:
315

316
            def __init__(self, value: int):
317
                component.set_input_types(self, value_1=str, value_2=str)
318
                ...
319

320
            @component.output_types(output_1=int, output_2=str)
321
            def run(self, **kwargs):
322
                return {"output_1": kwargs["value_1"], "output_2": ""}
323
        ```
324

325
        Note that if the `run()` method also specifies some parameters, those will take precedence.
326

327
        For example:
328

329
        ```python
330
        @component
331
        class MyComponent:
332

333
            def __init__(self, value: int):
334
                component.set_input_types(self, value_1=str, value_2=str)
335
                ...
336

337
            @component.output_types(output_1=int, output_2=str)
338
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
339
                return {"output_1": kwargs["value_1"], "output_2": ""}
340
        ```
341

342
        would add a mandatory `value_0` parameters, make the `value_1`
343
        parameter optional with a default None, and keep the `value_2`
344
        parameter mandatory as specified in `set_input_types`.
345

346
        """
347
        instance.__haystack_input__ = Sockets(
1✔
348
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
349
        )
350

351
    def set_output_types(self, instance, **types):
1✔
352
        """
353
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
354

355
        Use as:
356

357
        ```python
358
        @component
359
        class MyComponent:
360

361
            def __init__(self, value: int):
362
                component.set_output_types(self, output_1=int, output_2=str)
363
                ...
364

365
            # no decorators here
366
            def run(self, value: int):
367
                return {"output_1": 1, "output_2": "2"}
368
        ```
369
        """
370
        instance.__haystack_output__ = Sockets(
1✔
371
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
372
        )
373

374
    def output_types(self, **types):
1✔
375
        """
376
        Decorator factory that specifies the output types of a component.
377

378
        Use as:
379

380
        ```python
381
        @component
382
        class MyComponent:
383
            @component.output_types(output_1=int, output_2=str)
384
            def run(self, value: int):
385
                return {"output_1": 1, "output_2": "2"}
386
        ```
387
        """
388

389
        def output_types_decorator(run_method):
1✔
390
            """
391
            Decorator that sets the output types of the decorated method.
392

393
            This happens at class creation time, and since we don't have the decorated
394
            class available here, we temporarily store the output types as an attribute of
395
            the decorated method. The ComponentMeta metaclass will use this data to create
396
            sockets at instance creation time.
397
            """
398
            setattr(
1✔
399
                run_method,
400
                "_output_types_cache",
401
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
402
            )
403
            return run_method
1✔
404

405
        return output_types_decorator
1✔
406

407
    def _component(self, cls, is_greedy: bool = False):
1✔
408
        """
409
        Decorator validating the structure of the component and registering it in the components registry.
410
        """
411
        logger.debug("Registering {component} as a component", component=cls)
1✔
412

413
        # Check for required methods and fail as soon as possible
414
        if not hasattr(cls, "run"):
1✔
415
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
416

417
        def copy_class_namespace(namespace):
1✔
418
            """
419
            This is the callback that `typing.new_class` will use to populate the newly created class.
420

421
            Simply copy the whole namespace from the decorated class.
422
            """
423
            for key, val in dict(cls.__dict__).items():
1✔
424
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
425
                if key in ("__dict__", "__weakref__"):
1✔
426
                    continue
1✔
427
                namespace[key] = val
1✔
428

429
        # Recreate the decorated component class so it uses our metaclass.
430
        # We must explicitly redefine the type of the class to make sure language servers
431
        # and type checkers understand that the class is of the correct type.
432
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
433
        cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)  # type: ignore[no-redef]
1✔
434

435
        # Save the component in the class registry (for deserialization)
436
        class_path = f"{cls.__module__}.{cls.__name__}"
1✔
437
        if class_path in self.registry:
1✔
438
            # Corner case, but it may occur easily in notebooks when re-running cells.
439
            logger.debug(
1✔
440
                "Component {component} is already registered. Previous imported from '{module_name}', \
441
                new imported from '{new_module_name}'",
442
                component=class_path,
443
                module_name=self.registry[class_path],
444
                new_module_name=cls,
445
            )
446
        self.registry[class_path] = cls
1✔
447
        logger.debug("Registered Component {component}", component=cls)
1✔
448

449
        # Override the __repr__ method with a default one
450
        cls.__repr__ = _component_repr
1✔
451

452
        # The greedy flag can be True only if the component has a variadic input.
453
        # At this point of the lifetime of the component, we can't reliably know if it has a variadic input.
454
        # So we set it to whatever the user specified, during the instance creation we'll change it if needed
455
        # since we'll have access to the input sockets and check if any of them is variadic.
456
        setattr(cls, "__haystack_is_greedy__", is_greedy)
1✔
457

458
        return cls
1✔
459

460
    def __call__(self, cls: Optional[type] = None, is_greedy: bool = False):
1✔
461
        # We must wrap the call to the decorator in a function for it to work
462
        # correctly with or without parens
463
        def wrap(cls):
1✔
464
            return self._component(cls, is_greedy=is_greedy)
1✔
465

466
        if cls:
1✔
467
            # Decorator is called without parens
468
            return wrap(cls)
1✔
469

470
        # Decorator is called with parens
471
        return wrap
1✔
472

473

474
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc