• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 15210934031

23 May 2025 01:01PM UTC coverage: 90.056% (-0.03%) from 90.087%
15210934031

Pull #9434

github

web-flow
Merge f2e68af13 into d8cc6f733
Pull Request #9434: fix: Fix invoker to work when using dataclass with from_dict but dataclass…

11338 of 12590 relevant lines covered (90.06%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

93.51
haystack/core/pipeline/base.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import itertools
1✔
6
from collections import defaultdict
1✔
7
from datetime import datetime
1✔
8
from enum import IntEnum
1✔
9
from pathlib import Path
1✔
10
from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union
1✔
11

12
import networkx  # type:ignore
1✔
13

14
from haystack import logging, tracing
1✔
15
from haystack.core.component import Component, InputSocket, OutputSocket, component
1✔
16
from haystack.core.errors import (
1✔
17
    DeserializationError,
18
    PipelineComponentsBlockedError,
19
    PipelineConnectError,
20
    PipelineDrawingError,
21
    PipelineError,
22
    PipelineMaxComponentRuns,
23
    PipelineUnmarshalError,
24
    PipelineValidationError,
25
)
26
from haystack.core.pipeline.component_checks import (
1✔
27
    _NO_OUTPUT_PRODUCED,
28
    all_predecessors_executed,
29
    are_all_lazy_variadic_sockets_resolved,
30
    are_all_sockets_ready,
31
    can_component_run,
32
    is_any_greedy_socket_ready,
33
    is_socket_lazy_variadic,
34
)
35
from haystack.core.pipeline.utils import (
1✔
36
    FIFOPriorityQueue,
37
    _deepcopy_with_exceptions,
38
    args_deprecated,
39
    parse_connect_string,
40
)
41
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
1✔
42
from haystack.core.type_utils import _type_name, _types_are_compatible
1✔
43
from haystack.marshal import Marshaller, YamlMarshaller
1✔
44
from haystack.utils import is_in_jupyter, type_serialization
1✔
45

46
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
1✔
47
from .draw import _to_mermaid_image
1✔
48
from .template import PipelineTemplate, PredefinedPipeline
1✔
49

50
DEFAULT_MARSHALLER = YamlMarshaller()
1✔
51

52
# We use a generic type to annotate the return value of class methods,
53
# so that static analyzers won't be confused when derived classes
54
# use those methods.
55
T = TypeVar("T", bound="PipelineBase")
1✔
56

57
logger = logging.getLogger(__name__)
1✔
58

59

60
# Constants for tracing tags
61
_COMPONENT_INPUT = "haystack.component.input"
1✔
62
_COMPONENT_OUTPUT = "haystack.component.output"
1✔
63
_COMPONENT_VISITS = "haystack.component.visits"
1✔
64

65

66
class ComponentPriority(IntEnum):
1✔
67
    HIGHEST = 1
1✔
68
    READY = 2
1✔
69
    DEFER = 3
1✔
70
    DEFER_LAST = 4
1✔
71
    BLOCKED = 5
1✔
72

73

74
class PipelineBase:
1✔
75
    """
76
    Components orchestration engine.
77

78
    Builds a graph of components and orchestrates their execution according to the execution graph.
79
    """
80

81
    def __init__(
1✔
82
        self,
83
        metadata: Optional[Dict[str, Any]] = None,
84
        max_runs_per_component: int = 100,
85
        connection_type_validation: bool = True,
86
    ):
87
        """
88
        Creates the Pipeline.
89

90
        :param metadata:
91
            Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in
92
            this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file.
93
        :param max_runs_per_component:
94
            How many times the `Pipeline` can run the same Component.
95
            If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
96
            If not set defaults to 100 runs per Component.
97
        :param connection_type_validation: Whether the pipeline will validate the types of the connections.
98
            Defaults to True.
99
        """
100
        self._telemetry_runs = 0
1✔
101
        self._last_telemetry_sent: Optional[datetime] = None
1✔
102
        self.metadata = metadata or {}
1✔
103
        self.graph = networkx.MultiDiGraph()
1✔
104
        self._max_runs_per_component = max_runs_per_component
1✔
105
        self._connection_type_validation = connection_type_validation
1✔
106

107
    def __eq__(self, other) -> bool:
1✔
108
        """
109
        Pipeline equality is defined by their type and the equality of their serialized form.
110

111
        Pipelines of the same type share every metadata, node and edge, but they're not required to use
112
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
113
        """
114
        if not isinstance(self, type(other)):
1✔
115
            return False
×
116
        return self.to_dict() == other.to_dict()
1✔
117

118
    def __repr__(self) -> str:
1✔
119
        """
120
        Returns a text representation of the Pipeline.
121
        """
122
        res = f"{object.__repr__(self)}\n"
1✔
123
        if self.metadata:
1✔
124
            res += "🧱 Metadata\n"
1✔
125
            for k, v in self.metadata.items():
1✔
126
                res += f"  - {k}: {v}\n"
1✔
127

128
        res += "🚅 Components\n"
1✔
129
        for name, instance in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
130
            res += f"  - {name}: {instance.__class__.__name__}\n"
1✔
131

132
        res += "🛤️ Connections\n"
1✔
133
        for sender, receiver, edge_data in self.graph.edges(data=True):
1✔
134
            sender_socket = edge_data["from_socket"].name
1✔
135
            receiver_socket = edge_data["to_socket"].name
1✔
136
            res += f"  - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
1✔
137

138
        return res
1✔
139

140
    def to_dict(self) -> Dict[str, Any]:
1✔
141
        """
142
        Serializes the pipeline to a dictionary.
143

144
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
145

146
        :returns:
147
            Dictionary with serialized data.
148
        """
149
        components = {}
1✔
150
        for name, instance in self.graph.nodes(data="instance"):  # type:ignore
1✔
151
            components[name] = component_to_dict(instance, name)
1✔
152

153
        connections = []
1✔
154
        for sender, receiver, edge_data in self.graph.edges.data():
1✔
155
            sender_socket = edge_data["from_socket"].name
1✔
156
            receiver_socket = edge_data["to_socket"].name
1✔
157
            connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
1✔
158
        return {
1✔
159
            "metadata": self.metadata,
160
            "max_runs_per_component": self._max_runs_per_component,
161
            "components": components,
162
            "connections": connections,
163
            "connection_type_validation": self._connection_type_validation,
164
        }
165

166
    @classmethod
1✔
167
    def from_dict(
1✔
168
        cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
169
    ) -> T:
170
        """
171
        Deserializes the pipeline from a dictionary.
172

173
        :param data:
174
            Dictionary to deserialize from.
175
        :param callbacks:
176
            Callbacks to invoke during deserialization.
177
        :param kwargs:
178
            `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new
179
            ones.
180
        :returns:
181
            Deserialized component.
182
        """
183
        data_copy = _deepcopy_with_exceptions(data)  # to prevent modification of original data
1✔
184
        metadata = data_copy.get("metadata", {})
1✔
185
        max_runs_per_component = data_copy.get("max_runs_per_component", 100)
1✔
186
        connection_type_validation = data_copy.get("connection_type_validation", True)
1✔
187
        pipe = cls(
1✔
188
            metadata=metadata,
189
            max_runs_per_component=max_runs_per_component,
190
            connection_type_validation=connection_type_validation,
191
        )
192
        components_to_reuse = kwargs.get("components", {})
1✔
193
        for name, component_data in data_copy.get("components", {}).items():
1✔
194
            if name in components_to_reuse:
1✔
195
                # Reuse an instance
196
                instance = components_to_reuse[name]
1✔
197
            else:
198
                if "type" not in component_data:
1✔
199
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
200

201
                if component_data["type"] not in component.registry:
1✔
202
                    try:
1✔
203
                        # Import the module first...
204
                        module, _ = component_data["type"].rsplit(".", 1)
1✔
205
                        logger.debug("Trying to import module {module_name}", module_name=module)
1✔
206
                        type_serialization.thread_safe_import(module)
1✔
207
                        # ...then try again
208
                        if component_data["type"] not in component.registry:
1✔
209
                            raise PipelineError(
1✔
210
                                f"Successfully imported module '{module}' but couldn't find "
211
                                f"'{component_data['type']}' in the component registry.\n"
212
                                f"The component might be registered under a different path. "
213
                                f"Here are the registered components:\n {list(component.registry.keys())}\n"
214
                            )
215
                    except (ImportError, PipelineError, ValueError) as e:
1✔
216
                        raise PipelineError(
1✔
217
                            f"Component '{component_data['type']}' (name: '{name}') not imported. Please "
218
                            f"check that the package is installed and the component path is correct."
219
                        ) from e
220

221
                # Create a new one
222
                component_class = component.registry[component_data["type"]]
1✔
223

224
                try:
1✔
225
                    instance = component_from_dict(component_class, component_data, name, callbacks)
1✔
226
                except Exception as e:
1✔
227
                    msg = (
1✔
228
                        f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
229
                        f"with the following data: {str(component_data)}. Possible reasons include "
230
                        "malformed serialized data, mismatch between the serialized component and the "
231
                        "loaded one (due to a breaking change, see "
232
                        "https://github.com/deepset-ai/haystack/releases), etc."
233
                    )
234
                    raise DeserializationError(msg) from e
1✔
235
            pipe.add_component(name=name, instance=instance)
1✔
236

237
        for connection in data.get("connections", []):
1✔
238
            if "sender" not in connection:
1✔
239
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
240
            if "receiver" not in connection:
1✔
241
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
242
            pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
1✔
243

244
        return pipe
1✔
245

246
    def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
1✔
247
        """
248
        Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
249

250
        :param marshaller:
251
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
252
        :returns:
253
            A string representing the pipeline.
254
        """
255
        return marshaller.marshal(self.to_dict())
1✔
256

257
    def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
1✔
258
        """
259
        Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
260

261
        :param fp:
262
            A file-like object ready to be written to.
263
        :param marshaller:
264
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
265
        """
266
        fp.write(marshaller.marshal(self.to_dict()))
1✔
267

268
    @classmethod
1✔
269
    def loads(
1✔
270
        cls: Type[T],
271
        data: Union[str, bytes, bytearray],
272
        marshaller: Marshaller = DEFAULT_MARSHALLER,
273
        callbacks: Optional[DeserializationCallbacks] = None,
274
    ) -> T:
275
        """
276
        Creates a `Pipeline` object from the string representation passed in the `data` argument.
277

278
        :param data:
279
            The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
280
        :param marshaller:
281
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
282
        :param callbacks:
283
            Callbacks to invoke during deserialization.
284
        :raises DeserializationError:
285
            If an error occurs during deserialization.
286
        :returns:
287
            A `Pipeline` object.
288
        """
289
        try:
1✔
290
            deserialized_data = marshaller.unmarshal(data)
1✔
291
        except Exception as e:
1✔
292
            raise DeserializationError(
1✔
293
                "Error while unmarshalling serialized pipeline data. This is usually "
294
                "caused by malformed or invalid syntax in the serialized representation."
295
            ) from e
296

297
        return cls.from_dict(deserialized_data, callbacks)
1✔
298

299
    @classmethod
1✔
300
    def load(
1✔
301
        cls: Type[T],
302
        fp: TextIO,
303
        marshaller: Marshaller = DEFAULT_MARSHALLER,
304
        callbacks: Optional[DeserializationCallbacks] = None,
305
    ) -> T:
306
        """
307
        Creates a `Pipeline` object a string representation.
308

309
        The string representation is read from the file-like object passed in the `fp` argument.
310

311

312
        :param fp:
313
            A file-like object ready to be read from.
314
        :param marshaller:
315
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
316
        :param callbacks:
317
            Callbacks to invoke during deserialization.
318
        :raises DeserializationError:
319
            If an error occurs during deserialization.
320
        :returns:
321
            A `Pipeline` object.
322
        """
323
        return cls.loads(fp.read(), marshaller, callbacks)
1✔
324

325
    def add_component(self, name: str, instance: Component) -> None:
1✔
326
        """
327
        Add the given component to the pipeline.
328

329
        Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
330
        Component names must be unique, but component instances can be reused if needed.
331

332
        :param name:
333
            The name of the component to add.
334
        :param instance:
335
            The component instance to add.
336

337
        :raises ValueError:
338
            If a component with the same name already exists.
339
        :raises PipelineValidationError:
340
            If the given instance is not a component.
341
        """
342
        # Component names are unique
343
        if name in self.graph.nodes:
1✔
344
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
345

346
        # Components can't be named `_debug`
347
        if name == "_debug":
1✔
348
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
1✔
349

350
        # Component names can't have "."
351
        if "." in name:
1✔
352
            raise ValueError(f"{name} is an invalid component name, cannot contain '.' (dot) characters.")
1✔
353

354
        # Component instances must be components
355
        if not isinstance(instance, Component):
1✔
356
            raise PipelineValidationError(
×
357
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
358
            )
359

360
        if getattr(instance, "__haystack_added_to_pipeline__", None):
1✔
361
            msg = (
1✔
362
                "Component has already been added in another Pipeline. Components can't be shared between Pipelines. "
363
                "Create a new instance instead."
364
            )
365
            raise PipelineError(msg)
1✔
366

367
        setattr(instance, "__haystack_added_to_pipeline__", self)
1✔
368

369
        # Add component to the graph, disconnected
370
        logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
1✔
371
        # We're completely sure the fields exist so we ignore the type error
372
        self.graph.add_node(
1✔
373
            name,
374
            instance=instance,
375
            input_sockets=instance.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
376
            output_sockets=instance.__haystack_output__._sockets_dict,  # type: ignore[attr-defined]
377
            visits=0,
378
        )
379

380
    def remove_component(self, name: str) -> Component:
1✔
381
        """
382
        Remove and returns component from the pipeline.
383

384
        Remove an existing component from the pipeline by providing its name.
385
        All edges that connect to the component will also be deleted.
386

387
        :param name:
388
            The name of the component to remove.
389
        :returns:
390
            The removed Component instance.
391

392
        :raises ValueError:
393
            If there is no component with that name already in the Pipeline.
394
        """
395

396
        # Check that a component with that name is in the Pipeline
397
        try:
1✔
398
            instance = self.get_component(name)
1✔
399
        except ValueError as exc:
1✔
400
            raise ValueError(
1✔
401
                f"There is no component named '{name}' in the pipeline. The valid component names are: ",
402
                ", ".join(n for n in self.graph.nodes),
403
            ) from exc
404

405
        # Delete component from the graph, deleting all its connections
406
        self.graph.remove_node(name)
1✔
407

408
        # Reset the Component sockets' senders and receivers
409
        input_sockets = instance.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
1✔
410
        for socket in input_sockets.values():
1✔
411
            socket.senders = []
1✔
412

413
        output_sockets = instance.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
1✔
414
        for socket in output_sockets.values():
1✔
415
            socket.receivers = []
1✔
416

417
        # Reset the Component's pipeline reference
418
        setattr(instance, "__haystack_added_to_pipeline__", None)
1✔
419

420
        return instance
1✔
421

422
    def connect(self, sender: str, receiver: str) -> "PipelineBase":  # noqa: PLR0915 PLR0912
1✔
423
        """
424
        Connects two components together.
425

426
        All components to connect must exist in the pipeline.
427
        If connecting to a component that has several output connections, specify the inputs and output names as
428
        'component_name.connections_name'.
429

430
        :param sender:
431
            The component that delivers the value. This can be either just a component name or can be
432
            in the format `component_name.connection_name` if the component has multiple outputs.
433
        :param receiver:
434
            The component that receives the value. This can be either just a component name or can be
435
            in the format `component_name.connection_name` if the component has multiple inputs.
436

437
        :returns:
438
            The Pipeline instance.
439

440
        :raises PipelineConnectError:
441
            If the two components cannot be connected (for example if one of the components is
442
            not present in the pipeline, or the connections don't match by type, and so on).
443
        """
444
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
445
        sender_component_name, sender_socket_name = parse_connect_string(sender)
1✔
446
        receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
1✔
447

448
        if sender_component_name == receiver_component_name:
1✔
449
            raise PipelineConnectError("Connecting a Component to itself is not supported.")
1✔
450

451
        # Get the nodes data.
452
        try:
1✔
453
            sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
1✔
454
        except KeyError as exc:
1✔
455
            raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
1✔
456
        try:
1✔
457
            receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
1✔
458
        except KeyError as exc:
1✔
459
            raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
1✔
460

461
        # If the name of either socket is given, get the socket
462
        sender_socket: Optional[OutputSocket] = None
1✔
463
        if sender_socket_name:
1✔
464
            sender_socket = sender_sockets.get(sender_socket_name)
1✔
465
            if not sender_socket:
1✔
466
                raise PipelineConnectError(
1✔
467
                    f"'{sender} does not exist. "
468
                    f"Output connections of {sender_component_name} are: "
469
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
470
                )
471

472
        receiver_socket: Optional[InputSocket] = None
1✔
473
        if receiver_socket_name:
1✔
474
            receiver_socket = receiver_sockets.get(receiver_socket_name)
1✔
475
            if not receiver_socket:
1✔
476
                raise PipelineConnectError(
1✔
477
                    f"'{receiver} does not exist. "
478
                    f"Input connections of {receiver_component_name} are: "
479
                    + ", ".join(
480
                        [f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
481
                    )
482
                )
483

484
        # Look for a matching connection among the possible ones.
485
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
486
        sender_socket_candidates: List[OutputSocket] = (
1✔
487
            [sender_socket] if sender_socket else list(sender_sockets.values())
488
        )
489
        receiver_socket_candidates: List[InputSocket] = (
1✔
490
            [receiver_socket] if receiver_socket else list(receiver_sockets.values())
491
        )
492

493
        # Find all possible connections between these two components
494
        possible_connections = []
1✔
495
        for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
1✔
496
            if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
1✔
497
                possible_connections.append((sender_sock, receiver_sock))
1✔
498

499
        # We need this status for error messages, since we might need it in multiple places we calculate it here
500
        status = _connections_status(
1✔
501
            sender_node=sender_component_name,
502
            sender_sockets=sender_socket_candidates,
503
            receiver_node=receiver_component_name,
504
            receiver_sockets=receiver_socket_candidates,
505
        )
506

507
        if not possible_connections:
1✔
508
            # There's no possible connection between these two components
509
            if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
1✔
510
                msg = (
1✔
511
                    f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with "
512
                    f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': "
513
                    f"their declared input and output types do not match.\n{status}"
514
                )
515
            else:
516
                msg = (
×
517
                    f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
518
                    f"no matching connections available.\n{status}"
519
                )
520
            raise PipelineConnectError(msg)
1✔
521

522
        if len(possible_connections) == 1:
1✔
523
            # There's only one possible connection, use it
524
            sender_socket = possible_connections[0][0]
1✔
525
            receiver_socket = possible_connections[0][1]
1✔
526

527
        if len(possible_connections) > 1:
1✔
528
            # There are multiple possible connection, let's try to match them by name
529
            name_matches = [
1✔
530
                (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
531
            ]
532
            if len(name_matches) != 1:
1✔
533
                # There's are either no matches or more than one, we can't pick one reliably
534
                msg = (
1✔
535
                    f"Cannot connect '{sender_component_name}' with "
536
                    f"'{receiver_component_name}': more than one connection is possible "
537
                    "between these components. Please specify the connection name, like: "
538
                    f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
539
                    f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
540
                )
541
                raise PipelineConnectError(msg)
1✔
542

543
            # Get the only possible match
544
            sender_socket = name_matches[0][0]
1✔
545
            receiver_socket = name_matches[0][1]
1✔
546

547
        # Connection must be valid on both sender/receiver sides
548
        if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
1✔
549
            if sender_component_name and sender_socket:
×
550
                sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
×
551
            else:
552
                sender_repr = "input needed"
×
553

554
            if receiver_component_name and receiver_socket:
×
555
                receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
×
556
            else:
557
                receiver_repr = "output"
×
558
            msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
×
559
            raise PipelineConnectError(msg)
×
560

561
        logger.debug(
1✔
562
            "Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
563
            sender_component=sender_component_name,
564
            sender_socket_name=sender_socket.name,
565
            receiver_component=receiver_component_name,
566
            receiver_socket_name=receiver_socket.name,
567
        )
568

569
        if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
1✔
570
            # This is already connected, nothing to do
571
            return self
1✔
572

573
        if receiver_socket.senders and not receiver_socket.is_variadic:
1✔
574
            # Only variadic input sockets can receive from multiple senders
575
            msg = (
1✔
576
                f"Cannot connect '{sender_component_name}.{sender_socket.name}' with "
577
                f"'{receiver_component_name}.{receiver_socket.name}': "
578
                f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
579
            )
580
            raise PipelineConnectError(msg)
1✔
581

582
        # Update the sockets with the new connection
583
        sender_socket.receivers.append(receiver_component_name)
1✔
584
        receiver_socket.senders.append(sender_component_name)
1✔
585

586
        # Create the new connection
587
        self.graph.add_edge(
1✔
588
            sender_component_name,
589
            receiver_component_name,
590
            key=f"{sender_socket.name}/{receiver_socket.name}",
591
            conn_type=_type_name(sender_socket.type),
592
            from_socket=sender_socket,
593
            to_socket=receiver_socket,
594
            mandatory=receiver_socket.is_mandatory,
595
        )
596
        return self
1✔
597

598
    def get_component(self, name: str) -> Component:
1✔
599
        """
600
        Get the component with the specified name from the pipeline.
601

602
        :param name:
603
            The name of the component.
604
        :returns:
605
            The instance of that component.
606

607
        :raises ValueError:
608
            If a component with that name is not present in the pipeline.
609
        """
610
        try:
1✔
611
            return self.graph.nodes[name]["instance"]
1✔
612
        except KeyError as exc:
1✔
613
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
1✔
614

615
    def get_component_name(self, instance: Component) -> str:
1✔
616
        """
617
        Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
618

619
        :param instance:
620
            The Component instance to look for.
621
        :returns:
622
            The name of the Component instance.
623
        """
624
        for name, inst in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
625
            if inst == instance:
1✔
626
                return name
1✔
627
        return ""
1✔
628

629
    def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
630
        """
631
        Returns a dictionary containing the inputs of a pipeline.
632

633
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
634
        the input sockets of that component, including their types and whether they are optional.
635

636
        :param include_components_with_connected_inputs:
637
            If `False`, only components that have disconnected input edges are
638
            included in the output.
639
        :returns:
640
            A dictionary where each key is a pipeline component name and each value is a dictionary of
641
            inputs sockets of that component.
642
        """
643
        inputs: Dict[str, Dict[str, Any]] = {}
1✔
644
        for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
1✔
645
            sockets_description = {}
1✔
646
            for socket in data:
1✔
647
                sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
1✔
648
                if not socket.is_mandatory:
1✔
649
                    sockets_description[socket.name]["default_value"] = socket.default_value
1✔
650

651
            if sockets_description:
1✔
652
                inputs[component_name] = sockets_description
1✔
653
        return inputs
1✔
654

655
    def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
656
        """
657
        Returns a dictionary containing the outputs of a pipeline.
658

659
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
660
        the output sockets of that component.
661

662
        :param include_components_with_connected_outputs:
663
            If `False`, only components that have disconnected output edges are
664
            included in the output.
665
        :returns:
666
            A dictionary where each key is a pipeline component name and each value is a dictionary of
667
            output sockets of that component.
668
        """
669
        outputs = {
1✔
670
            comp: {socket.name: {"type": socket.type} for socket in data}
671
            for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
672
            if data
673
        }
674
        return outputs
1✔
675

676
    @args_deprecated
1✔
677
    def show(
1✔
678
        self,
679
        server_url: str = "https://mermaid.ink",
680
        params: Optional[dict] = None,
681
        timeout: int = 30,
682
        super_component_expansion: bool = False,
683
    ) -> None:
684
        """
685
        Display an image representing this `Pipeline` in a Jupyter notebook.
686

687
        This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
688
        the notebook.
689

690
        :param server_url:
691
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
692
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
693
            info on how to set up your own Mermaid server.
694

695
        :param params:
696
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
697
            Supported keys:
698
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
699
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
700
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
701
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
702
                - width: Width of the output image (integer).
703
                - height: Height of the output image (integer).
704
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
705
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
706
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
707
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
708

709
        :param timeout:
710
            Timeout in seconds for the request to the Mermaid server.
711

712
        :param super_component_expansion:
713
            If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
714
            super-components as if they were components part of the pipeline instead of a "black-box".
715
            Otherwise, only the super-component itself will be displayed.
716

717
        :raises PipelineDrawingError:
718
            If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
719
        """
720

721
        # Call the internal implementation with keyword arguments
722
        self._show_internal(
1✔
723
            server_url=server_url, params=params, timeout=timeout, super_component_expansion=super_component_expansion
724
        )
725

726
    def _show_internal(
1✔
727
        self,
728
        *,
729
        server_url: str = "https://mermaid.ink",
730
        params: Optional[dict] = None,
731
        timeout: int = 30,
732
        super_component_expansion: bool = False,
733
    ) -> None:
734
        """
735
        Internal implementation of show() that uses keyword-only arguments.
736

737
        ToDo: after 2.14.0 release make this the main function and remove the old one.
738
        """
739
        if is_in_jupyter():
1✔
740
            from IPython.display import Image, display  # type: ignore
1✔
741

742
            if super_component_expansion:
1✔
743
                graph, super_component_mapping = self._merge_super_component_pipelines()
×
744
            else:
745
                graph = self.graph
1✔
746
                super_component_mapping = None
1✔
747

748
            image_data = _to_mermaid_image(
1✔
749
                graph,
750
                server_url=server_url,
751
                params=params,
752
                timeout=timeout,
753
                super_component_mapping=super_component_mapping,
754
            )
755
            display(Image(image_data))
1✔
756
        else:
757
            msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
1✔
758
            raise PipelineDrawingError(msg)
1✔
759

760
    @args_deprecated
1✔
761
    def draw(  # pylint: disable=too-many-positional-arguments
1✔
762
        self,
763
        path: Path,
764
        server_url: str = "https://mermaid.ink",
765
        params: Optional[dict] = None,
766
        timeout: int = 30,
767
        super_component_expansion: bool = False,
768
    ) -> None:
769
        """
770
        Save an image representing this `Pipeline` to the specified file path.
771

772
        This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
773

774
        :param path:
775
            The file path where the generated image will be saved.
776

777
        :param server_url:
778
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
779
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
780
            info on how to set up your own Mermaid server.
781

782
        :param params:
783
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
784
            Supported keys:
785
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
786
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
787
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
788
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
789
                - width: Width of the output image (integer).
790
                - height: Height of the output image (integer).
791
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
792
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
793
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
794
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
795

796
        :param timeout:
797
            Timeout in seconds for the request to the Mermaid server.
798

799
        :param super_component_expansion:
800
            If set to True and the pipeline contains SuperComponents the diagram will show the internal structure of
801
            super-components as if they were components part of the pipeline instead of a "black-box".
802
            Otherwise, only the super-component itself will be displayed.
803

804
        :raises PipelineDrawingError:
805
            If there is an issue with rendering or saving the image.
806
        """
807

808
        # Call the internal implementation with keyword arguments
809
        self._draw_internal(
1✔
810
            path=path,
811
            server_url=server_url,
812
            params=params,
813
            timeout=timeout,
814
            super_component_expansion=super_component_expansion,
815
        )
816

817
    def _draw_internal(
1✔
818
        self,
819
        *,
820
        path: Path,
821
        server_url: str = "https://mermaid.ink",
822
        params: Optional[dict] = None,
823
        timeout: int = 30,
824
        super_component_expansion: bool = False,
825
    ) -> None:
826
        """
827
        Internal implementation of draw() that uses keyword-only arguments.
828

829
        ToDo: after 2.14.0 release make this the main function and remove the old one.
830
        """
831
        # Before drawing we edit a bit the graph, to avoid modifying the original that is
832
        # used for running the pipeline we copy it.
833
        if super_component_expansion:
1✔
834
            graph, super_component_mapping = self._merge_super_component_pipelines()
×
835
        else:
836
            graph = self.graph
1✔
837
            super_component_mapping = None
1✔
838

839
        image_data = _to_mermaid_image(
1✔
840
            graph,
841
            server_url=server_url,
842
            params=params,
843
            timeout=timeout,
844
            super_component_mapping=super_component_mapping,
845
        )
846
        Path(path).write_bytes(image_data)
1✔
847

848
    def walk(self) -> Iterator[Tuple[str, Component]]:
1✔
849
        """
850
        Visits each component in the pipeline exactly once and yields its name and instance.
851

852
        No guarantees are provided on the visiting order.
853

854
        :returns:
855
            An iterator of tuples of component name and component instance.
856
        """
857
        for component_name, instance in self.graph.nodes(data="instance"):  # type: ignore # type is wrong in networkx
1✔
858
            yield component_name, instance
1✔
859

860
    def warm_up(self):
1✔
861
        """
862
        Make sure all nodes are warm.
863

864
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
865
        without re-initializing everything.
866
        """
867
        for node in self.graph.nodes:
1✔
868
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
869
                logger.info("Warming up component {node}...", node=node)
1✔
870
                self.graph.nodes[node]["instance"].warm_up()
1✔
871

872
    @staticmethod
1✔
873
    def _create_component_span(
1✔
874
        component_name: str, instance: Component, inputs: Dict[str, Any], parent_span: Optional[tracing.Span] = None
875
    ):
876
        return tracing.tracer.trace(
1✔
877
            "haystack.component.run",
878
            tags={
879
                "haystack.component.name": component_name,
880
                "haystack.component.type": instance.__class__.__name__,
881
                "haystack.component.input_types": {k: type(v).__name__ for k, v in inputs.items()},
882
                "haystack.component.input_spec": {
883
                    key: {
884
                        "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
885
                        "senders": value.senders,
886
                    }
887
                    for key, value in instance.__haystack_input__._sockets_dict.items()  # type: ignore
888
                },
889
                "haystack.component.output_spec": {
890
                    key: {
891
                        "type": (value.type.__name__ if isinstance(value.type, type) else str(value.type)),
892
                        "receivers": value.receivers,
893
                    }
894
                    for key, value in instance.__haystack_output__._sockets_dict.items()  # type: ignore
895
                },
896
            },
897
            parent_span=parent_span,
898
        )
899

900
    def _validate_input(self, data: Dict[str, Any]):
1✔
901
        """
902
        Validates pipeline input data.
903

904
        Validates that data:
905
        * Each Component name actually exists in the Pipeline
906
        * Each Component is not missing any input
907
        * Each Component has only one input per input socket, if not variadic
908
        * Each Component doesn't receive inputs that are already sent by another Component
909

910
        :param data:
911
            A dictionary of inputs for the pipeline's components. Each key is a component name.
912

913
        :raises ValueError:
914
            If inputs are invalid according to the above.
915
        """
916
        for component_name, component_inputs in data.items():
1✔
917
            if component_name not in self.graph.nodes:
1✔
918
                raise ValueError(f"Component named {component_name} not found in the pipeline.")
1✔
919
            instance = self.graph.nodes[component_name]["instance"]
1✔
920
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
921
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
922
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
923
            for input_name in component_inputs.keys():
1✔
924
                if input_name not in instance.__haystack_input__._sockets_dict:
1✔
925
                    raise ValueError(f"Input {input_name} not found in component {component_name}.")
1✔
926

927
        for component_name in self.graph.nodes:
1✔
928
            instance = self.graph.nodes[component_name]["instance"]
1✔
929
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
930
                component_inputs = data.get(component_name, {})
1✔
931
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
932
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
933
                if socket.senders and socket_name in component_inputs and not socket.is_variadic:
1✔
934
                    raise ValueError(
1✔
935
                        f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
936
                    )
937

938
    def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
1✔
939
        """
940
        Prepares input data for pipeline components.
941

942
        Organizes input data for pipeline components and identifies any inputs that are not matched to any
943
        component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
944

945
        This method processes a flat dictionary of input data, where each key-value pair represents an input name
946
        and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
947
        their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
948

949
        :param data:
950
            A dictionary potentially having input names as keys and input values as values.
951

952
        :returns:
953
            A dictionary mapping component names to their respective matched inputs.
954
        """
955
        # check whether the data is a nested dictionary of component inputs where each key is a component name
956
        # and each value is a dictionary of input parameters for that component
957
        is_nested_component_input = all(isinstance(value, dict) for value in data.values())
1✔
958
        if not is_nested_component_input:
1✔
959
            # flat input, a dict where keys are input names and values are the corresponding values
960
            # we need to convert it to a nested dictionary of component inputs and then run the pipeline
961
            # just like in the previous case
962
            pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
1✔
963
            unresolved_kwargs = {}
1✔
964

965
            # Retrieve the input slots for each component in the pipeline
966
            available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
1✔
967

968
            # Go through all provided to distribute them to the appropriate component inputs
969
            for input_name, input_value in data.items():
1✔
970
                resolved_at_least_once = False
1✔
971

972
                # Check each component to see if it has a slot for the current kwarg
973
                for component_name, component_inputs in available_inputs.items():
1✔
974
                    if input_name in component_inputs:
1✔
975
                        # If a match is found, add the kwarg to the component's input data
976
                        pipeline_input_data[component_name][input_name] = input_value
1✔
977
                        resolved_at_least_once = True
1✔
978

979
                if not resolved_at_least_once:
1✔
980
                    unresolved_kwargs[input_name] = input_value
1✔
981

982
            if unresolved_kwargs:
1✔
983
                logger.warning(
1✔
984
                    "Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
985
                    input_keys=list(unresolved_kwargs.keys()),
986
                )
987

988
            data = dict(pipeline_input_data)
1✔
989

990
        # deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
991
        # when the same input reference is passed to multiple components.
992
        for component_name, component_inputs in data.items():
1✔
993
            data[component_name] = {k: _deepcopy_with_exceptions(v) for k, v in component_inputs.items()}
1✔
994

995
        return data
1✔
996

997
    @classmethod
1✔
998
    def from_template(
1✔
999
        cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
1000
    ) -> "PipelineBase":
1001
        """
1002
        Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
1003

1004
        :param predefined_pipeline:
1005
            The predefined pipeline to use.
1006
        :param template_params:
1007
            An optional dictionary of parameters to use when rendering the pipeline template.
1008
        :returns:
1009
            An instance of `Pipeline`.
1010
        """
1011
        tpl = PipelineTemplate.from_predefined(predefined_pipeline)
1✔
1012
        # If tpl.render() fails, we let bubble up the original error
1013
        rendered = tpl.render(template_params)
1✔
1014

1015
        # If there was a problem with the rendered version of the
1016
        # template, we add it to the error stack for debugging
1017
        try:
1✔
1018
            return cls.loads(rendered)
1✔
1019
        except Exception as e:
×
1020
            msg = f"Error unmarshalling pipeline: {e}\n"
×
1021
            msg += f"Source:\n{rendered}"
×
1022
            raise PipelineUnmarshalError(msg)
×
1023

1024
    def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
1✔
1025
        """
1026
        Utility function to find all Components that receive input from `component_name`.
1027

1028
        :param component_name:
1029
            Name of the sender Component
1030

1031
        :returns:
1032
            List of tuples containing name of the receiver Component and sender OutputSocket
1033
            and receiver InputSocket instances
1034
        """
1035
        res = []
1✔
1036
        for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
1✔
1037
            sender_socket: OutputSocket = connection["from_socket"]
1✔
1038
            receiver_socket: InputSocket = connection["to_socket"]
1✔
1039
            res.append((receiver_name, sender_socket, receiver_socket))
1✔
1040
        return res
1✔
1041

1042
    @staticmethod
1✔
1043
    def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]:
1✔
1044
        """
1045
        Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic.
1046

1047
        Example Input:
1048
        {'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}}
1049
        Example Output:
1050
        {'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]},
1051
         'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}}
1052

1053
        :param pipeline_inputs: Inputs to the pipeline.
1054
        :returns: Converted inputs that can be used by the internal `Pipeline.run` logic.
1055
        """
1056
        inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
1✔
1057
        for component_name, socket_dict in pipeline_inputs.items():
1✔
1058
            inputs[component_name] = {}
1✔
1059
            for socket_name, value in socket_dict.items():
1✔
1060
                inputs[component_name][socket_name] = [{"sender": None, "value": value}]
1✔
1061

1062
        return inputs
1✔
1063

1064
    @staticmethod
1✔
1065
    def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]:
1✔
1066
        """
1067
        Extracts the inputs needed to run for the component and removes them from the global inputs state.
1068

1069
        :param component_name: The name of a component.
1070
        :param component: Component with component metadata.
1071
        :param inputs: Global inputs state.
1072
        :returns: The inputs for the component.
1073
        """
1074
        component_inputs = inputs.get(component_name, {})
1✔
1075
        consumed_inputs = {}
1✔
1076
        greedy_inputs_to_remove = set()
1✔
1077
        for socket_name, socket in component["input_sockets"].items():
1✔
1078
            socket_inputs = component_inputs.get(socket_name, [])
1✔
1079
            socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
1✔
1080
            if socket_inputs:
1✔
1081
                if not socket.is_variadic:
1✔
1082
                    # We only care about the first input provided to the socket.
1083
                    consumed_inputs[socket_name] = socket_inputs[0]
1✔
1084
                elif socket.is_greedy:
1✔
1085
                    # We need to keep track of greedy inputs because we always remove them, even if they come from
1086
                    # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run
1087
                    # indefinitely.
1088
                    greedy_inputs_to_remove.add(socket_name)
1✔
1089
                    consumed_inputs[socket_name] = [socket_inputs[0]]
1✔
1090
                elif is_socket_lazy_variadic(socket):
1✔
1091
                    # We use all inputs provided to the socket on a lazy variadic socket.
1092
                    consumed_inputs[socket_name] = socket_inputs
1✔
1093

1094
        # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs).
1095
        pruned_inputs = {
1✔
1096
            socket_name: [
1097
                sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove
1098
            ]
1099
            for socket_name, socket in component_inputs.items()
1100
        }
1101
        pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0}
1✔
1102

1103
        inputs[component_name] = pruned_inputs
1✔
1104

1105
        return consumed_inputs
1✔
1106

1107
    def _fill_queue(
1✔
1108
        self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int]
1109
    ) -> FIFOPriorityQueue:
1110
        """
1111
        Calculates the execution priority for each component and inserts it into the priority queue.
1112

1113
        :param component_names: Names of the components to put into the queue.
1114
        :param inputs: Inputs to the components.
1115
        :param component_visits: Current state of component visits.
1116
        :returns: A prioritized queue of component names.
1117
        """
1118
        priority_queue = FIFOPriorityQueue()
1✔
1119
        for component_name in component_names:
1✔
1120
            component = self._get_component_with_graph_metadata_and_visits(
1✔
1121
                component_name, component_visits[component_name]
1122
            )
1123
            priority = self._calculate_priority(component, inputs.get(component_name, {}))
1✔
1124
            priority_queue.push(component_name, priority)
1✔
1125

1126
        return priority_queue
1✔
1127

1128
    @staticmethod
1✔
1129
    def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
1✔
1130
        """
1131
        Calculates the execution priority for a component depending on the component's inputs.
1132

1133
        :param component: Component metadata and component instance.
1134
        :param inputs: Inputs to the component.
1135
        :returns: Priority value for the component.
1136
        """
1137
        if not can_component_run(component, inputs):
1✔
1138
            return ComponentPriority.BLOCKED
1✔
1139
        elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
1✔
1140
            return ComponentPriority.HIGHEST
1✔
1141
        elif all_predecessors_executed(component, inputs):
1✔
1142
            return ComponentPriority.READY
1✔
1143
        elif are_all_lazy_variadic_sockets_resolved(component, inputs):
1✔
1144
            return ComponentPriority.DEFER
1✔
1145
        else:
1146
            return ComponentPriority.DEFER_LAST
1✔
1147

1148
    def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]:
1✔
1149
        """
1150
        Returns the component instance alongside input/output-socket metadata from the graph and adds current visits.
1151

1152
        We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution.
1153

1154
        :param component_name: The name of the component.
1155
        :param visits: Number of visits for the component.
1156
        :returns: Dict including component instance, input/output-sockets and visits.
1157
        """
1158
        comp_dict = self.graph.nodes[component_name]
1✔
1159
        comp_dict = {**comp_dict, "visits": visits}
1✔
1160
        return comp_dict
1✔
1161

1162
    def _get_next_runnable_component(
1✔
1163
        self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int]
1164
    ) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]:
1165
        """
1166
        Returns the next runnable component alongside its metadata from the priority queue.
1167

1168
        :param priority_queue: Priority queue of component names.
1169
        :param component_visits: Current state of component visits.
1170
        :returns: The next runnable component, the component name, and its priority
1171
            or None if no component in the queue can run.
1172
        :raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs.
1173
        """
1174
        priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = (
1✔
1175
            None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1]))
1176
        )
1177

1178
        if priority_and_component_name is not None and priority_and_component_name[0] != ComponentPriority.BLOCKED:
1✔
1179
            priority, component_name = priority_and_component_name
1✔
1180
            component = self._get_component_with_graph_metadata_and_visits(
1✔
1181
                component_name, component_visits[component_name]
1182
            )
1183
            if component["visits"] > self._max_runs_per_component:
1✔
1184
                msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'"
1✔
1185
                raise PipelineMaxComponentRuns(msg)
1✔
1186

1187
            return priority, component_name, component
1✔
1188

1189
        return None
1✔
1190

1191
    @staticmethod
1✔
1192
    def _add_missing_input_defaults(component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket]):
1✔
1193
        """
1194
        Updates the inputs with the default values for the inputs that are missing
1195

1196
        :param component_inputs: Inputs for the component.
1197
        :param component_input_sockets: Input sockets of the component.
1198
        """
1199
        for name, socket in component_input_sockets.items():
1✔
1200
            if not socket.is_mandatory and name not in component_inputs:
1✔
1201
                if socket.is_variadic:
1✔
1202
                    component_inputs[name] = [socket.default_value]
1✔
1203
                else:
1204
                    component_inputs[name] = socket.default_value
1✔
1205

1206
        return component_inputs
1✔
1207

1208
    def _tiebreak_waiting_components(
1✔
1209
        self,
1210
        component_name: str,
1211
        priority: ComponentPriority,
1212
        priority_queue: FIFOPriorityQueue,
1213
        topological_sort: Union[Dict[str, int], None],
1214
    ):
1215
        """
1216
        Decides which component to run when multiple components are waiting for inputs with the same priority.
1217

1218
        :param component_name: The name of the component.
1219
        :param priority: Priority of the component.
1220
        :param priority_queue: Priority queue of component names.
1221
        :param topological_sort: Cached topological sort of all components in the pipeline.
1222
        """
1223
        components_with_same_priority = [component_name]
1✔
1224

1225
        while len(priority_queue) > 0:
1✔
1226
            next_priority, next_component_name = priority_queue.peek()
1✔
1227
            if next_priority == priority:
1✔
1228
                priority_queue.pop()  # actually remove the component
×
1229
                components_with_same_priority.append(next_component_name)
×
1230
            else:
1231
                break
×
1232

1233
        if len(components_with_same_priority) > 1:
1✔
1234
            if topological_sort is None:
×
1235
                if networkx.is_directed_acyclic_graph(self.graph):
×
1236
                    topological_sort = networkx.lexicographical_topological_sort(self.graph)
×
1237
                    topological_sort = {node: idx for idx, node in enumerate(topological_sort)}
×
1238
                else:
1239
                    condensed = networkx.condensation(self.graph)
×
1240
                    condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))}
×
1241
                    topological_sort = {
×
1242
                        component_name: condensed_sorted[node]
1243
                        for component_name, node in condensed.graph["mapping"].items()
1244
                    }
1245

1246
            components_with_same_priority = sorted(
×
1247
                components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower())
1248
            )
1249

1250
            component_name = components_with_same_priority[0]
×
1251

1252
        return component_name, topological_sort
1✔
1253

1254
    @staticmethod
1✔
1255
    def _write_component_outputs(
1✔
1256
        component_name: str,
1257
        component_outputs: Dict[str, Any],
1258
        inputs: Dict[str, Any],
1259
        receivers: List[Tuple],
1260
        include_outputs_from: Set[str],
1261
    ) -> Dict[str, Any]:
1262
        """
1263
        Distributes the outputs of a component to the input sockets that it is connected to.
1264

1265
        :param component_name: The name of the component.
1266
        :param component_outputs: The outputs of the component.
1267
        :param inputs: The current global input state.
1268
        :param receivers: List of components that receive inputs from the component.
1269
        :param include_outputs_from: List of component names that should always return an output from the pipeline.
1270
        """
1271
        for receiver_name, sender_socket, receiver_socket in receivers:
1✔
1272
            # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
1273
            # that the sender did not produce an output for this socket.
1274
            # This allows us to track if a predecessor already ran but did not produce an output.
1275
            value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
1✔
1276

1277
            if receiver_name not in inputs:
1✔
1278
                inputs[receiver_name] = {}
1✔
1279

1280
            if is_socket_lazy_variadic(receiver_socket):
1✔
1281
                # If the receiver socket is lazy variadic, we append the new input.
1282
                # Lazy variadic sockets can collect multiple inputs.
1283
                _write_to_lazy_variadic_socket(
1✔
1284
                    inputs=inputs,
1285
                    receiver_name=receiver_name,
1286
                    receiver_socket_name=receiver_socket.name,
1287
                    component_name=component_name,
1288
                    value=value,
1289
                )
1290
            else:
1291
                # If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic.
1292
                # We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None.
1293
                _write_to_standard_socket(
1✔
1294
                    inputs=inputs,
1295
                    receiver_name=receiver_name,
1296
                    receiver_socket_name=receiver_socket.name,
1297
                    component_name=component_name,
1298
                    value=value,
1299
                )
1300

1301
        # If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
1302
        # outputs
1303
        if component_name in include_outputs_from:
1✔
1304
            return component_outputs
1✔
1305

1306
        # We prune outputs that were consumed by any receiving sockets.
1307
        # All remaining outputs will be added to the final outputs of the pipeline.
1308
        consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
1✔
1309
        pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
1✔
1310

1311
        return pruned_outputs
1✔
1312

1313
    @staticmethod
1✔
1314
    def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool:
1✔
1315
        """
1316
        Checks if the priority queue needs to be recomputed because the priorities might have changed.
1317

1318
        :param priority_queue: Priority queue of component names.
1319
        """
1320
        return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY
1✔
1321

1322
    @staticmethod
1✔
1323
    def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
1✔
1324
        """
1325
        Validate the pipeline to check if it is blocked or has no valid entry point.
1326

1327
        :param priority_queue: Priority queue of component names.
1328
        :raises PipelineRuntimeError:
1329
            If the pipeline is blocked or has no valid entry point.
1330
        """
1331
        if len(priority_queue) == 0:
1✔
1332
            return
×
1333

1334
        candidate = priority_queue.peek()
1✔
1335
        if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
1✔
1336
            raise PipelineComponentsBlockedError()
×
1337

1338
    def _find_super_components(self) -> list[tuple[str, Component]]:
1✔
1339
        """
1340
        Find all SuperComponents in the pipeline.
1341

1342
        :returns:
1343
            List of tuples containing (component_name, component_instance) representing a SuperComponent.
1344
        """
1345

1346
        super_components = []
1✔
1347
        for comp_name, comp in self.walk():
1✔
1348
            # a SuperComponent has a "pipeline" attribute which itself a Pipeline instance
1349
            # we don't test against SuperComponent because doing so always lead to circular imports
1350
            if hasattr(comp, "pipeline") and isinstance(comp.pipeline, self.__class__):
1✔
1351
                super_components.append((comp_name, comp))
1✔
1352
        return super_components
1✔
1353

1354
    def _merge_super_component_pipelines(self) -> Tuple["networkx.MultiDiGraph", Dict[str, str]]:
1✔
1355
        """
1356
        Merge the internal pipelines of SuperComponents into the main pipeline graph structure.
1357

1358
        This creates a new networkx.MultiDiGraph containing all the components from both the main pipeline
1359
        and all the internal SuperComponents' pipelines. The SuperComponents are removed and their internal
1360
        components are connected to corresponding input and output sockets of the main pipeline.
1361

1362
        :returns:
1363
            A tuple containing:
1364
            - A networkx.MultiDiGraph with the expanded structure of the main pipeline and all it's SuperComponents
1365
            - A dictionary mapping component names to boolean indicating that this component was part of a
1366
              SuperComponent
1367
            - A dictionary mapping component names to their SuperComponent name
1368
        """
1369
        merged_graph = self.graph.copy()
1✔
1370
        super_component_mapping: Dict[str, str] = {}
1✔
1371

1372
        for super_name, super_component in self._find_super_components():
1✔
1373
            internal_pipeline = super_component.pipeline  # type: ignore
1✔
1374
            internal_graph = internal_pipeline.graph.copy()
1✔
1375

1376
            # Mark all components in the internal pipeline as being part of a SuperComponent
1377
            for node in internal_graph.nodes():
1✔
1378
                super_component_mapping[node] = super_name
1✔
1379

1380
            # edges connected to the super component
1381
            incoming_edges = list(merged_graph.in_edges(super_name, data=True))
1✔
1382
            outgoing_edges = list(merged_graph.out_edges(super_name, data=True))
1✔
1383

1384
            # merge the SuperComponent graph into the main graph and remove the super component node
1385
            # since its components are now part of the main graph
1386
            merged_graph = networkx.compose(merged_graph, internal_graph)
1✔
1387
            merged_graph.remove_node(super_name)
1✔
1388

1389
            # get the entry and exit points of the SuperComponent internal pipeline
1390
            entry_points = [n for n in internal_graph.nodes() if internal_graph.in_degree(n) == 0]
1✔
1391
            exit_points = [n for n in internal_graph.nodes() if internal_graph.out_degree(n) == 0]
1✔
1392

1393
            # connect the incoming edges to entry points
1394
            for sender, _, edge_data in incoming_edges:
1✔
1395
                sender_socket = edge_data["from_socket"]
1✔
1396
                for entry_point in entry_points:
1✔
1397
                    # find a matching input socket in the entry point
1398
                    entry_point_sockets = internal_graph.nodes[entry_point]["input_sockets"]
1✔
1399
                    for socket_name, socket in entry_point_sockets.items():
1✔
1400
                        if _types_are_compatible(sender_socket.type, socket.type, self._connection_type_validation):
1✔
1401
                            merged_graph.add_edge(
1✔
1402
                                sender,
1403
                                entry_point,
1404
                                key=f"{sender_socket.name}/{socket_name}",
1405
                                conn_type=_type_name(sender_socket.type),
1406
                                from_socket=sender_socket,
1407
                                to_socket=socket,
1408
                                mandatory=socket.is_mandatory,
1409
                            )
1410

1411
            # connect outgoing edges from exit points
1412
            for _, receiver, edge_data in outgoing_edges:
1✔
1413
                receiver_socket = edge_data["to_socket"]
1✔
1414
                for exit_point in exit_points:
1✔
1415
                    # find a matching output socket in the exit point
1416
                    exit_point_sockets = internal_graph.nodes[exit_point]["output_sockets"]
1✔
1417
                    for socket_name, socket in exit_point_sockets.items():
1✔
1418
                        if _types_are_compatible(socket.type, receiver_socket.type, self._connection_type_validation):
1✔
1419
                            merged_graph.add_edge(
1✔
1420
                                exit_point,
1421
                                receiver,
1422
                                key=f"{socket_name}/{receiver_socket.name}",
1423
                                conn_type=_type_name(socket.type),
1424
                                from_socket=socket,
1425
                                to_socket=receiver_socket,
1426
                                mandatory=receiver_socket.is_mandatory,
1427
                            )
1428

1429
        return merged_graph, super_component_mapping
1✔
1430

1431

1432
def _connections_status(
1✔
1433
    sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
1434
) -> str:
1435
    """
1436
    Lists the status of the sockets, for error messages.
1437
    """
1438
    sender_sockets_entries = []
1✔
1439
    for sender_socket in sender_sockets:
1✔
1440
        sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
1✔
1441
    sender_sockets_list = "\n".join(sender_sockets_entries)
1✔
1442

1443
    receiver_sockets_entries = []
1✔
1444
    for receiver_socket in receiver_sockets:
1✔
1445
        if receiver_socket.senders:
1✔
1446
            sender_status = f"sent by {','.join(receiver_socket.senders)}"
1✔
1447
        else:
1448
            sender_status = "available"
1✔
1449
        receiver_sockets_entries.append(
1✔
1450
            f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
1451
        )
1452
    receiver_sockets_list = "\n".join(receiver_sockets_entries)
1✔
1453

1454
    return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
1✔
1455

1456

1457
# Utility functions for writing to sockets
1458

1459

1460
def _write_to_lazy_variadic_socket(
1✔
1461
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1462
) -> None:
1463
    """
1464
    Write to a lazy variadic socket.
1465

1466
    Mutates inputs in place.
1467
    """
1468
    if not inputs[receiver_name].get(receiver_socket_name):
1✔
1469
        inputs[receiver_name][receiver_socket_name] = []
1✔
1470

1471
    inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value})
1✔
1472

1473

1474
def _write_to_standard_socket(
1✔
1475
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1476
) -> None:
1477
    """
1478
    Write to a greedy variadic or non-variadic socket.
1479

1480
    Mutates inputs in place.
1481
    """
1482
    current_value = inputs[receiver_name].get(receiver_socket_name)
1✔
1483

1484
    # Only overwrite if there's no existing value, or we have a new value to provide
1485
    if current_value is None or value is not _NO_OUTPUT_PRODUCED:
1✔
1486
        inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc