• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13633822034

03 Mar 2025 03:00PM UTC coverage: 89.98% (+0.007%) from 89.973%
13633822034

push

github

web-flow
feat: Add Type Validation parameter for Pipeline Connections (#8875)

* Starting to refactor type util tests to be more systematic

* refactoring

* Expand tests

* Update to type utils

* Add missing subclass check

* Expand and refactor tests, introduce type_validation Literal

* More test refactoring

* Test refactoring, adding type validation variable to pipeline base

* Update relaxed version of type checking to pass all newly added tests

* trim whitespace

* Add tests

* cleanup

* Updates docstrings

* Add reno

* docs

* Fix mypy and add docstrings

* Changes based on advice from Tobi

* Remove unused imports

* Doc strings

* Add connection type validation to to_dict and from_dict

* Update tests

* Fix test

* Also save connection_type_validation at global pipeline level

* Fix tests

* Remove connection type validation from the connect level, only keep at pipeline level

* Formatting

* Fix tests

* formatting

9573 of 10639 relevant lines covered (89.98%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.85
haystack/core/pipeline/base.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import itertools
1✔
6
from collections import defaultdict
1✔
7
from copy import deepcopy
1✔
8
from datetime import datetime
1✔
9
from enum import IntEnum
1✔
10
from pathlib import Path
1✔
11
from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union
1✔
12

13
import networkx  # type:ignore
1✔
14

15
from haystack import logging
1✔
16
from haystack.core.component import Component, InputSocket, OutputSocket, component
1✔
17
from haystack.core.errors import (
1✔
18
    DeserializationError,
19
    PipelineConnectError,
20
    PipelineDrawingError,
21
    PipelineError,
22
    PipelineMaxComponentRuns,
23
    PipelineRuntimeError,
24
    PipelineUnmarshalError,
25
    PipelineValidationError,
26
)
27
from haystack.core.pipeline.component_checks import (
1✔
28
    _NO_OUTPUT_PRODUCED,
29
    all_predecessors_executed,
30
    are_all_lazy_variadic_sockets_resolved,
31
    are_all_sockets_ready,
32
    can_component_run,
33
    is_any_greedy_socket_ready,
34
    is_socket_lazy_variadic,
35
)
36
from haystack.core.pipeline.utils import FIFOPriorityQueue, parse_connect_string
1✔
37
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
1✔
38
from haystack.core.type_utils import _type_name, _types_are_compatible
1✔
39
from haystack.marshal import Marshaller, YamlMarshaller
1✔
40
from haystack.utils import is_in_jupyter, type_serialization
1✔
41

42
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
1✔
43
from .draw import _to_mermaid_image
1✔
44
from .template import PipelineTemplate, PredefinedPipeline
1✔
45

46
DEFAULT_MARSHALLER = YamlMarshaller()
1✔
47

48
# We use a generic type to annotate the return value of class methods,
49
# so that static analyzers won't be confused when derived classes
50
# use those methods.
51
T = TypeVar("T", bound="PipelineBase")
1✔
52

53
logger = logging.getLogger(__name__)
1✔
54

55

56
class ComponentPriority(IntEnum):
1✔
57
    HIGHEST = 1
1✔
58
    READY = 2
1✔
59
    DEFER = 3
1✔
60
    DEFER_LAST = 4
1✔
61
    BLOCKED = 5
1✔
62

63

64
class PipelineBase:
1✔
65
    """
66
    Components orchestration engine.
67

68
    Builds a graph of components and orchestrates their execution according to the execution graph.
69
    """
70

71
    def __init__(
1✔
72
        self,
73
        metadata: Optional[Dict[str, Any]] = None,
74
        max_runs_per_component: int = 100,
75
        connection_type_validation: bool = True,
76
    ):
77
        """
78
        Creates the Pipeline.
79

80
        :param metadata:
81
            Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in
82
            this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file.
83
        :param max_runs_per_component:
84
            How many times the `Pipeline` can run the same Component.
85
            If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
86
            If not set defaults to 100 runs per Component.
87
        :param connection_type_validation: Whether the pipeline will validate the types of the connections.
88
            Defaults to True.
89
        """
90
        self._telemetry_runs = 0
1✔
91
        self._last_telemetry_sent: Optional[datetime] = None
1✔
92
        self.metadata = metadata or {}
1✔
93
        self.graph = networkx.MultiDiGraph()
1✔
94
        self._max_runs_per_component = max_runs_per_component
1✔
95
        self._connection_type_validation = connection_type_validation
1✔
96

97
    def __eq__(self, other) -> bool:
1✔
98
        """
99
        Pipeline equality is defined by their type and the equality of their serialized form.
100

101
        Pipelines of the same type share every metadata, node and edge, but they're not required to use
102
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
103
        """
104
        if not isinstance(self, type(other)):
1✔
105
            return False
×
106
        return self.to_dict() == other.to_dict()
1✔
107

108
    def __repr__(self) -> str:
1✔
109
        """
110
        Returns a text representation of the Pipeline.
111
        """
112
        res = f"{object.__repr__(self)}\n"
1✔
113
        if self.metadata:
1✔
114
            res += "🧱 Metadata\n"
1✔
115
            for k, v in self.metadata.items():
1✔
116
                res += f"  - {k}: {v}\n"
1✔
117

118
        res += "🚅 Components\n"
1✔
119
        for name, instance in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
120
            res += f"  - {name}: {instance.__class__.__name__}\n"
1✔
121

122
        res += "🛤️ Connections\n"
1✔
123
        for sender, receiver, edge_data in self.graph.edges(data=True):
1✔
124
            sender_socket = edge_data["from_socket"].name
1✔
125
            receiver_socket = edge_data["to_socket"].name
1✔
126
            res += f"  - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
1✔
127

128
        return res
1✔
129

130
    def to_dict(self) -> Dict[str, Any]:
1✔
131
        """
132
        Serializes the pipeline to a dictionary.
133

134
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
135

136
        :returns:
137
            Dictionary with serialized data.
138
        """
139
        components = {}
1✔
140
        for name, instance in self.graph.nodes(data="instance"):  # type:ignore
1✔
141
            components[name] = component_to_dict(instance, name)
1✔
142

143
        connections = []
1✔
144
        for sender, receiver, edge_data in self.graph.edges.data():
1✔
145
            sender_socket = edge_data["from_socket"].name
1✔
146
            receiver_socket = edge_data["to_socket"].name
1✔
147
            connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
1✔
148
        return {
1✔
149
            "metadata": self.metadata,
150
            "max_runs_per_component": self._max_runs_per_component,
151
            "components": components,
152
            "connections": connections,
153
            "connection_type_validation": self._connection_type_validation,
154
        }
155

156
    @classmethod
1✔
157
    def from_dict(
1✔
158
        cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
159
    ) -> T:
160
        """
161
        Deserializes the pipeline from a dictionary.
162

163
        :param data:
164
            Dictionary to deserialize from.
165
        :param callbacks:
166
            Callbacks to invoke during deserialization.
167
        :param kwargs:
168
            `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new
169
            ones.
170
        :returns:
171
            Deserialized component.
172
        """
173
        data_copy = deepcopy(data)  # to prevent modification of original data
1✔
174
        metadata = data_copy.get("metadata", {})
1✔
175
        max_runs_per_component = data_copy.get("max_runs_per_component", 100)
1✔
176
        connection_type_validation = data_copy.get("connection_type_validation", True)
1✔
177
        pipe = cls(
1✔
178
            metadata=metadata,
179
            max_runs_per_component=max_runs_per_component,
180
            connection_type_validation=connection_type_validation,
181
        )
182
        components_to_reuse = kwargs.get("components", {})
1✔
183
        for name, component_data in data_copy.get("components", {}).items():
1✔
184
            if name in components_to_reuse:
1✔
185
                # Reuse an instance
186
                instance = components_to_reuse[name]
1✔
187
            else:
188
                if "type" not in component_data:
1✔
189
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
190

191
                if component_data["type"] not in component.registry:
1✔
192
                    try:
1✔
193
                        # Import the module first...
194
                        module, _ = component_data["type"].rsplit(".", 1)
1✔
195
                        logger.debug("Trying to import module {module_name}", module_name=module)
1✔
196
                        type_serialization.thread_safe_import(module)
1✔
197
                        # ...then try again
198
                        if component_data["type"] not in component.registry:
1✔
199
                            raise PipelineError(
×
200
                                f"Successfully imported module {module} but can't find it in the component registry."
201
                                "This is unexpected and most likely a bug."
202
                            )
203
                    except (ImportError, PipelineError, ValueError) as e:
1✔
204
                        raise PipelineError(
1✔
205
                            f"Component '{component_data['type']}' (name: '{name}') not imported."
206
                        ) from e
207

208
                # Create a new one
209
                component_class = component.registry[component_data["type"]]
1✔
210

211
                try:
1✔
212
                    instance = component_from_dict(component_class, component_data, name, callbacks)
1✔
213
                except Exception as e:
1✔
214
                    msg = (
1✔
215
                        f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
216
                        f"with the following data: {str(component_data)}. Possible reasons include "
217
                        "malformed serialized data, mismatch between the serialized component and the "
218
                        "loaded one (due to a breaking change, see "
219
                        "https://github.com/deepset-ai/haystack/releases), etc."
220
                    )
221
                    raise DeserializationError(msg) from e
1✔
222
            pipe.add_component(name=name, instance=instance)
1✔
223

224
        for connection in data.get("connections", []):
1✔
225
            if "sender" not in connection:
1✔
226
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
227
            if "receiver" not in connection:
1✔
228
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
229
            pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
1✔
230

231
        return pipe
1✔
232

233
    def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
1✔
234
        """
235
        Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
236

237
        :param marshaller:
238
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
239
        :returns:
240
            A string representing the pipeline.
241
        """
242
        return marshaller.marshal(self.to_dict())
1✔
243

244
    def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
1✔
245
        """
246
        Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
247

248
        :param fp:
249
            A file-like object ready to be written to.
250
        :param marshaller:
251
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
252
        """
253
        fp.write(marshaller.marshal(self.to_dict()))
1✔
254

255
    @classmethod
1✔
256
    def loads(
1✔
257
        cls: Type[T],
258
        data: Union[str, bytes, bytearray],
259
        marshaller: Marshaller = DEFAULT_MARSHALLER,
260
        callbacks: Optional[DeserializationCallbacks] = None,
261
    ) -> T:
262
        """
263
        Creates a `Pipeline` object from the string representation passed in the `data` argument.
264

265
        :param data:
266
            The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
267
        :param marshaller:
268
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
269
        :param callbacks:
270
            Callbacks to invoke during deserialization.
271
        :raises DeserializationError:
272
            If an error occurs during deserialization.
273
        :returns:
274
            A `Pipeline` object.
275
        """
276
        try:
1✔
277
            deserialized_data = marshaller.unmarshal(data)
1✔
278
        except Exception as e:
1✔
279
            raise DeserializationError(
1✔
280
                "Error while unmarshalling serialized pipeline data. This is usually "
281
                "caused by malformed or invalid syntax in the serialized representation."
282
            ) from e
283

284
        return cls.from_dict(deserialized_data, callbacks)
1✔
285

286
    @classmethod
1✔
287
    def load(
1✔
288
        cls: Type[T],
289
        fp: TextIO,
290
        marshaller: Marshaller = DEFAULT_MARSHALLER,
291
        callbacks: Optional[DeserializationCallbacks] = None,
292
    ) -> T:
293
        """
294
        Creates a `Pipeline` object a string representation.
295

296
        The string representation is read from the file-like object passed in the `fp` argument.
297

298

299
        :param fp:
300
            A file-like object ready to be read from.
301
        :param marshaller:
302
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
303
        :param callbacks:
304
            Callbacks to invoke during deserialization.
305
        :raises DeserializationError:
306
            If an error occurs during deserialization.
307
        :returns:
308
            A `Pipeline` object.
309
        """
310
        return cls.loads(fp.read(), marshaller, callbacks)
1✔
311

312
    def add_component(self, name: str, instance: Component) -> None:
1✔
313
        """
314
        Add the given component to the pipeline.
315

316
        Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
317
        Component names must be unique, but component instances can be reused if needed.
318

319
        :param name:
320
            The name of the component to add.
321
        :param instance:
322
            The component instance to add.
323

324
        :raises ValueError:
325
            If a component with the same name already exists.
326
        :raises PipelineValidationError:
327
            If the given instance is not a component.
328
        """
329
        # Component names are unique
330
        if name in self.graph.nodes:
1✔
331
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
332

333
        # Components can't be named `_debug`
334
        if name == "_debug":
1✔
335
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
336

337
        # Component instances must be components
338
        if not isinstance(instance, Component):
1✔
339
            raise PipelineValidationError(
×
340
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
341
            )
342

343
        if getattr(instance, "__haystack_added_to_pipeline__", None):
1✔
344
            msg = (
1✔
345
                "Component has already been added in another Pipeline. Components can't be shared between Pipelines. "
346
                "Create a new instance instead."
347
            )
348
            raise PipelineError(msg)
1✔
349

350
        setattr(instance, "__haystack_added_to_pipeline__", self)
1✔
351

352
        # Add component to the graph, disconnected
353
        logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
1✔
354
        # We're completely sure the fields exist so we ignore the type error
355
        self.graph.add_node(
1✔
356
            name,
357
            instance=instance,
358
            input_sockets=instance.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
359
            output_sockets=instance.__haystack_output__._sockets_dict,  # type: ignore[attr-defined]
360
            visits=0,
361
        )
362

363
    def remove_component(self, name: str) -> Component:
1✔
364
        """
365
        Remove and returns component from the pipeline.
366

367
        Remove an existing component from the pipeline by providing its name.
368
        All edges that connect to the component will also be deleted.
369

370
        :param name:
371
            The name of the component to remove.
372
        :returns:
373
            The removed Component instance.
374

375
        :raises ValueError:
376
            If there is no component with that name already in the Pipeline.
377
        """
378

379
        # Check that a component with that name is in the Pipeline
380
        try:
1✔
381
            instance = self.get_component(name)
1✔
382
        except ValueError as exc:
1✔
383
            raise ValueError(
1✔
384
                f"There is no component named '{name}' in the pipeline. The valid component names are: ",
385
                ", ".join(n for n in self.graph.nodes),
386
            ) from exc
387

388
        # Delete component from the graph, deleting all its connections
389
        self.graph.remove_node(name)
1✔
390

391
        # Reset the Component sockets' senders and receivers
392
        input_sockets = instance.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
1✔
393
        for socket in input_sockets.values():
1✔
394
            socket.senders = []
1✔
395

396
        output_sockets = instance.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
1✔
397
        for socket in output_sockets.values():
1✔
398
            socket.receivers = []
1✔
399

400
        # Reset the Component's pipeline reference
401
        setattr(instance, "__haystack_added_to_pipeline__", None)
1✔
402

403
        return instance
1✔
404

405
    def connect(self, sender: str, receiver: str) -> "PipelineBase":  # noqa: PLR0915 PLR0912
1✔
406
        """
407
        Connects two components together.
408

409
        All components to connect must exist in the pipeline.
410
        If connecting to a component that has several output connections, specify the inputs and output names as
411
        'component_name.connections_name'.
412

413
        :param sender:
414
            The component that delivers the value. This can be either just a component name or can be
415
            in the format `component_name.connection_name` if the component has multiple outputs.
416
        :param receiver:
417
            The component that receives the value. This can be either just a component name or can be
418
            in the format `component_name.connection_name` if the component has multiple inputs.
419
        :param connection_type_validation: Whether the pipeline will validate the types of the connections.
420
            Defaults to the value set in the pipeline.
421
        :returns:
422
            The Pipeline instance.
423

424
        :raises PipelineConnectError:
425
            If the two components cannot be connected (for example if one of the components is
426
            not present in the pipeline, or the connections don't match by type, and so on).
427
        """
428
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
429
        sender_component_name, sender_socket_name = parse_connect_string(sender)
1✔
430
        receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
1✔
431

432
        if sender_component_name == receiver_component_name:
1✔
433
            raise PipelineConnectError("Connecting a Component to itself is not supported.")
1✔
434

435
        # Get the nodes data.
436
        try:
1✔
437
            sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
1✔
438
        except KeyError as exc:
1✔
439
            raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
1✔
440
        try:
1✔
441
            receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
1✔
442
        except KeyError as exc:
1✔
443
            raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
1✔
444

445
        # If the name of either socket is given, get the socket
446
        sender_socket: Optional[OutputSocket] = None
1✔
447
        if sender_socket_name:
1✔
448
            sender_socket = sender_sockets.get(sender_socket_name)
1✔
449
            if not sender_socket:
1✔
450
                raise PipelineConnectError(
1✔
451
                    f"'{sender} does not exist. "
452
                    f"Output connections of {sender_component_name} are: "
453
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
454
                )
455

456
        receiver_socket: Optional[InputSocket] = None
1✔
457
        if receiver_socket_name:
1✔
458
            receiver_socket = receiver_sockets.get(receiver_socket_name)
1✔
459
            if not receiver_socket:
1✔
460
                raise PipelineConnectError(
1✔
461
                    f"'{receiver} does not exist. "
462
                    f"Input connections of {receiver_component_name} are: "
463
                    + ", ".join(
464
                        [f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
465
                    )
466
                )
467

468
        # Look for a matching connection among the possible ones.
469
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
470
        sender_socket_candidates: List[OutputSocket] = (
1✔
471
            [sender_socket] if sender_socket else list(sender_sockets.values())
472
        )
473
        receiver_socket_candidates: List[InputSocket] = (
1✔
474
            [receiver_socket] if receiver_socket else list(receiver_sockets.values())
475
        )
476

477
        # Find all possible connections between these two components
478
        possible_connections = []
1✔
479
        for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
1✔
480
            if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
1✔
481
                possible_connections.append((sender_sock, receiver_sock))
1✔
482

483
        # We need this status for error messages, since we might need it in multiple places we calculate it here
484
        status = _connections_status(
1✔
485
            sender_node=sender_component_name,
486
            sender_sockets=sender_socket_candidates,
487
            receiver_node=receiver_component_name,
488
            receiver_sockets=receiver_socket_candidates,
489
        )
490

491
        if not possible_connections:
1✔
492
            # There's no possible connection between these two components
493
            if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
1✔
494
                msg = (
1✔
495
                    f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with "
496
                    f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': "
497
                    f"their declared input and output types do not match.\n{status}"
498
                )
499
            else:
500
                msg = (
×
501
                    f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
502
                    f"no matching connections available.\n{status}"
503
                )
504
            raise PipelineConnectError(msg)
1✔
505

506
        if len(possible_connections) == 1:
1✔
507
            # There's only one possible connection, use it
508
            sender_socket = possible_connections[0][0]
1✔
509
            receiver_socket = possible_connections[0][1]
1✔
510

511
        if len(possible_connections) > 1:
1✔
512
            # There are multiple possible connection, let's try to match them by name
513
            name_matches = [
1✔
514
                (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
515
            ]
516
            if len(name_matches) != 1:
1✔
517
                # There's are either no matches or more than one, we can't pick one reliably
518
                msg = (
1✔
519
                    f"Cannot connect '{sender_component_name}' with "
520
                    f"'{receiver_component_name}': more than one connection is possible "
521
                    "between these components. Please specify the connection name, like: "
522
                    f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
523
                    f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
524
                )
525
                raise PipelineConnectError(msg)
1✔
526

527
            # Get the only possible match
528
            sender_socket = name_matches[0][0]
1✔
529
            receiver_socket = name_matches[0][1]
1✔
530

531
        # Connection must be valid on both sender/receiver sides
532
        if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
1✔
533
            if sender_component_name and sender_socket:
×
534
                sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
×
535
            else:
536
                sender_repr = "input needed"
×
537

538
            if receiver_component_name and receiver_socket:
×
539
                receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
×
540
            else:
541
                receiver_repr = "output"
×
542
            msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
×
543
            raise PipelineConnectError(msg)
×
544

545
        logger.debug(
1✔
546
            "Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
547
            sender_component=sender_component_name,
548
            sender_socket_name=sender_socket.name,
549
            receiver_component=receiver_component_name,
550
            receiver_socket_name=receiver_socket.name,
551
        )
552

553
        if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
1✔
554
            # This is already connected, nothing to do
555
            return self
1✔
556

557
        if receiver_socket.senders and not receiver_socket.is_variadic:
1✔
558
            # Only variadic input sockets can receive from multiple senders
559
            msg = (
1✔
560
                f"Cannot connect '{sender_component_name}.{sender_socket.name}' with "
561
                f"'{receiver_component_name}.{receiver_socket.name}': "
562
                f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
563
            )
564
            raise PipelineConnectError(msg)
1✔
565

566
        # Update the sockets with the new connection
567
        sender_socket.receivers.append(receiver_component_name)
1✔
568
        receiver_socket.senders.append(sender_component_name)
1✔
569

570
        # Create the new connection
571
        self.graph.add_edge(
1✔
572
            sender_component_name,
573
            receiver_component_name,
574
            key=f"{sender_socket.name}/{receiver_socket.name}",
575
            conn_type=_type_name(sender_socket.type),
576
            from_socket=sender_socket,
577
            to_socket=receiver_socket,
578
            mandatory=receiver_socket.is_mandatory,
579
        )
580
        return self
1✔
581

582
    def get_component(self, name: str) -> Component:
1✔
583
        """
584
        Get the component with the specified name from the pipeline.
585

586
        :param name:
587
            The name of the component.
588
        :returns:
589
            The instance of that component.
590

591
        :raises ValueError:
592
            If a component with that name is not present in the pipeline.
593
        """
594
        try:
1✔
595
            return self.graph.nodes[name]["instance"]
1✔
596
        except KeyError as exc:
1✔
597
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
1✔
598

599
    def get_component_name(self, instance: Component) -> str:
1✔
600
        """
601
        Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
602

603
        :param instance:
604
            The Component instance to look for.
605
        :returns:
606
            The name of the Component instance.
607
        """
608
        for name, inst in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
609
            if inst == instance:
1✔
610
                return name
1✔
611
        return ""
1✔
612

613
    def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
614
        """
615
        Returns a dictionary containing the inputs of a pipeline.
616

617
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
618
        the input sockets of that component, including their types and whether they are optional.
619

620
        :param include_components_with_connected_inputs:
621
            If `False`, only components that have disconnected input edges are
622
            included in the output.
623
        :returns:
624
            A dictionary where each key is a pipeline component name and each value is a dictionary of
625
            inputs sockets of that component.
626
        """
627
        inputs: Dict[str, Dict[str, Any]] = {}
1✔
628
        for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
1✔
629
            sockets_description = {}
1✔
630
            for socket in data:
1✔
631
                sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
1✔
632
                if not socket.is_mandatory:
1✔
633
                    sockets_description[socket.name]["default_value"] = socket.default_value
1✔
634

635
            if sockets_description:
1✔
636
                inputs[component_name] = sockets_description
1✔
637
        return inputs
1✔
638

639
    def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
640
        """
641
        Returns a dictionary containing the outputs of a pipeline.
642

643
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
644
        the output sockets of that component.
645

646
        :param include_components_with_connected_outputs:
647
            If `False`, only components that have disconnected output edges are
648
            included in the output.
649
        :returns:
650
            A dictionary where each key is a pipeline component name and each value is a dictionary of
651
            output sockets of that component.
652
        """
653
        outputs = {
1✔
654
            comp: {socket.name: {"type": socket.type} for socket in data}
655
            for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
656
            if data
657
        }
658
        return outputs
1✔
659

660
    def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
1✔
661
        """
662
        Display an image representing this `Pipeline` in a Jupyter notebook.
663

664
        This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
665
        the notebook.
666

667
        :param server_url:
668
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
669
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
670
            info on how to set up your own Mermaid server.
671

672
        :param params:
673
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
674
            Supported keys:
675
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
676
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
677
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
678
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
679
                - width: Width of the output image (integer).
680
                - height: Height of the output image (integer).
681
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
682
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
683
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
684
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
685

686
        :raises PipelineDrawingError:
687
            If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
688
        """
689
        if is_in_jupyter():
1✔
690
            from IPython.display import Image, display  # type: ignore
1✔
691

692
            image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
1✔
693
            display(Image(image_data))
1✔
694
        else:
695
            msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
1✔
696
            raise PipelineDrawingError(msg)
1✔
697

698
    def draw(self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None) -> None:
1✔
699
        """
700
        Save an image representing this `Pipeline` to the specified file path.
701

702
        This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
703

704
        :param path:
705
            The file path where the generated image will be saved.
706
        :param server_url:
707
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
708
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
709
            info on how to set up your own Mermaid server.
710
        :param params:
711
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
712
            Supported keys:
713
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
714
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
715
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
716
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
717
                - width: Width of the output image (integer).
718
                - height: Height of the output image (integer).
719
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
720
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
721
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
722
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
723

724
        :raises PipelineDrawingError:
725
            If there is an issue with rendering or saving the image.
726
        """
727
        # Before drawing we edit a bit the graph, to avoid modifying the original that is
728
        # used for running the pipeline we copy it.
729
        image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params)
1✔
730
        Path(path).write_bytes(image_data)
1✔
731

732
    def walk(self) -> Iterator[Tuple[str, Component]]:
1✔
733
        """
734
        Visits each component in the pipeline exactly once and yields its name and instance.
735

736
        No guarantees are provided on the visiting order.
737

738
        :returns:
739
            An iterator of tuples of component name and component instance.
740
        """
741
        for component_name, instance in self.graph.nodes(data="instance"):  # type: ignore # type is wrong in networkx
1✔
742
            yield component_name, instance
1✔
743

744
    def warm_up(self):
1✔
745
        """
746
        Make sure all nodes are warm.
747

748
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
749
        without re-initializing everything.
750
        """
751
        for node in self.graph.nodes:
1✔
752
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
753
                logger.info("Warming up component {node}...", node=node)
×
754
                self.graph.nodes[node]["instance"].warm_up()
×
755

756
    def _validate_input(self, data: Dict[str, Any]):
1✔
757
        """
758
        Validates pipeline input data.
759

760
        Validates that data:
761
        * Each Component name actually exists in the Pipeline
762
        * Each Component is not missing any input
763
        * Each Component has only one input per input socket, if not variadic
764
        * Each Component doesn't receive inputs that are already sent by another Component
765

766
        :param data:
767
            A dictionary of inputs for the pipeline's components. Each key is a component name.
768

769
        :raises ValueError:
770
            If inputs are invalid according to the above.
771
        """
772
        for component_name, component_inputs in data.items():
1✔
773
            if component_name not in self.graph.nodes:
1✔
774
                raise ValueError(f"Component named {component_name} not found in the pipeline.")
1✔
775
            instance = self.graph.nodes[component_name]["instance"]
1✔
776
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
777
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
778
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
779
            for input_name in component_inputs.keys():
1✔
780
                if input_name not in instance.__haystack_input__._sockets_dict:
1✔
781
                    raise ValueError(f"Input {input_name} not found in component {component_name}.")
1✔
782

783
        for component_name in self.graph.nodes:
1✔
784
            instance = self.graph.nodes[component_name]["instance"]
1✔
785
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
786
                component_inputs = data.get(component_name, {})
1✔
787
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
788
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
789
                if socket.senders and socket_name in component_inputs and not socket.is_variadic:
1✔
790
                    raise ValueError(
1✔
791
                        f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
792
                    )
793

794
    def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
1✔
795
        """
796
        Prepares input data for pipeline components.
797

798
        Organizes input data for pipeline components and identifies any inputs that are not matched to any
799
        component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
800

801
        This method processes a flat dictionary of input data, where each key-value pair represents an input name
802
        and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
803
        their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
804

805
        :param data:
806
            A dictionary potentially having input names as keys and input values as values.
807

808
        :returns:
809
            A dictionary mapping component names to their respective matched inputs.
810
        """
811
        # check whether the data is a nested dictionary of component inputs where each key is a component name
812
        # and each value is a dictionary of input parameters for that component
813
        is_nested_component_input = all(isinstance(value, dict) for value in data.values())
1✔
814
        if not is_nested_component_input:
1✔
815
            # flat input, a dict where keys are input names and values are the corresponding values
816
            # we need to convert it to a nested dictionary of component inputs and then run the pipeline
817
            # just like in the previous case
818
            pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
1✔
819
            unresolved_kwargs = {}
1✔
820

821
            # Retrieve the input slots for each component in the pipeline
822
            available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
1✔
823

824
            # Go through all provided to distribute them to the appropriate component inputs
825
            for input_name, input_value in data.items():
1✔
826
                resolved_at_least_once = False
1✔
827

828
                # Check each component to see if it has a slot for the current kwarg
829
                for component_name, component_inputs in available_inputs.items():
1✔
830
                    if input_name in component_inputs:
1✔
831
                        # If a match is found, add the kwarg to the component's input data
832
                        pipeline_input_data[component_name][input_name] = input_value
1✔
833
                        resolved_at_least_once = True
1✔
834

835
                if not resolved_at_least_once:
1✔
836
                    unresolved_kwargs[input_name] = input_value
1✔
837

838
            if unresolved_kwargs:
1✔
839
                logger.warning(
1✔
840
                    "Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
841
                    input_keys=list(unresolved_kwargs.keys()),
842
                )
843

844
            data = dict(pipeline_input_data)
1✔
845

846
        # deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
847
        # when the same input reference is passed to multiple components.
848
        for component_name, component_inputs in data.items():
1✔
849
            data[component_name] = {k: deepcopy(v) for k, v in component_inputs.items()}
1✔
850

851
        return data
1✔
852

853
    @classmethod
1✔
854
    def from_template(
1✔
855
        cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
856
    ) -> "PipelineBase":
857
        """
858
        Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
859

860
        :param predefined_pipeline:
861
            The predefined pipeline to use.
862
        :param template_params:
863
            An optional dictionary of parameters to use when rendering the pipeline template.
864
        :returns:
865
            An instance of `Pipeline`.
866
        """
867
        tpl = PipelineTemplate.from_predefined(predefined_pipeline)
1✔
868
        # If tpl.render() fails, we let bubble up the original error
869
        rendered = tpl.render(template_params)
1✔
870

871
        # If there was a problem with the rendered version of the
872
        # template, we add it to the error stack for debugging
873
        try:
1✔
874
            return cls.loads(rendered)
1✔
875
        except Exception as e:
×
876
            msg = f"Error unmarshalling pipeline: {e}\n"
×
877
            msg += f"Source:\n{rendered}"
×
878
            raise PipelineUnmarshalError(msg)
×
879

880
    def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
1✔
881
        """
882
        Utility function to find all Components that receive input from `component_name`.
883

884
        :param component_name:
885
            Name of the sender Component
886

887
        :returns:
888
            List of tuples containing name of the receiver Component and sender OutputSocket
889
            and receiver InputSocket instances
890
        """
891
        res = []
1✔
892
        for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
1✔
893
            sender_socket: OutputSocket = connection["from_socket"]
1✔
894
            receiver_socket: InputSocket = connection["to_socket"]
1✔
895
            res.append((receiver_name, sender_socket, receiver_socket))
1✔
896
        return res
1✔
897

898
    @staticmethod
1✔
899
    def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]:
1✔
900
        """
901
        Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic.
902

903
        Example Input:
904
        {'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}}
905
        Example Output:
906
        {'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]},
907
         'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}}
908

909
        :param pipeline_inputs: Inputs to the pipeline.
910
        :returns: Converted inputs that can be used by the internal `Pipeline.run` logic.
911
        """
912
        inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
1✔
913
        for component_name, socket_dict in pipeline_inputs.items():
1✔
914
            inputs[component_name] = {}
1✔
915
            for socket_name, value in socket_dict.items():
1✔
916
                inputs[component_name][socket_name] = [{"sender": None, "value": value}]
1✔
917

918
        return inputs
1✔
919

920
    @staticmethod
1✔
921
    def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]:
1✔
922
        """
923
        Extracts the inputs needed to run for the component and removes them from the global inputs state.
924

925
        :param component_name: The name of a component.
926
        :param component: Component with component metadata.
927
        :param inputs: Global inputs state.
928
        :returns: The inputs for the component.
929
        """
930
        component_inputs = inputs.get(component_name, {})
1✔
931
        consumed_inputs = {}
1✔
932
        greedy_inputs_to_remove = set()
1✔
933
        for socket_name, socket in component["input_sockets"].items():
1✔
934
            socket_inputs = component_inputs.get(socket_name, [])
1✔
935
            socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
1✔
936
            if socket_inputs:
1✔
937
                if not socket.is_variadic:
1✔
938
                    # We only care about the first input provided to the socket.
939
                    consumed_inputs[socket_name] = socket_inputs[0]
1✔
940
                elif socket.is_greedy:
1✔
941
                    # We need to keep track of greedy inputs because we always remove them, even if they come from
942
                    # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run
943
                    # indefinitely.
944
                    greedy_inputs_to_remove.add(socket_name)
1✔
945
                    consumed_inputs[socket_name] = [socket_inputs[0]]
1✔
946
                elif is_socket_lazy_variadic(socket):
1✔
947
                    # We use all inputs provided to the socket on a lazy variadic socket.
948
                    consumed_inputs[socket_name] = socket_inputs
1✔
949

950
        # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs).
951
        pruned_inputs = {
1✔
952
            socket_name: [
953
                sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove
954
            ]
955
            for socket_name, socket in component_inputs.items()
956
        }
957
        pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0}
1✔
958

959
        inputs[component_name] = pruned_inputs
1✔
960

961
        return consumed_inputs
1✔
962

963
    def _fill_queue(
1✔
964
        self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int]
965
    ) -> FIFOPriorityQueue:
966
        """
967
        Calculates the execution priority for each component and inserts it into the priority queue.
968

969
        :param component_names: Names of the components to put into the queue.
970
        :param inputs: Inputs to the components.
971
        :param component_visits: Current state of component visits.
972
        :returns: A prioritized queue of component names.
973
        """
974
        priority_queue = FIFOPriorityQueue()
1✔
975
        for component_name in component_names:
1✔
976
            component = self._get_component_with_graph_metadata_and_visits(
1✔
977
                component_name, component_visits[component_name]
978
            )
979
            priority = self._calculate_priority(component, inputs.get(component_name, {}))
1✔
980
            priority_queue.push(component_name, priority)
1✔
981

982
        return priority_queue
1✔
983

984
    @staticmethod
1✔
985
    def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
1✔
986
        """
987
        Calculates the execution priority for a component depending on the component's inputs.
988

989
        :param component: Component metadata and component instance.
990
        :param inputs: Inputs to the component.
991
        :returns: Priority value for the component.
992
        """
993
        if not can_component_run(component, inputs):
1✔
994
            return ComponentPriority.BLOCKED
1✔
995
        elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
1✔
996
            return ComponentPriority.HIGHEST
1✔
997
        elif all_predecessors_executed(component, inputs):
1✔
998
            return ComponentPriority.READY
1✔
999
        elif are_all_lazy_variadic_sockets_resolved(component, inputs):
1✔
1000
            return ComponentPriority.DEFER
1✔
1001
        else:
1002
            return ComponentPriority.DEFER_LAST
1✔
1003

1004
    def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]:
1✔
1005
        """
1006
        Returns the component instance alongside input/output-socket metadata from the graph and adds current visits.
1007

1008
        We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution.
1009

1010
        :param component_name: The name of the component.
1011
        :param visits: Number of visits for the component.
1012
        :returns: Dict including component instance, input/output-sockets and visits.
1013
        """
1014
        comp_dict = self.graph.nodes[component_name]
1✔
1015
        comp_dict = {**comp_dict, "visits": visits}
1✔
1016
        return comp_dict
1✔
1017

1018
    def _get_next_runnable_component(
1✔
1019
        self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int]
1020
    ) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]:
1021
        """
1022
        Returns the next runnable component alongside its metadata from the priority queue.
1023

1024
        :param priority_queue: Priority queue of component names.
1025
        :param component_visits: Current state of component visits.
1026
        :returns: The next runnable component, the component name, and its priority
1027
            or None if no component in the queue can run.
1028
        :raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs.
1029
        """
1030
        priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = (
1✔
1031
            None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1]))
1032
        )
1033

1034
        if priority_and_component_name is not None and priority_and_component_name[0] != ComponentPriority.BLOCKED:
1✔
1035
            priority, component_name = priority_and_component_name
1✔
1036
            component = self._get_component_with_graph_metadata_and_visits(
1✔
1037
                component_name, component_visits[component_name]
1038
            )
1039
            if component["visits"] > self._max_runs_per_component:
1✔
1040
                msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'"
1✔
1041
                raise PipelineMaxComponentRuns(msg)
1✔
1042

1043
            return priority, component_name, component
1✔
1044

1045
        return None
1✔
1046

1047
    @staticmethod
1✔
1048
    def _add_missing_input_defaults(component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket]):
1✔
1049
        """
1050
        Updates the inputs with the default values for the inputs that are missing
1051

1052
        :param component_inputs: Inputs for the component.
1053
        :param component_input_sockets: Input sockets of the component.
1054
        """
1055
        for name, socket in component_input_sockets.items():
1✔
1056
            if not socket.is_mandatory and name not in component_inputs:
1✔
1057
                if socket.is_variadic:
1✔
1058
                    component_inputs[name] = [socket.default_value]
1✔
1059
                else:
1060
                    component_inputs[name] = socket.default_value
1✔
1061

1062
        return component_inputs
1✔
1063

1064
    def _tiebreak_waiting_components(
1✔
1065
        self,
1066
        component_name: str,
1067
        priority: ComponentPriority,
1068
        priority_queue: FIFOPriorityQueue,
1069
        topological_sort: Union[Dict[str, int], None],
1070
    ):
1071
        """
1072
        Decides which component to run when multiple components are waiting for inputs with the same priority.
1073

1074
        :param component_name: The name of the component.
1075
        :param priority: Priority of the component.
1076
        :param priority_queue: Priority queue of component names.
1077
        :param topological_sort: Cached topological sort of all components in the pipeline.
1078
        """
1079
        components_with_same_priority = [component_name]
×
1080

1081
        while len(priority_queue) > 0:
×
1082
            next_priority, next_component_name = priority_queue.peek()
×
1083
            if next_priority == priority:
×
1084
                priority_queue.pop()  # actually remove the component
×
1085
                components_with_same_priority.append(next_component_name)
×
1086
            else:
1087
                break
×
1088

1089
        if len(components_with_same_priority) > 1:
×
1090
            if topological_sort is None:
×
1091
                if networkx.is_directed_acyclic_graph(self.graph):
×
1092
                    topological_sort = networkx.lexicographical_topological_sort(self.graph)
×
1093
                    topological_sort = {node: idx for idx, node in enumerate(topological_sort)}
×
1094
                else:
1095
                    condensed = networkx.condensation(self.graph)
×
1096
                    condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))}
×
1097
                    topological_sort = {
×
1098
                        component_name: condensed_sorted[node]
1099
                        for component_name, node in condensed.graph["mapping"].items()
1100
                    }
1101

1102
            components_with_same_priority = sorted(
×
1103
                components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower())
1104
            )
1105

1106
            component_name = components_with_same_priority[0]
×
1107

1108
        return component_name, topological_sort
×
1109

1110
    @staticmethod
1✔
1111
    def _write_component_outputs(
1✔
1112
        component_name: str,
1113
        component_outputs: Dict[str, Any],
1114
        inputs: Dict[str, Any],
1115
        receivers: List[Tuple],
1116
        include_outputs_from: Set[str],
1117
    ) -> Dict[str, Any]:
1118
        """
1119
        Distributes the outputs of a component to the input sockets that it is connected to.
1120

1121
        :param component_name: The name of the component.
1122
        :param component_outputs: The outputs of the component.
1123
        :param inputs: The current global input state.
1124
        :param receivers: List of components that receive inputs from the component.
1125
        :param include_outputs_from: List of component names that should always return an output from the pipeline.
1126
        """
1127
        for receiver_name, sender_socket, receiver_socket in receivers:
1✔
1128
            # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
1129
            # that the sender did not produce an output for this socket.
1130
            # This allows us to track if a pre-decessor already ran but did not produce an output.
1131
            value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
1✔
1132

1133
            if receiver_name not in inputs:
1✔
1134
                inputs[receiver_name] = {}
1✔
1135

1136
            if is_socket_lazy_variadic(receiver_socket):
1✔
1137
                # If the receiver socket is lazy variadic, we append the new input.
1138
                # Lazy variadic sockets can collect multiple inputs.
1139
                _write_to_lazy_variadic_socket(
1✔
1140
                    inputs=inputs,
1141
                    receiver_name=receiver_name,
1142
                    receiver_socket_name=receiver_socket.name,
1143
                    component_name=component_name,
1144
                    value=value,
1145
                )
1146
            else:
1147
                # If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic.
1148
                # We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None.
1149
                _write_to_standard_socket(
1✔
1150
                    inputs=inputs,
1151
                    receiver_name=receiver_name,
1152
                    receiver_socket_name=receiver_socket.name,
1153
                    component_name=component_name,
1154
                    value=value,
1155
                )
1156

1157
        # If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
1158
        # outputs
1159
        if component_name in include_outputs_from:
1✔
1160
            return component_outputs
1✔
1161

1162
        # We prune outputs that were consumed by any receiving sockets.
1163
        # All remaining outputs will be added to the final outputs of the pipeline.
1164
        consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
1✔
1165
        pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
1✔
1166

1167
        return pruned_outputs
1✔
1168

1169
    @staticmethod
1✔
1170
    def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool:
1✔
1171
        """
1172
        Checks if the priority queue needs to be recomputed because the priorities might have changed.
1173

1174
        :param priority_queue: Priority queue of component names.
1175
        """
1176
        return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY
1✔
1177

1178
    @staticmethod
1✔
1179
    def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
1✔
1180
        """
1181
        Validate the pipeline to check if it is blocked or has no valid entry point.
1182

1183
        :param priority_queue: Priority queue of component names.
1184
        """
1185
        if len(priority_queue) == 0:
1✔
1186
            return
×
1187

1188
        candidate = priority_queue.peek()
1✔
1189
        if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
1✔
1190
            raise PipelineRuntimeError(
×
1191
                "Cannot run pipeline - all components are blocked. "
1192
                "This typically happens when:\n"
1193
                "1. There is no valid entry point for the pipeline\n"
1194
                "2. There is a circular dependency preventing the pipeline from running\n"
1195
                "Check the connections between these components and ensure all required inputs are provided."
1196
            )
1197

1198

1199
def _connections_status(
1✔
1200
    sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
1201
) -> str:
1202
    """
1203
    Lists the status of the sockets, for error messages.
1204
    """
1205
    sender_sockets_entries = []
1✔
1206
    for sender_socket in sender_sockets:
1✔
1207
        sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
1✔
1208
    sender_sockets_list = "\n".join(sender_sockets_entries)
1✔
1209

1210
    receiver_sockets_entries = []
1✔
1211
    for receiver_socket in receiver_sockets:
1✔
1212
        if receiver_socket.senders:
1✔
1213
            sender_status = f"sent by {','.join(receiver_socket.senders)}"
1✔
1214
        else:
1215
            sender_status = "available"
1✔
1216
        receiver_sockets_entries.append(
1✔
1217
            f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
1218
        )
1219
    receiver_sockets_list = "\n".join(receiver_sockets_entries)
1✔
1220

1221
    return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
1✔
1222

1223

1224
# Utility functions for writing to sockets
1225

1226

1227
def _write_to_lazy_variadic_socket(
1✔
1228
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1229
) -> None:
1230
    """
1231
    Write to a lazy variadic socket.
1232

1233
    Mutates inputs in place.
1234
    """
1235
    if not inputs[receiver_name].get(receiver_socket_name):
1✔
1236
        inputs[receiver_name][receiver_socket_name] = []
1✔
1237

1238
    inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value})
1✔
1239

1240

1241
def _write_to_standard_socket(
1✔
1242
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1243
) -> None:
1244
    """
1245
    Write to a greedy variadic or non-variadic socket.
1246

1247
    Mutates inputs in place.
1248
    """
1249
    current_value = inputs[receiver_name].get(receiver_socket_name)
1✔
1250

1251
    # Only overwrite if there's no existing value, or we have a new value to provide
1252
    if current_value is None or value is not _NO_OUTPUT_PRODUCED:
1✔
1253
        inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc