• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 11552785866

28 Oct 2024 11:15AM UTC coverage: 90.548% (+0.02%) from 90.524%
11552785866

push

github

web-flow
build: unpin `numpy` + use Python 3.9 in CI (#8492)

* try unpinning numpy

* try python 3.9

* release note

7558 of 8347 relevant lines covered (90.55%)

0.91 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

97.58
haystack/core/component/component.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
"""
6
Attributes:
7

8
    component: Marks a class as a component. Any class decorated with `@component` can be used by a Pipeline.
9

10
All components must follow the contract below. This docstring is the source of truth for components contract.
11

12
<hr>
13

14
`@component` decorator
15

16
All component classes must be decorated with the `@component` decorator. This allows Canals to discover them.
17

18
<hr>
19

20
`__init__(self, **kwargs)`
21

22
Optional method.
23

24
Components may have an `__init__` method where they define:
25

26
- `self.init_parameters = {same parameters that the __init__ method received}`:
27
    In this dictionary you can store any state the components wish to be persisted when they are saved.
28
    These values will be given to the `__init__` method of a new instance when the pipeline is loaded.
29
    Note that by default the `@component` decorator saves the arguments automatically.
30
    However, if a component sets their own `init_parameters` manually in `__init__()`, that will be used instead.
31
    Note: all of the values contained here **must be JSON serializable**. Serialize them manually if needed.
32

33
Components should take only "basic" Python types as parameters of their `__init__` function, or iterables and
34
dictionaries containing only such values. Anything else (objects, functions, etc) will raise an exception at init
35
time. If there's the need for such values, consider serializing them to a string.
36

37
_(TODO explain how to use classes and functions in init. In the meantime see `test/components/test_accumulate.py`)_
38

39
The `__init__` must be extremely lightweight, because it's a frequent operation during the construction and
40
validation of the pipeline. If a component has some heavy state to initialize (models, backends, etc...) refer to
41
the `warm_up()` method.
42

43
<hr>
44

45
`warm_up(self)`
46

47
Optional method.
48

49
This method is called by Pipeline before the graph execution. Make sure to avoid double-initializations,
50
because Pipeline will not keep track of which components it called `warm_up()` on.
51

52
<hr>
53

54
`run(self, data)`
55

56
Mandatory method.
57

58
This is the method where the main functionality of the component should be carried out. It's called by
59
`Pipeline.run()`.
60

61
When the component should run, Pipeline will call this method with an instance of the dataclass returned by the
62
method decorated with `@component.input`. This dataclass contains:
63

64
- all the input values coming from other components connected to it,
65
- if any is missing, the corresponding value defined in `self.defaults`, if it exists.
66

67
`run()` must return a single instance of the dataclass declared through the method decorated with
68
`@component.output`.
69

70
"""
71

72
import inspect
1✔
73
import sys
1✔
74
import warnings
1✔
75
from collections.abc import Callable
1✔
76
from contextlib import contextmanager
1✔
77
from contextvars import ContextVar
1✔
78
from copy import deepcopy
1✔
79
from dataclasses import dataclass
1✔
80
from types import new_class
1✔
81
from typing import Any, Dict, Optional, Protocol, Type, runtime_checkable
1✔
82

83
from haystack import logging
1✔
84
from haystack.core.errors import ComponentError
1✔
85

86
from .sockets import Sockets
1✔
87
from .types import InputSocket, OutputSocket, _empty
1✔
88

89
logger = logging.getLogger(__name__)
1✔
90

91

92
@dataclass
1✔
93
class PreInitHookPayload:
1✔
94
    """
95
    Payload for the hook called before a component instance is initialized.
96

97
    :param callback:
98
        Receives the following inputs: component class and init parameter keyword args.
99
    :param in_progress:
100
        Flag to indicate if the hook is currently being executed.
101
        Used to prevent it from being called recursively (if the component's constructor
102
        instantiates another component).
103
    """
104

105
    callback: Callable
1✔
106
    in_progress: bool = False
1✔
107

108

109
_COMPONENT_PRE_INIT_HOOK: ContextVar[Optional[PreInitHookPayload]] = ContextVar("component_pre_init_hook", default=None)
1✔
110

111

112
@contextmanager
1✔
113
def _hook_component_init(callback: Callable):
1✔
114
    """
115
    Context manager to set a callback that will be invoked before a component's constructor is called.
116

117
    The callback receives the component class and the init parameters (as keyword arguments) and can modify the init
118
    parameters in place.
119

120
    :param callback:
121
        Callback function to invoke.
122
    """
123
    token = _COMPONENT_PRE_INIT_HOOK.set(PreInitHookPayload(callback))
1✔
124
    try:
1✔
125
        yield
1✔
126
    finally:
127
        _COMPONENT_PRE_INIT_HOOK.reset(token)
1✔
128

129

130
@runtime_checkable
1✔
131
class Component(Protocol):
1✔
132
    """
133
    Note this is only used by type checking tools.
134

135
    In order to implement the `Component` protocol, custom components need to
136
    have a `run` method. The signature of the method and its return value
137
    won't be checked, i.e. classes with the following methods:
138

139
        def run(self, param: str) -> Dict[str, Any]:
140
            ...
141

142
    and
143

144
        def run(self, **kwargs):
145
            ...
146

147
    will be both considered as respecting the protocol. This makes the type
148
    checking much weaker, but we have other places where we ensure code is
149
    dealing with actual Components.
150

151
    The protocol is runtime checkable so it'll be possible to assert:
152

153
        isinstance(MyComponent, Component)
154
    """
155

156
    # This is the most reliable way to define the protocol for the `run` method.
157
    # Defining a method doesn't work as different Components will have different
158
    # arguments. Even defining here a method with `**kwargs` doesn't work as the
159
    # expected signature must be identical.
160
    # This makes most Language Servers and type checkers happy and shows less errors.
161
    # NOTE: This check can be removed when we drop Python 3.8 support.
162
    if sys.version_info >= (3, 9):
1✔
163
        run: Callable[..., Dict[str, Any]]
1✔
164
    else:
165
        run: Callable
×
166

167

168
class ComponentMeta(type):
1✔
169
    @staticmethod
1✔
170
    def _positional_to_kwargs(cls_type, args) -> Dict[str, Any]:
1✔
171
        """
172
        Convert positional arguments to keyword arguments based on the signature of the `__init__` method.
173
        """
174
        init_signature = inspect.signature(cls_type.__init__)
1✔
175
        init_params = {name: info for name, info in init_signature.parameters.items() if name != "self"}
1✔
176

177
        out = {}
1✔
178
        for arg, (name, info) in zip(args, init_params.items()):
1✔
179
            if info.kind == inspect.Parameter.VAR_POSITIONAL:
1✔
180
                raise ComponentError(
1✔
181
                    "Pre-init hooks do not support components with variadic positional args in their init method"
182
                )
183

184
            assert info.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.POSITIONAL_ONLY)
1✔
185
            out[name] = arg
1✔
186
        return out
1✔
187

188
    @staticmethod
1✔
189
    def _parse_and_set_output_sockets(instance: Any):
1✔
190
        has_async_run = hasattr(instance, "run_async")
1✔
191

192
        # If `component.set_output_types()` was called in the component constructor,
193
        # `__haystack_output__` is already populated, no need to do anything.
194
        if not hasattr(instance, "__haystack_output__"):
1✔
195
            # If that's not the case, we need to populate `__haystack_output__`
196
            #
197
            # If either of the run methods were decorated, they'll have a field assigned that
198
            # stores the output specification. If both run methods were decorated, we ensure that
199
            # outputs are the same. We deepcopy the content of the cache to transfer ownership from
200
            # the class method to the actual instance, so that different instances of the same class
201
            # won't share this data.
202

203
            run_output_types = getattr(instance.run, "_output_types_cache", {})
1✔
204
            async_run_output_types = getattr(instance.run_async, "_output_types_cache", {}) if has_async_run else {}
1✔
205

206
            if has_async_run and run_output_types != async_run_output_types:
1✔
207
                raise ComponentError("Output type specifications of 'run' and 'run_async' methods must be the same")
1✔
208
            output_types_cache = run_output_types
1✔
209

210
            instance.__haystack_output__ = Sockets(instance, deepcopy(output_types_cache), OutputSocket)
1✔
211

212
    @staticmethod
1✔
213
    def _parse_and_set_input_sockets(component_cls: Type, instance: Any):
1✔
214
        def inner(method, sockets):
1✔
215
            from inspect import Parameter
1✔
216

217
            run_signature = inspect.signature(method)
1✔
218

219
            for param_name, param_info in run_signature.parameters.items():
1✔
220
                if param_name == "self" or param_info.kind in (Parameter.VAR_POSITIONAL, Parameter.VAR_KEYWORD):
1✔
221
                    continue
1✔
222

223
                socket_kwargs = {"name": param_name, "type": param_info.annotation}
1✔
224
                if param_info.default != Parameter.empty:
1✔
225
                    socket_kwargs["default_value"] = param_info.default
1✔
226

227
                new_socket = InputSocket(**socket_kwargs)
1✔
228

229
                # Also ensure that new sockets don't override existing ones.
230
                existing_socket = sockets.get(param_name)
1✔
231
                if existing_socket is not None and existing_socket != new_socket:
1✔
232
                    raise ComponentError(
1✔
233
                        "set_input_types()/set_input_type() cannot override the parameters of the 'run' method"
234
                    )
235

236
                sockets[param_name] = new_socket
1✔
237

238
            return run_signature
1✔
239

240
        # Create the sockets if set_input_types() wasn't called in the constructor.
241
        if not hasattr(instance, "__haystack_input__"):
1✔
242
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
243

244
        inner(getattr(component_cls, "run"), instance.__haystack_input__)
1✔
245

246
        # Ensure that the sockets are the same for the async method, if it exists.
247
        async_run = getattr(component_cls, "run_async", None)
1✔
248
        if async_run is not None:
1✔
249
            run_sockets = Sockets(instance, {}, InputSocket)
1✔
250
            async_run_sockets = Sockets(instance, {}, InputSocket)
1✔
251

252
            # Can't use the sockets from above as they might contain
253
            # values set with set_input_types().
254
            run_sig = inner(getattr(component_cls, "run"), run_sockets)
1✔
255
            async_run_sig = inner(async_run, async_run_sockets)
1✔
256

257
            if async_run_sockets != run_sockets or run_sig != async_run_sig:
1✔
258
                raise ComponentError("Parameters of 'run' and 'run_async' methods must be the same")
1✔
259

260
    def __call__(cls, *args, **kwargs):
1✔
261
        """
262
        This method is called when clients instantiate a Component and runs before __new__ and __init__.
263
        """
264
        # This will call __new__ then __init__, giving us back the Component instance
265
        pre_init_hook = _COMPONENT_PRE_INIT_HOOK.get()
1✔
266
        if pre_init_hook is None or pre_init_hook.in_progress:
1✔
267
            instance = super().__call__(*args, **kwargs)
1✔
268
        else:
269
            try:
1✔
270
                pre_init_hook.in_progress = True
1✔
271
                named_positional_args = ComponentMeta._positional_to_kwargs(cls, args)
1✔
272
                assert (
1✔
273
                    set(named_positional_args.keys()).intersection(kwargs.keys()) == set()
274
                ), "positional and keyword arguments overlap"
275
                kwargs.update(named_positional_args)
1✔
276
                pre_init_hook.callback(cls, kwargs)
1✔
277
                instance = super().__call__(**kwargs)
1✔
278
            finally:
279
                pre_init_hook.in_progress = False
1✔
280

281
        # Before returning, we have the chance to modify the newly created
282
        # Component instance, so we take the chance and set up the I/O sockets
283
        has_async_run = hasattr(instance, "run_async")
1✔
284
        if has_async_run and not inspect.iscoroutinefunction(instance.run_async):
1✔
285
            raise ComponentError(f"Method 'run_async' of component '{cls.__name__}' must be a coroutine")
1✔
286
        instance.__haystack_supports_async__ = has_async_run
1✔
287

288
        ComponentMeta._parse_and_set_input_sockets(cls, instance)
1✔
289
        ComponentMeta._parse_and_set_output_sockets(instance)
1✔
290

291
        # Since a Component can't be used in multiple Pipelines at the same time
292
        # we need to know if it's already owned by a Pipeline when adding it to one.
293
        # We use this flag to check that.
294
        instance.__haystack_added_to_pipeline__ = None
1✔
295

296
        return instance
1✔
297

298

299
def _component_repr(component: Component) -> str:
1✔
300
    """
301
    All Components override their __repr__ method with this one.
302

303
    It prints the component name and the input/output sockets.
304
    """
305
    result = object.__repr__(component)
1✔
306
    if pipeline := getattr(component, "__haystack_added_to_pipeline__", None):
1✔
307
        # This Component has been added in a Pipeline, let's get the name from there.
308
        result += f"\n{pipeline.get_component_name(component)}"
1✔
309

310
    # We're explicitly ignoring the type here because we're sure that the component
311
    # has the __haystack_input__ and __haystack_output__ attributes at this point
312
    return (
1✔
313
        f'{result}\n{getattr(component, "__haystack_input__", "<invalid_input_sockets>")}'
314
        f'\n{getattr(component, "__haystack_output__", "<invalid_output_sockets>")}'
315
    )
316

317

318
def _component_run_has_kwargs(component_cls: Type) -> bool:
1✔
319
    run_method = getattr(component_cls, "run", None)
1✔
320
    if run_method is None:
1✔
321
        return False
×
322
    else:
323
        return any(
1✔
324
            param.kind == inspect.Parameter.VAR_KEYWORD for param in inspect.signature(run_method).parameters.values()
325
        )
326

327

328
class _Component:
1✔
329
    """
330
    See module's docstring.
331

332
    Args:
333
        class_: the class that Canals should use as a component.
334
        serializable: whether to check, at init time, if the component can be saved with
335
        `save_pipelines()`.
336

337
    Returns:
338
        A class that can be recognized as a component.
339

340
    Raises:
341
        ComponentError: if the class provided has no `run()` method or otherwise doesn't respect the component contract.
342
    """
343

344
    def __init__(self):
1✔
345
        self.registry = {}
1✔
346

347
    def set_input_type(
1✔
348
        self,
349
        instance,
350
        name: str,
351
        type: Any,  # noqa: A002
352
        default: Any = _empty,
353
    ):
354
        """
355
        Add a single input socket to the component instance.
356

357
        Replaces any existing input socket with the same name.
358

359
        :param instance: Component instance where the input type will be added.
360
        :param name: name of the input socket.
361
        :param type: type of the input socket.
362
        :param default: default value of the input socket, defaults to _empty
363
        """
364
        if not _component_run_has_kwargs(instance.__class__):
1✔
365
            raise ComponentError(
1✔
366
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
367
            )
368

369
        if not hasattr(instance, "__haystack_input__"):
1✔
370
            instance.__haystack_input__ = Sockets(instance, {}, InputSocket)
1✔
371
        instance.__haystack_input__[name] = InputSocket(name=name, type=type, default_value=default)
1✔
372

373
    def set_input_types(self, instance, **types):
1✔
374
        """
375
        Method that specifies the input types when 'kwargs' is passed to the run method.
376

377
        Use as:
378

379
        ```python
380
        @component
381
        class MyComponent:
382

383
            def __init__(self, value: int):
384
                component.set_input_types(self, value_1=str, value_2=str)
385
                ...
386

387
            @component.output_types(output_1=int, output_2=str)
388
            def run(self, **kwargs):
389
                return {"output_1": kwargs["value_1"], "output_2": ""}
390
        ```
391

392
        Note that if the `run()` method also specifies some parameters, those will take precedence.
393

394
        For example:
395

396
        ```python
397
        @component
398
        class MyComponent:
399

400
            def __init__(self, value: int):
401
                component.set_input_types(self, value_1=str, value_2=str)
402
                ...
403

404
            @component.output_types(output_1=int, output_2=str)
405
            def run(self, value_0: str, value_1: Optional[str] = None, **kwargs):
406
                return {"output_1": kwargs["value_1"], "output_2": ""}
407
        ```
408

409
        would add a mandatory `value_0` parameters, make the `value_1`
410
        parameter optional with a default None, and keep the `value_2`
411
        parameter mandatory as specified in `set_input_types`.
412

413
        """
414
        if not _component_run_has_kwargs(instance.__class__):
1✔
415
            raise ComponentError(
1✔
416
                "Cannot set input types on a component that doesn't have a kwargs parameter in the 'run' method"
417
            )
418

419
        instance.__haystack_input__ = Sockets(
1✔
420
            instance, {name: InputSocket(name=name, type=type_) for name, type_ in types.items()}, InputSocket
421
        )
422

423
    def set_output_types(self, instance, **types):
1✔
424
        """
425
        Method that specifies the output types when the 'run' method is not decorated with 'component.output_types'.
426

427
        Use as:
428

429
        ```python
430
        @component
431
        class MyComponent:
432

433
            def __init__(self, value: int):
434
                component.set_output_types(self, output_1=int, output_2=str)
435
                ...
436

437
            # no decorators here
438
            def run(self, value: int):
439
                return {"output_1": 1, "output_2": "2"}
440
        ```
441
        """
442
        has_decorator = hasattr(instance.run, "_output_types_cache")
1✔
443
        if has_decorator:
1✔
444
            raise ComponentError(
1✔
445
                "Cannot call `set_output_types` on a component that already has "
446
                "the 'output_types' decorator on its `run` method"
447
            )
448

449
        instance.__haystack_output__ = Sockets(
1✔
450
            instance, {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()}, OutputSocket
451
        )
452

453
    def output_types(self, **types):
1✔
454
        """
455
        Decorator factory that specifies the output types of a component.
456

457
        Use as:
458

459
        ```python
460
        @component
461
        class MyComponent:
462
            @component.output_types(output_1=int, output_2=str)
463
            def run(self, value: int):
464
                return {"output_1": 1, "output_2": "2"}
465
        ```
466
        """
467

468
        def output_types_decorator(run_method):
1✔
469
            """
470
            Decorator that sets the output types of the decorated method.
471

472
            This happens at class creation time, and since we don't have the decorated
473
            class available here, we temporarily store the output types as an attribute of
474
            the decorated method. The ComponentMeta metaclass will use this data to create
475
            sockets at instance creation time.
476
            """
477
            method_name = run_method.__name__
1✔
478
            if method_name not in ("run", "run_async"):
1✔
479
                raise ComponentError("'output_types' decorator can only be used on 'run' and 'run_async' methods")
1✔
480

481
            setattr(
1✔
482
                run_method,
483
                "_output_types_cache",
484
                {name: OutputSocket(name=name, type=type_) for name, type_ in types.items()},
485
            )
486
            return run_method
1✔
487

488
        return output_types_decorator
1✔
489

490
    def _component(self, cls, is_greedy: Optional[bool] = None):
1✔
491
        """
492
        Decorator validating the structure of the component and registering it in the components registry.
493
        """
494
        logger.debug("Registering {component} as a component", component=cls)
1✔
495

496
        if is_greedy is not None:
1✔
497
            msg = (
×
498
                "The 'is_greedy' argument is deprecated and will be removed in version '2.7.0'. "
499
                "Change the 'Variadic' input of your Component to 'GreedyVariadic' instead."
500
            )
501
            warnings.warn(msg, DeprecationWarning)
×
502
        else:
503
            is_greedy = False
1✔
504

505
        # Check for required methods and fail as soon as possible
506
        if not hasattr(cls, "run"):
1✔
507
            raise ComponentError(f"{cls.__name__} must have a 'run()' method. See the docs for more information.")
1✔
508

509
        def copy_class_namespace(namespace):
1✔
510
            """
511
            This is the callback that `typing.new_class` will use to populate the newly created class.
512

513
            Simply copy the whole namespace from the decorated class.
514
            """
515
            for key, val in dict(cls.__dict__).items():
1✔
516
                # __dict__ and __weakref__ are class-bound, we should let Python recreate them.
517
                if key in ("__dict__", "__weakref__"):
1✔
518
                    continue
1✔
519
                namespace[key] = val
1✔
520

521
        # Recreate the decorated component class so it uses our metaclass.
522
        # We must explicitly redefine the type of the class to make sure language servers
523
        # and type checkers understand that the class is of the correct type.
524
        # mypy doesn't like that we do this though so we explicitly ignore the type check.
525
        cls: cls.__name__ = new_class(cls.__name__, cls.__bases__, {"metaclass": ComponentMeta}, copy_class_namespace)  # type: ignore[no-redef]
1✔
526

527
        # Save the component in the class registry (for deserialization)
528
        class_path = f"{cls.__module__}.{cls.__name__}"
1✔
529
        if class_path in self.registry:
1✔
530
            # Corner case, but it may occur easily in notebooks when re-running cells.
531
            logger.debug(
1✔
532
                "Component {component} is already registered. Previous imported from '{module_name}', \
533
                new imported from '{new_module_name}'",
534
                component=class_path,
535
                module_name=self.registry[class_path],
536
                new_module_name=cls,
537
            )
538
        self.registry[class_path] = cls
1✔
539
        logger.debug("Registered Component {component}", component=cls)
1✔
540

541
        # Override the __repr__ method with a default one
542
        cls.__repr__ = _component_repr
1✔
543

544
        return cls
1✔
545

546
    def __call__(self, cls: Optional[type] = None, is_greedy: Optional[bool] = None):
1✔
547
        # We must wrap the call to the decorator in a function for it to work
548
        # correctly with or without parens
549
        def wrap(cls):
1✔
550
            return self._component(cls, is_greedy=is_greedy)
1✔
551

552
        if cls:
1✔
553
            # Decorator is called without parens
554
            return wrap(cls)
1✔
555

556
        # Decorator is called with parens
557
        return wrap
1✔
558

559

560
component = _Component()
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc