• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 13972131258

20 Mar 2025 02:43PM UTC coverage: 90.021% (-0.03%) from 90.054%
13972131258

Pull #9069

github

web-flow
Merge 8371761b0 into 67ab3788e
Pull Request #9069: refactor!: `ChatMessage` serialization-deserialization updates

9833 of 10923 relevant lines covered (90.02%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.08
haystack/core/pipeline/base.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
import itertools
1✔
6
from collections import defaultdict
1✔
7
from copy import deepcopy
1✔
8
from datetime import datetime
1✔
9
from enum import IntEnum
1✔
10
from pathlib import Path
1✔
11
from typing import Any, Dict, Iterator, List, Optional, Set, TextIO, Tuple, Type, TypeVar, Union
1✔
12

13
import networkx  # type:ignore
1✔
14

15
from haystack import logging
1✔
16
from haystack.core.component import Component, InputSocket, OutputSocket, component
1✔
17
from haystack.core.errors import (
1✔
18
    DeserializationError,
19
    PipelineConnectError,
20
    PipelineDrawingError,
21
    PipelineError,
22
    PipelineMaxComponentRuns,
23
    PipelineRuntimeError,
24
    PipelineUnmarshalError,
25
    PipelineValidationError,
26
)
27
from haystack.core.pipeline.component_checks import (
1✔
28
    _NO_OUTPUT_PRODUCED,
29
    all_predecessors_executed,
30
    are_all_lazy_variadic_sockets_resolved,
31
    are_all_sockets_ready,
32
    can_component_run,
33
    is_any_greedy_socket_ready,
34
    is_socket_lazy_variadic,
35
)
36
from haystack.core.pipeline.utils import FIFOPriorityQueue, parse_connect_string
1✔
37
from haystack.core.serialization import DeserializationCallbacks, component_from_dict, component_to_dict
1✔
38
from haystack.core.type_utils import _type_name, _types_are_compatible
1✔
39
from haystack.marshal import Marshaller, YamlMarshaller
1✔
40
from haystack.utils import is_in_jupyter, type_serialization
1✔
41

42
from .descriptions import find_pipeline_inputs, find_pipeline_outputs
1✔
43
from .draw import _to_mermaid_image
1✔
44
from .template import PipelineTemplate, PredefinedPipeline
1✔
45

46
DEFAULT_MARSHALLER = YamlMarshaller()
1✔
47

48
# We use a generic type to annotate the return value of class methods,
49
# so that static analyzers won't be confused when derived classes
50
# use those methods.
51
T = TypeVar("T", bound="PipelineBase")
1✔
52

53
logger = logging.getLogger(__name__)
1✔
54

55

56
class ComponentPriority(IntEnum):
1✔
57
    HIGHEST = 1
1✔
58
    READY = 2
1✔
59
    DEFER = 3
1✔
60
    DEFER_LAST = 4
1✔
61
    BLOCKED = 5
1✔
62

63

64
class PipelineBase:
1✔
65
    """
66
    Components orchestration engine.
67

68
    Builds a graph of components and orchestrates their execution according to the execution graph.
69
    """
70

71
    def __init__(
1✔
72
        self,
73
        metadata: Optional[Dict[str, Any]] = None,
74
        max_runs_per_component: int = 100,
75
        connection_type_validation: bool = True,
76
    ):
77
        """
78
        Creates the Pipeline.
79

80
        :param metadata:
81
            Arbitrary dictionary to store metadata about this `Pipeline`. Make sure all the values contained in
82
            this dictionary can be serialized and deserialized if you wish to save this `Pipeline` to file.
83
        :param max_runs_per_component:
84
            How many times the `Pipeline` can run the same Component.
85
            If this limit is reached a `PipelineMaxComponentRuns` exception is raised.
86
            If not set defaults to 100 runs per Component.
87
        :param connection_type_validation: Whether the pipeline will validate the types of the connections.
88
            Defaults to True.
89
        """
90
        self._telemetry_runs = 0
1✔
91
        self._last_telemetry_sent: Optional[datetime] = None
1✔
92
        self.metadata = metadata or {}
1✔
93
        self.graph = networkx.MultiDiGraph()
1✔
94
        self._max_runs_per_component = max_runs_per_component
1✔
95
        self._connection_type_validation = connection_type_validation
1✔
96

97
    def __eq__(self, other) -> bool:
1✔
98
        """
99
        Pipeline equality is defined by their type and the equality of their serialized form.
100

101
        Pipelines of the same type share every metadata, node and edge, but they're not required to use
102
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
103
        """
104
        if not isinstance(self, type(other)):
1✔
105
            return False
×
106
        return self.to_dict() == other.to_dict()
1✔
107

108
    def __repr__(self) -> str:
1✔
109
        """
110
        Returns a text representation of the Pipeline.
111
        """
112
        res = f"{object.__repr__(self)}\n"
1✔
113
        if self.metadata:
1✔
114
            res += "🧱 Metadata\n"
1✔
115
            for k, v in self.metadata.items():
1✔
116
                res += f"  - {k}: {v}\n"
1✔
117

118
        res += "🚅 Components\n"
1✔
119
        for name, instance in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
120
            res += f"  - {name}: {instance.__class__.__name__}\n"
1✔
121

122
        res += "🛤️ Connections\n"
1✔
123
        for sender, receiver, edge_data in self.graph.edges(data=True):
1✔
124
            sender_socket = edge_data["from_socket"].name
1✔
125
            receiver_socket = edge_data["to_socket"].name
1✔
126
            res += f"  - {sender}.{sender_socket} -> {receiver}.{receiver_socket} ({edge_data['conn_type']})\n"
1✔
127

128
        return res
1✔
129

130
    def to_dict(self) -> Dict[str, Any]:
1✔
131
        """
132
        Serializes the pipeline to a dictionary.
133

134
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
135

136
        :returns:
137
            Dictionary with serialized data.
138
        """
139
        components = {}
1✔
140
        for name, instance in self.graph.nodes(data="instance"):  # type:ignore
1✔
141
            components[name] = component_to_dict(instance, name)
1✔
142

143
        connections = []
1✔
144
        for sender, receiver, edge_data in self.graph.edges.data():
1✔
145
            sender_socket = edge_data["from_socket"].name
1✔
146
            receiver_socket = edge_data["to_socket"].name
1✔
147
            connections.append({"sender": f"{sender}.{sender_socket}", "receiver": f"{receiver}.{receiver_socket}"})
1✔
148
        return {
1✔
149
            "metadata": self.metadata,
150
            "max_runs_per_component": self._max_runs_per_component,
151
            "components": components,
152
            "connections": connections,
153
            "connection_type_validation": self._connection_type_validation,
154
        }
155

156
    @classmethod
1✔
157
    def from_dict(
1✔
158
        cls: Type[T], data: Dict[str, Any], callbacks: Optional[DeserializationCallbacks] = None, **kwargs
159
    ) -> T:
160
        """
161
        Deserializes the pipeline from a dictionary.
162

163
        :param data:
164
            Dictionary to deserialize from.
165
        :param callbacks:
166
            Callbacks to invoke during deserialization.
167
        :param kwargs:
168
            `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new
169
            ones.
170
        :returns:
171
            Deserialized component.
172
        """
173
        data_copy = deepcopy(data)  # to prevent modification of original data
1✔
174
        metadata = data_copy.get("metadata", {})
1✔
175
        max_runs_per_component = data_copy.get("max_runs_per_component", 100)
1✔
176
        connection_type_validation = data_copy.get("connection_type_validation", True)
1✔
177
        pipe = cls(
1✔
178
            metadata=metadata,
179
            max_runs_per_component=max_runs_per_component,
180
            connection_type_validation=connection_type_validation,
181
        )
182
        components_to_reuse = kwargs.get("components", {})
1✔
183
        for name, component_data in data_copy.get("components", {}).items():
1✔
184
            if name in components_to_reuse:
1✔
185
                # Reuse an instance
186
                instance = components_to_reuse[name]
1✔
187
            else:
188
                if "type" not in component_data:
1✔
189
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
190

191
                if component_data["type"] not in component.registry:
1✔
192
                    try:
1✔
193
                        # Import the module first...
194
                        module, _ = component_data["type"].rsplit(".", 1)
1✔
195
                        logger.debug("Trying to import module {module_name}", module_name=module)
1✔
196
                        type_serialization.thread_safe_import(module)
1✔
197
                        # ...then try again
198
                        if component_data["type"] not in component.registry:
1✔
199
                            raise PipelineError(
1✔
200
                                f"Successfully imported module '{module}' but couldn't find "
201
                                f"'{component_data['type']}' in the component registry.\n"
202
                                f"The component might be registered under a different path. "
203
                                f"Here are the registered components:\n {list(component.registry.keys())}\n"
204
                            )
205
                    except (ImportError, PipelineError, ValueError) as e:
1✔
206
                        raise PipelineError(
1✔
207
                            f"Component '{component_data['type']}' (name: '{name}') not imported. Please "
208
                            f"check that the package is installed and the component path is correct."
209
                        ) from e
210

211
                # Create a new one
212
                component_class = component.registry[component_data["type"]]
1✔
213

214
                try:
1✔
215
                    instance = component_from_dict(component_class, component_data, name, callbacks)
1✔
216
                except Exception as e:
1✔
217
                    msg = (
1✔
218
                        f"Couldn't deserialize component '{name}' of class '{component_class.__name__}' "
219
                        f"with the following data: {str(component_data)}. Possible reasons include "
220
                        "malformed serialized data, mismatch between the serialized component and the "
221
                        "loaded one (due to a breaking change, see "
222
                        "https://github.com/deepset-ai/haystack/releases), etc."
223
                    )
224
                    raise DeserializationError(msg) from e
1✔
225
            pipe.add_component(name=name, instance=instance)
1✔
226

227
        for connection in data.get("connections", []):
1✔
228
            if "sender" not in connection:
1✔
229
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
230
            if "receiver" not in connection:
1✔
231
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
232
            pipe.connect(sender=connection["sender"], receiver=connection["receiver"])
1✔
233

234
        return pipe
1✔
235

236
    def dumps(self, marshaller: Marshaller = DEFAULT_MARSHALLER) -> str:
1✔
237
        """
238
        Returns the string representation of this pipeline according to the format dictated by the `Marshaller` in use.
239

240
        :param marshaller:
241
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
242
        :returns:
243
            A string representing the pipeline.
244
        """
245
        return marshaller.marshal(self.to_dict())
1✔
246

247
    def dump(self, fp: TextIO, marshaller: Marshaller = DEFAULT_MARSHALLER):
1✔
248
        """
249
        Writes the string representation of this pipeline to the file-like object passed in the `fp` argument.
250

251
        :param fp:
252
            A file-like object ready to be written to.
253
        :param marshaller:
254
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
255
        """
256
        fp.write(marshaller.marshal(self.to_dict()))
1✔
257

258
    @classmethod
1✔
259
    def loads(
1✔
260
        cls: Type[T],
261
        data: Union[str, bytes, bytearray],
262
        marshaller: Marshaller = DEFAULT_MARSHALLER,
263
        callbacks: Optional[DeserializationCallbacks] = None,
264
    ) -> T:
265
        """
266
        Creates a `Pipeline` object from the string representation passed in the `data` argument.
267

268
        :param data:
269
            The string representation of the pipeline, can be `str`, `bytes` or `bytearray`.
270
        :param marshaller:
271
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
272
        :param callbacks:
273
            Callbacks to invoke during deserialization.
274
        :raises DeserializationError:
275
            If an error occurs during deserialization.
276
        :returns:
277
            A `Pipeline` object.
278
        """
279
        try:
1✔
280
            deserialized_data = marshaller.unmarshal(data)
1✔
281
        except Exception as e:
1✔
282
            raise DeserializationError(
1✔
283
                "Error while unmarshalling serialized pipeline data. This is usually "
284
                "caused by malformed or invalid syntax in the serialized representation."
285
            ) from e
286

287
        return cls.from_dict(deserialized_data, callbacks)
1✔
288

289
    @classmethod
1✔
290
    def load(
1✔
291
        cls: Type[T],
292
        fp: TextIO,
293
        marshaller: Marshaller = DEFAULT_MARSHALLER,
294
        callbacks: Optional[DeserializationCallbacks] = None,
295
    ) -> T:
296
        """
297
        Creates a `Pipeline` object a string representation.
298

299
        The string representation is read from the file-like object passed in the `fp` argument.
300

301

302
        :param fp:
303
            A file-like object ready to be read from.
304
        :param marshaller:
305
            The Marshaller used to create the string representation. Defaults to `YamlMarshaller`.
306
        :param callbacks:
307
            Callbacks to invoke during deserialization.
308
        :raises DeserializationError:
309
            If an error occurs during deserialization.
310
        :returns:
311
            A `Pipeline` object.
312
        """
313
        return cls.loads(fp.read(), marshaller, callbacks)
1✔
314

315
    def add_component(self, name: str, instance: Component) -> None:
1✔
316
        """
317
        Add the given component to the pipeline.
318

319
        Components are not connected to anything by default: use `Pipeline.connect()` to connect components together.
320
        Component names must be unique, but component instances can be reused if needed.
321

322
        :param name:
323
            The name of the component to add.
324
        :param instance:
325
            The component instance to add.
326

327
        :raises ValueError:
328
            If a component with the same name already exists.
329
        :raises PipelineValidationError:
330
            If the given instance is not a component.
331
        """
332
        # Component names are unique
333
        if name in self.graph.nodes:
1✔
334
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
335

336
        # Components can't be named `_debug`
337
        if name == "_debug":
1✔
338
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
339

340
        # Component instances must be components
341
        if not isinstance(instance, Component):
1✔
342
            raise PipelineValidationError(
×
343
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
344
            )
345

346
        if getattr(instance, "__haystack_added_to_pipeline__", None):
1✔
347
            msg = (
1✔
348
                "Component has already been added in another Pipeline. Components can't be shared between Pipelines. "
349
                "Create a new instance instead."
350
            )
351
            raise PipelineError(msg)
1✔
352

353
        setattr(instance, "__haystack_added_to_pipeline__", self)
1✔
354

355
        # Add component to the graph, disconnected
356
        logger.debug("Adding component '{component_name}' ({component})", component_name=name, component=instance)
1✔
357
        # We're completely sure the fields exist so we ignore the type error
358
        self.graph.add_node(
1✔
359
            name,
360
            instance=instance,
361
            input_sockets=instance.__haystack_input__._sockets_dict,  # type: ignore[attr-defined]
362
            output_sockets=instance.__haystack_output__._sockets_dict,  # type: ignore[attr-defined]
363
            visits=0,
364
        )
365

366
    def remove_component(self, name: str) -> Component:
1✔
367
        """
368
        Remove and returns component from the pipeline.
369

370
        Remove an existing component from the pipeline by providing its name.
371
        All edges that connect to the component will also be deleted.
372

373
        :param name:
374
            The name of the component to remove.
375
        :returns:
376
            The removed Component instance.
377

378
        :raises ValueError:
379
            If there is no component with that name already in the Pipeline.
380
        """
381

382
        # Check that a component with that name is in the Pipeline
383
        try:
1✔
384
            instance = self.get_component(name)
1✔
385
        except ValueError as exc:
1✔
386
            raise ValueError(
1✔
387
                f"There is no component named '{name}' in the pipeline. The valid component names are: ",
388
                ", ".join(n for n in self.graph.nodes),
389
            ) from exc
390

391
        # Delete component from the graph, deleting all its connections
392
        self.graph.remove_node(name)
1✔
393

394
        # Reset the Component sockets' senders and receivers
395
        input_sockets = instance.__haystack_input__._sockets_dict  # type: ignore[attr-defined]
1✔
396
        for socket in input_sockets.values():
1✔
397
            socket.senders = []
1✔
398

399
        output_sockets = instance.__haystack_output__._sockets_dict  # type: ignore[attr-defined]
1✔
400
        for socket in output_sockets.values():
1✔
401
            socket.receivers = []
1✔
402

403
        # Reset the Component's pipeline reference
404
        setattr(instance, "__haystack_added_to_pipeline__", None)
1✔
405

406
        return instance
1✔
407

408
    def connect(self, sender: str, receiver: str) -> "PipelineBase":  # noqa: PLR0915 PLR0912
1✔
409
        """
410
        Connects two components together.
411

412
        All components to connect must exist in the pipeline.
413
        If connecting to a component that has several output connections, specify the inputs and output names as
414
        'component_name.connections_name'.
415

416
        :param sender:
417
            The component that delivers the value. This can be either just a component name or can be
418
            in the format `component_name.connection_name` if the component has multiple outputs.
419
        :param receiver:
420
            The component that receives the value. This can be either just a component name or can be
421
            in the format `component_name.connection_name` if the component has multiple inputs.
422
        :param connection_type_validation: Whether the pipeline will validate the types of the connections.
423
            Defaults to the value set in the pipeline.
424
        :returns:
425
            The Pipeline instance.
426

427
        :raises PipelineConnectError:
428
            If the two components cannot be connected (for example if one of the components is
429
            not present in the pipeline, or the connections don't match by type, and so on).
430
        """
431
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
432
        sender_component_name, sender_socket_name = parse_connect_string(sender)
1✔
433
        receiver_component_name, receiver_socket_name = parse_connect_string(receiver)
1✔
434

435
        if sender_component_name == receiver_component_name:
1✔
436
            raise PipelineConnectError("Connecting a Component to itself is not supported.")
1✔
437

438
        # Get the nodes data.
439
        try:
1✔
440
            sender_sockets = self.graph.nodes[sender_component_name]["output_sockets"]
1✔
441
        except KeyError as exc:
1✔
442
            raise ValueError(f"Component named {sender_component_name} not found in the pipeline.") from exc
1✔
443
        try:
1✔
444
            receiver_sockets = self.graph.nodes[receiver_component_name]["input_sockets"]
1✔
445
        except KeyError as exc:
1✔
446
            raise ValueError(f"Component named {receiver_component_name} not found in the pipeline.") from exc
1✔
447

448
        # If the name of either socket is given, get the socket
449
        sender_socket: Optional[OutputSocket] = None
1✔
450
        if sender_socket_name:
1✔
451
            sender_socket = sender_sockets.get(sender_socket_name)
1✔
452
            if not sender_socket:
1✔
453
                raise PipelineConnectError(
1✔
454
                    f"'{sender} does not exist. "
455
                    f"Output connections of {sender_component_name} are: "
456
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in sender_sockets.items()])
457
                )
458

459
        receiver_socket: Optional[InputSocket] = None
1✔
460
        if receiver_socket_name:
1✔
461
            receiver_socket = receiver_sockets.get(receiver_socket_name)
1✔
462
            if not receiver_socket:
1✔
463
                raise PipelineConnectError(
1✔
464
                    f"'{receiver} does not exist. "
465
                    f"Input connections of {receiver_component_name} are: "
466
                    + ", ".join(
467
                        [f"{name} (type {_type_name(socket.type)})" for name, socket in receiver_sockets.items()]
468
                    )
469
                )
470

471
        # Look for a matching connection among the possible ones.
472
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
473
        sender_socket_candidates: List[OutputSocket] = (
1✔
474
            [sender_socket] if sender_socket else list(sender_sockets.values())
475
        )
476
        receiver_socket_candidates: List[InputSocket] = (
1✔
477
            [receiver_socket] if receiver_socket else list(receiver_sockets.values())
478
        )
479

480
        # Find all possible connections between these two components
481
        possible_connections = []
1✔
482
        for sender_sock, receiver_sock in itertools.product(sender_socket_candidates, receiver_socket_candidates):
1✔
483
            if _types_are_compatible(sender_sock.type, receiver_sock.type, self._connection_type_validation):
1✔
484
                possible_connections.append((sender_sock, receiver_sock))
1✔
485

486
        # We need this status for error messages, since we might need it in multiple places we calculate it here
487
        status = _connections_status(
1✔
488
            sender_node=sender_component_name,
489
            sender_sockets=sender_socket_candidates,
490
            receiver_node=receiver_component_name,
491
            receiver_sockets=receiver_socket_candidates,
492
        )
493

494
        if not possible_connections:
1✔
495
            # There's no possible connection between these two components
496
            if len(sender_socket_candidates) == len(receiver_socket_candidates) == 1:
1✔
497
                msg = (
1✔
498
                    f"Cannot connect '{sender_component_name}.{sender_socket_candidates[0].name}' with "
499
                    f"'{receiver_component_name}.{receiver_socket_candidates[0].name}': "
500
                    f"their declared input and output types do not match.\n{status}"
501
                )
502
            else:
503
                msg = (
×
504
                    f"Cannot connect '{sender_component_name}' with '{receiver_component_name}': "
505
                    f"no matching connections available.\n{status}"
506
                )
507
            raise PipelineConnectError(msg)
1✔
508

509
        if len(possible_connections) == 1:
1✔
510
            # There's only one possible connection, use it
511
            sender_socket = possible_connections[0][0]
1✔
512
            receiver_socket = possible_connections[0][1]
1✔
513

514
        if len(possible_connections) > 1:
1✔
515
            # There are multiple possible connection, let's try to match them by name
516
            name_matches = [
1✔
517
                (out_sock, in_sock) for out_sock, in_sock in possible_connections if in_sock.name == out_sock.name
518
            ]
519
            if len(name_matches) != 1:
1✔
520
                # There's are either no matches or more than one, we can't pick one reliably
521
                msg = (
1✔
522
                    f"Cannot connect '{sender_component_name}' with "
523
                    f"'{receiver_component_name}': more than one connection is possible "
524
                    "between these components. Please specify the connection name, like: "
525
                    f"pipeline.connect('{sender_component_name}.{possible_connections[0][0].name}', "
526
                    f"'{receiver_component_name}.{possible_connections[0][1].name}').\n{status}"
527
                )
528
                raise PipelineConnectError(msg)
1✔
529

530
            # Get the only possible match
531
            sender_socket = name_matches[0][0]
1✔
532
            receiver_socket = name_matches[0][1]
1✔
533

534
        # Connection must be valid on both sender/receiver sides
535
        if not sender_socket or not receiver_socket or not sender_component_name or not receiver_component_name:
1✔
536
            if sender_component_name and sender_socket:
×
537
                sender_repr = f"{sender_component_name}.{sender_socket.name} ({_type_name(sender_socket.type)})"
×
538
            else:
539
                sender_repr = "input needed"
×
540

541
            if receiver_component_name and receiver_socket:
×
542
                receiver_repr = f"({_type_name(receiver_socket.type)}) {receiver_component_name}.{receiver_socket.name}"
×
543
            else:
544
                receiver_repr = "output"
×
545
            msg = f"Connection must have both sender and receiver: {sender_repr} -> {receiver_repr}"
×
546
            raise PipelineConnectError(msg)
×
547

548
        logger.debug(
1✔
549
            "Connecting '{sender_component}.{sender_socket_name}' to '{receiver_component}.{receiver_socket_name}'",
550
            sender_component=sender_component_name,
551
            sender_socket_name=sender_socket.name,
552
            receiver_component=receiver_component_name,
553
            receiver_socket_name=receiver_socket.name,
554
        )
555

556
        if receiver_component_name in sender_socket.receivers and sender_component_name in receiver_socket.senders:
1✔
557
            # This is already connected, nothing to do
558
            return self
1✔
559

560
        if receiver_socket.senders and not receiver_socket.is_variadic:
1✔
561
            # Only variadic input sockets can receive from multiple senders
562
            msg = (
1✔
563
                f"Cannot connect '{sender_component_name}.{sender_socket.name}' with "
564
                f"'{receiver_component_name}.{receiver_socket.name}': "
565
                f"{receiver_component_name}.{receiver_socket.name} is already connected to {receiver_socket.senders}.\n"
566
            )
567
            raise PipelineConnectError(msg)
1✔
568

569
        # Update the sockets with the new connection
570
        sender_socket.receivers.append(receiver_component_name)
1✔
571
        receiver_socket.senders.append(sender_component_name)
1✔
572

573
        # Create the new connection
574
        self.graph.add_edge(
1✔
575
            sender_component_name,
576
            receiver_component_name,
577
            key=f"{sender_socket.name}/{receiver_socket.name}",
578
            conn_type=_type_name(sender_socket.type),
579
            from_socket=sender_socket,
580
            to_socket=receiver_socket,
581
            mandatory=receiver_socket.is_mandatory,
582
        )
583
        return self
1✔
584

585
    def get_component(self, name: str) -> Component:
1✔
586
        """
587
        Get the component with the specified name from the pipeline.
588

589
        :param name:
590
            The name of the component.
591
        :returns:
592
            The instance of that component.
593

594
        :raises ValueError:
595
            If a component with that name is not present in the pipeline.
596
        """
597
        try:
1✔
598
            return self.graph.nodes[name]["instance"]
1✔
599
        except KeyError as exc:
1✔
600
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
1✔
601

602
    def get_component_name(self, instance: Component) -> str:
1✔
603
        """
604
        Returns the name of the Component instance if it has been added to this Pipeline or an empty string otherwise.
605

606
        :param instance:
607
            The Component instance to look for.
608
        :returns:
609
            The name of the Component instance.
610
        """
611
        for name, inst in self.graph.nodes(data="instance"):  # type: ignore # type wrongly defined in networkx
1✔
612
            if inst == instance:
1✔
613
                return name
1✔
614
        return ""
1✔
615

616
    def inputs(self, include_components_with_connected_inputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
617
        """
618
        Returns a dictionary containing the inputs of a pipeline.
619

620
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
621
        the input sockets of that component, including their types and whether they are optional.
622

623
        :param include_components_with_connected_inputs:
624
            If `False`, only components that have disconnected input edges are
625
            included in the output.
626
        :returns:
627
            A dictionary where each key is a pipeline component name and each value is a dictionary of
628
            inputs sockets of that component.
629
        """
630
        inputs: Dict[str, Dict[str, Any]] = {}
1✔
631
        for component_name, data in find_pipeline_inputs(self.graph, include_components_with_connected_inputs).items():
1✔
632
            sockets_description = {}
1✔
633
            for socket in data:
1✔
634
                sockets_description[socket.name] = {"type": socket.type, "is_mandatory": socket.is_mandatory}
1✔
635
                if not socket.is_mandatory:
1✔
636
                    sockets_description[socket.name]["default_value"] = socket.default_value
1✔
637

638
            if sockets_description:
1✔
639
                inputs[component_name] = sockets_description
1✔
640
        return inputs
1✔
641

642
    def outputs(self, include_components_with_connected_outputs: bool = False) -> Dict[str, Dict[str, Any]]:
1✔
643
        """
644
        Returns a dictionary containing the outputs of a pipeline.
645

646
        Each key in the dictionary corresponds to a component name, and its value is another dictionary that describes
647
        the output sockets of that component.
648

649
        :param include_components_with_connected_outputs:
650
            If `False`, only components that have disconnected output edges are
651
            included in the output.
652
        :returns:
653
            A dictionary where each key is a pipeline component name and each value is a dictionary of
654
            output sockets of that component.
655
        """
656
        outputs = {
1✔
657
            comp: {socket.name: {"type": socket.type} for socket in data}
658
            for comp, data in find_pipeline_outputs(self.graph, include_components_with_connected_outputs).items()
659
            if data
660
        }
661
        return outputs
1✔
662

663
    def show(self, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30) -> None:
1✔
664
        """
665
        Display an image representing this `Pipeline` in a Jupyter notebook.
666

667
        This function generates a diagram of the `Pipeline` using a Mermaid server and displays it directly in
668
        the notebook.
669

670
        :param server_url:
671
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
672
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
673
            info on how to set up your own Mermaid server.
674

675
        :param params:
676
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
677
            Supported keys:
678
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
679
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
680
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
681
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
682
                - width: Width of the output image (integer).
683
                - height: Height of the output image (integer).
684
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
685
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
686
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
687
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
688

689
        :param timeout:
690
            Timeout in seconds for the request to the Mermaid server.
691

692
        :raises PipelineDrawingError:
693
            If the function is called outside of a Jupyter notebook or if there is an issue with rendering.
694
        """
695
        if is_in_jupyter():
1✔
696
            from IPython.display import Image, display  # type: ignore
1✔
697

698
            image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
1✔
699
            display(Image(image_data))
1✔
700
        else:
701
            msg = "This method is only supported in Jupyter notebooks. Use Pipeline.draw() to save an image locally."
1✔
702
            raise PipelineDrawingError(msg)
1✔
703

704
    def draw(
1✔
705
        self, path: Path, server_url: str = "https://mermaid.ink", params: Optional[dict] = None, timeout: int = 30
706
    ) -> None:
707
        """
708
        Save an image representing this `Pipeline` to the specified file path.
709

710
        This function generates a diagram of the `Pipeline` using the Mermaid server and saves it to the provided path.
711

712
        :param path:
713
            The file path where the generated image will be saved.
714
        :param server_url:
715
            The base URL of the Mermaid server used for rendering (default: 'https://mermaid.ink').
716
            See https://github.com/jihchi/mermaid.ink and https://github.com/mermaid-js/mermaid-live-editor for more
717
            info on how to set up your own Mermaid server.
718
        :param params:
719
            Dictionary of customization parameters to modify the output. Refer to Mermaid documentation for more details
720
            Supported keys:
721
                - format: Output format ('img', 'svg', or 'pdf'). Default: 'img'.
722
                - type: Image type for /img endpoint ('jpeg', 'png', 'webp'). Default: 'png'.
723
                - theme: Mermaid theme ('default', 'neutral', 'dark', 'forest'). Default: 'neutral'.
724
                - bgColor: Background color in hexadecimal (e.g., 'FFFFFF') or named format (e.g., '!white').
725
                - width: Width of the output image (integer).
726
                - height: Height of the output image (integer).
727
                - scale: Scaling factor (1–3). Only applicable if 'width' or 'height' is specified.
728
                - fit: Whether to fit the diagram size to the page (PDF only, boolean).
729
                - paper: Paper size for PDFs (e.g., 'a4', 'a3'). Ignored if 'fit' is true.
730
                - landscape: Landscape orientation for PDFs (boolean). Ignored if 'fit' is true.
731

732
        :param timeout:
733
            Timeout in seconds for the request to the Mermaid server.
734

735
        :raises PipelineDrawingError:
736
            If there is an issue with rendering or saving the image.
737
        """
738
        # Before drawing we edit a bit the graph, to avoid modifying the original that is
739
        # used for running the pipeline we copy it.
740
        image_data = _to_mermaid_image(self.graph, server_url=server_url, params=params, timeout=timeout)
1✔
741
        Path(path).write_bytes(image_data)
1✔
742

743
    def walk(self) -> Iterator[Tuple[str, Component]]:
1✔
744
        """
745
        Visits each component in the pipeline exactly once and yields its name and instance.
746

747
        No guarantees are provided on the visiting order.
748

749
        :returns:
750
            An iterator of tuples of component name and component instance.
751
        """
752
        for component_name, instance in self.graph.nodes(data="instance"):  # type: ignore # type is wrong in networkx
1✔
753
            yield component_name, instance
1✔
754

755
    def warm_up(self):
1✔
756
        """
757
        Make sure all nodes are warm.
758

759
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
760
        without re-initializing everything.
761
        """
762
        for node in self.graph.nodes:
1✔
763
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
764
                logger.info("Warming up component {node}...", node=node)
×
765
                self.graph.nodes[node]["instance"].warm_up()
×
766

767
    def _validate_input(self, data: Dict[str, Any]):
1✔
768
        """
769
        Validates pipeline input data.
770

771
        Validates that data:
772
        * Each Component name actually exists in the Pipeline
773
        * Each Component is not missing any input
774
        * Each Component has only one input per input socket, if not variadic
775
        * Each Component doesn't receive inputs that are already sent by another Component
776

777
        :param data:
778
            A dictionary of inputs for the pipeline's components. Each key is a component name.
779

780
        :raises ValueError:
781
            If inputs are invalid according to the above.
782
        """
783
        for component_name, component_inputs in data.items():
1✔
784
            if component_name not in self.graph.nodes:
1✔
785
                raise ValueError(f"Component named {component_name} not found in the pipeline.")
1✔
786
            instance = self.graph.nodes[component_name]["instance"]
1✔
787
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
788
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
789
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
790
            for input_name in component_inputs.keys():
1✔
791
                if input_name not in instance.__haystack_input__._sockets_dict:
1✔
792
                    raise ValueError(f"Input {input_name} not found in component {component_name}.")
1✔
793

794
        for component_name in self.graph.nodes:
1✔
795
            instance = self.graph.nodes[component_name]["instance"]
1✔
796
            for socket_name, socket in instance.__haystack_input__._sockets_dict.items():
1✔
797
                component_inputs = data.get(component_name, {})
1✔
798
                if socket.senders == [] and socket.is_mandatory and socket_name not in component_inputs:
1✔
799
                    raise ValueError(f"Missing input for component {component_name}: {socket_name}")
1✔
800
                if socket.senders and socket_name in component_inputs and not socket.is_variadic:
1✔
801
                    raise ValueError(
1✔
802
                        f"Input {socket_name} for component {component_name} is already sent by {socket.senders}."
803
                    )
804

805
    def _prepare_component_input_data(self, data: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
1✔
806
        """
807
        Prepares input data for pipeline components.
808

809
        Organizes input data for pipeline components and identifies any inputs that are not matched to any
810
        component's input slots. Deep-copies data items to avoid sharing mutables across multiple components.
811

812
        This method processes a flat dictionary of input data, where each key-value pair represents an input name
813
        and its corresponding value. It distributes these inputs to the appropriate pipeline components based on
814
        their input requirements. Inputs that don't match any component's input slots are classified as unresolved.
815

816
        :param data:
817
            A dictionary potentially having input names as keys and input values as values.
818

819
        :returns:
820
            A dictionary mapping component names to their respective matched inputs.
821
        """
822
        # check whether the data is a nested dictionary of component inputs where each key is a component name
823
        # and each value is a dictionary of input parameters for that component
824
        is_nested_component_input = all(isinstance(value, dict) for value in data.values())
1✔
825
        if not is_nested_component_input:
1✔
826
            # flat input, a dict where keys are input names and values are the corresponding values
827
            # we need to convert it to a nested dictionary of component inputs and then run the pipeline
828
            # just like in the previous case
829
            pipeline_input_data: Dict[str, Dict[str, Any]] = defaultdict(dict)
1✔
830
            unresolved_kwargs = {}
1✔
831

832
            # Retrieve the input slots for each component in the pipeline
833
            available_inputs: Dict[str, Dict[str, Any]] = self.inputs()
1✔
834

835
            # Go through all provided to distribute them to the appropriate component inputs
836
            for input_name, input_value in data.items():
1✔
837
                resolved_at_least_once = False
1✔
838

839
                # Check each component to see if it has a slot for the current kwarg
840
                for component_name, component_inputs in available_inputs.items():
1✔
841
                    if input_name in component_inputs:
1✔
842
                        # If a match is found, add the kwarg to the component's input data
843
                        pipeline_input_data[component_name][input_name] = input_value
1✔
844
                        resolved_at_least_once = True
1✔
845

846
                if not resolved_at_least_once:
1✔
847
                    unresolved_kwargs[input_name] = input_value
1✔
848

849
            if unresolved_kwargs:
1✔
850
                logger.warning(
1✔
851
                    "Inputs {input_keys} were not matched to any component inputs, please check your run parameters.",
852
                    input_keys=list(unresolved_kwargs.keys()),
853
                )
854

855
            data = dict(pipeline_input_data)
1✔
856

857
        # deepcopying the inputs prevents the Pipeline run logic from being altered unexpectedly
858
        # when the same input reference is passed to multiple components.
859
        for component_name, component_inputs in data.items():
1✔
860
            data[component_name] = {k: deepcopy(v) for k, v in component_inputs.items()}
1✔
861

862
        return data
1✔
863

864
    @classmethod
1✔
865
    def from_template(
1✔
866
        cls, predefined_pipeline: PredefinedPipeline, template_params: Optional[Dict[str, Any]] = None
867
    ) -> "PipelineBase":
868
        """
869
        Create a Pipeline from a predefined template. See `PredefinedPipeline` for available options.
870

871
        :param predefined_pipeline:
872
            The predefined pipeline to use.
873
        :param template_params:
874
            An optional dictionary of parameters to use when rendering the pipeline template.
875
        :returns:
876
            An instance of `Pipeline`.
877
        """
878
        tpl = PipelineTemplate.from_predefined(predefined_pipeline)
1✔
879
        # If tpl.render() fails, we let bubble up the original error
880
        rendered = tpl.render(template_params)
1✔
881

882
        # If there was a problem with the rendered version of the
883
        # template, we add it to the error stack for debugging
884
        try:
1✔
885
            return cls.loads(rendered)
1✔
886
        except Exception as e:
×
887
            msg = f"Error unmarshalling pipeline: {e}\n"
×
888
            msg += f"Source:\n{rendered}"
×
889
            raise PipelineUnmarshalError(msg)
×
890

891
    def _find_receivers_from(self, component_name: str) -> List[Tuple[str, OutputSocket, InputSocket]]:
1✔
892
        """
893
        Utility function to find all Components that receive input from `component_name`.
894

895
        :param component_name:
896
            Name of the sender Component
897

898
        :returns:
899
            List of tuples containing name of the receiver Component and sender OutputSocket
900
            and receiver InputSocket instances
901
        """
902
        res = []
1✔
903
        for _, receiver_name, connection in self.graph.edges(nbunch=component_name, data=True):
1✔
904
            sender_socket: OutputSocket = connection["from_socket"]
1✔
905
            receiver_socket: InputSocket = connection["to_socket"]
1✔
906
            res.append((receiver_name, sender_socket, receiver_socket))
1✔
907
        return res
1✔
908

909
    @staticmethod
1✔
910
    def _convert_to_internal_format(pipeline_inputs: Dict[str, Any]) -> Dict[str, Dict[str, List]]:
1✔
911
        """
912
        Converts the inputs to the pipeline to the format that is needed for the internal `Pipeline.run` logic.
913

914
        Example Input:
915
        {'prompt_builder': {'question': 'Who lives in Paris?'}, 'retriever': {'query': 'Who lives in Paris?'}}
916
        Example Output:
917
        {'prompt_builder': {'question': [{'sender': None, 'value': 'Who lives in Paris?'}]},
918
         'retriever': {'query': [{'sender': None, 'value': 'Who lives in Paris?'}]}}
919

920
        :param pipeline_inputs: Inputs to the pipeline.
921
        :returns: Converted inputs that can be used by the internal `Pipeline.run` logic.
922
        """
923
        inputs: Dict[str, Dict[str, List[Dict[str, Any]]]] = {}
1✔
924
        for component_name, socket_dict in pipeline_inputs.items():
1✔
925
            inputs[component_name] = {}
1✔
926
            for socket_name, value in socket_dict.items():
1✔
927
                inputs[component_name][socket_name] = [{"sender": None, "value": value}]
1✔
928

929
        return inputs
1✔
930

931
    @staticmethod
1✔
932
    def _consume_component_inputs(component_name: str, component: Dict, inputs: Dict) -> Dict[str, Any]:
1✔
933
        """
934
        Extracts the inputs needed to run for the component and removes them from the global inputs state.
935

936
        :param component_name: The name of a component.
937
        :param component: Component with component metadata.
938
        :param inputs: Global inputs state.
939
        :returns: The inputs for the component.
940
        """
941
        component_inputs = inputs.get(component_name, {})
1✔
942
        consumed_inputs = {}
1✔
943
        greedy_inputs_to_remove = set()
1✔
944
        for socket_name, socket in component["input_sockets"].items():
1✔
945
            socket_inputs = component_inputs.get(socket_name, [])
1✔
946
            socket_inputs = [sock["value"] for sock in socket_inputs if sock["value"] is not _NO_OUTPUT_PRODUCED]
1✔
947
            if socket_inputs:
1✔
948
                if not socket.is_variadic:
1✔
949
                    # We only care about the first input provided to the socket.
950
                    consumed_inputs[socket_name] = socket_inputs[0]
1✔
951
                elif socket.is_greedy:
1✔
952
                    # We need to keep track of greedy inputs because we always remove them, even if they come from
953
                    # outside the pipeline. Otherwise, a greedy input from the user would trigger a pipeline to run
954
                    # indefinitely.
955
                    greedy_inputs_to_remove.add(socket_name)
1✔
956
                    consumed_inputs[socket_name] = [socket_inputs[0]]
1✔
957
                elif is_socket_lazy_variadic(socket):
1✔
958
                    # We use all inputs provided to the socket on a lazy variadic socket.
959
                    consumed_inputs[socket_name] = socket_inputs
1✔
960

961
        # We prune all inputs except for those that were provided from outside the pipeline (e.g. user inputs).
962
        pruned_inputs = {
1✔
963
            socket_name: [
964
                sock for sock in socket if sock["sender"] is None and not socket_name in greedy_inputs_to_remove
965
            ]
966
            for socket_name, socket in component_inputs.items()
967
        }
968
        pruned_inputs = {socket_name: socket for socket_name, socket in pruned_inputs.items() if len(socket) > 0}
1✔
969

970
        inputs[component_name] = pruned_inputs
1✔
971

972
        return consumed_inputs
1✔
973

974
    def _fill_queue(
1✔
975
        self, component_names: List[str], inputs: Dict[str, Any], component_visits: Dict[str, int]
976
    ) -> FIFOPriorityQueue:
977
        """
978
        Calculates the execution priority for each component and inserts it into the priority queue.
979

980
        :param component_names: Names of the components to put into the queue.
981
        :param inputs: Inputs to the components.
982
        :param component_visits: Current state of component visits.
983
        :returns: A prioritized queue of component names.
984
        """
985
        priority_queue = FIFOPriorityQueue()
1✔
986
        for component_name in component_names:
1✔
987
            component = self._get_component_with_graph_metadata_and_visits(
1✔
988
                component_name, component_visits[component_name]
989
            )
990
            priority = self._calculate_priority(component, inputs.get(component_name, {}))
1✔
991
            priority_queue.push(component_name, priority)
1✔
992

993
        return priority_queue
1✔
994

995
    @staticmethod
1✔
996
    def _calculate_priority(component: Dict, inputs: Dict) -> ComponentPriority:
1✔
997
        """
998
        Calculates the execution priority for a component depending on the component's inputs.
999

1000
        :param component: Component metadata and component instance.
1001
        :param inputs: Inputs to the component.
1002
        :returns: Priority value for the component.
1003
        """
1004
        if not can_component_run(component, inputs):
1✔
1005
            return ComponentPriority.BLOCKED
1✔
1006
        elif is_any_greedy_socket_ready(component, inputs) and are_all_sockets_ready(component, inputs):
1✔
1007
            return ComponentPriority.HIGHEST
1✔
1008
        elif all_predecessors_executed(component, inputs):
1✔
1009
            return ComponentPriority.READY
1✔
1010
        elif are_all_lazy_variadic_sockets_resolved(component, inputs):
1✔
1011
            return ComponentPriority.DEFER
1✔
1012
        else:
1013
            return ComponentPriority.DEFER_LAST
1✔
1014

1015
    def _get_component_with_graph_metadata_and_visits(self, component_name: str, visits: int) -> Dict[str, Any]:
1✔
1016
        """
1017
        Returns the component instance alongside input/output-socket metadata from the graph and adds current visits.
1018

1019
        We can't store visits in the pipeline graph because this would prevent reentrance / thread-safe execution.
1020

1021
        :param component_name: The name of the component.
1022
        :param visits: Number of visits for the component.
1023
        :returns: Dict including component instance, input/output-sockets and visits.
1024
        """
1025
        comp_dict = self.graph.nodes[component_name]
1✔
1026
        comp_dict = {**comp_dict, "visits": visits}
1✔
1027
        return comp_dict
1✔
1028

1029
    def _get_next_runnable_component(
1✔
1030
        self, priority_queue: FIFOPriorityQueue, component_visits: Dict[str, int]
1031
    ) -> Union[Tuple[ComponentPriority, str, Dict[str, Any]], None]:
1032
        """
1033
        Returns the next runnable component alongside its metadata from the priority queue.
1034

1035
        :param priority_queue: Priority queue of component names.
1036
        :param component_visits: Current state of component visits.
1037
        :returns: The next runnable component, the component name, and its priority
1038
            or None if no component in the queue can run.
1039
        :raises: PipelineMaxComponentRuns if the next runnable component has exceeded the maximum number of runs.
1040
        """
1041
        priority_and_component_name: Union[Tuple[ComponentPriority, str], None] = (
1✔
1042
            None if (item := priority_queue.get()) is None else (ComponentPriority(item[0]), str(item[1]))
1043
        )
1044

1045
        if priority_and_component_name is not None and priority_and_component_name[0] != ComponentPriority.BLOCKED:
1✔
1046
            priority, component_name = priority_and_component_name
1✔
1047
            component = self._get_component_with_graph_metadata_and_visits(
1✔
1048
                component_name, component_visits[component_name]
1049
            )
1050
            if component["visits"] > self._max_runs_per_component:
1✔
1051
                msg = f"Maximum run count {self._max_runs_per_component} reached for component '{component_name}'"
1✔
1052
                raise PipelineMaxComponentRuns(msg)
1✔
1053

1054
            return priority, component_name, component
1✔
1055

1056
        return None
1✔
1057

1058
    @staticmethod
1✔
1059
    def _add_missing_input_defaults(component_inputs: Dict[str, Any], component_input_sockets: Dict[str, InputSocket]):
1✔
1060
        """
1061
        Updates the inputs with the default values for the inputs that are missing
1062

1063
        :param component_inputs: Inputs for the component.
1064
        :param component_input_sockets: Input sockets of the component.
1065
        """
1066
        for name, socket in component_input_sockets.items():
1✔
1067
            if not socket.is_mandatory and name not in component_inputs:
1✔
1068
                if socket.is_variadic:
1✔
1069
                    component_inputs[name] = [socket.default_value]
1✔
1070
                else:
1071
                    component_inputs[name] = socket.default_value
1✔
1072

1073
        return component_inputs
1✔
1074

1075
    def _tiebreak_waiting_components(
1✔
1076
        self,
1077
        component_name: str,
1078
        priority: ComponentPriority,
1079
        priority_queue: FIFOPriorityQueue,
1080
        topological_sort: Union[Dict[str, int], None],
1081
    ):
1082
        """
1083
        Decides which component to run when multiple components are waiting for inputs with the same priority.
1084

1085
        :param component_name: The name of the component.
1086
        :param priority: Priority of the component.
1087
        :param priority_queue: Priority queue of component names.
1088
        :param topological_sort: Cached topological sort of all components in the pipeline.
1089
        """
1090
        components_with_same_priority = [component_name]
×
1091

1092
        while len(priority_queue) > 0:
×
1093
            next_priority, next_component_name = priority_queue.peek()
×
1094
            if next_priority == priority:
×
1095
                priority_queue.pop()  # actually remove the component
×
1096
                components_with_same_priority.append(next_component_name)
×
1097
            else:
1098
                break
×
1099

1100
        if len(components_with_same_priority) > 1:
×
1101
            if topological_sort is None:
×
1102
                if networkx.is_directed_acyclic_graph(self.graph):
×
1103
                    topological_sort = networkx.lexicographical_topological_sort(self.graph)
×
1104
                    topological_sort = {node: idx for idx, node in enumerate(topological_sort)}
×
1105
                else:
1106
                    condensed = networkx.condensation(self.graph)
×
1107
                    condensed_sorted = {node: idx for idx, node in enumerate(networkx.topological_sort(condensed))}
×
1108
                    topological_sort = {
×
1109
                        component_name: condensed_sorted[node]
1110
                        for component_name, node in condensed.graph["mapping"].items()
1111
                    }
1112

1113
            components_with_same_priority = sorted(
×
1114
                components_with_same_priority, key=lambda comp_name: (topological_sort[comp_name], comp_name.lower())
1115
            )
1116

1117
            component_name = components_with_same_priority[0]
×
1118

1119
        return component_name, topological_sort
×
1120

1121
    @staticmethod
1✔
1122
    def _write_component_outputs(
1✔
1123
        component_name: str,
1124
        component_outputs: Dict[str, Any],
1125
        inputs: Dict[str, Any],
1126
        receivers: List[Tuple],
1127
        include_outputs_from: Set[str],
1128
    ) -> Dict[str, Any]:
1129
        """
1130
        Distributes the outputs of a component to the input sockets that it is connected to.
1131

1132
        :param component_name: The name of the component.
1133
        :param component_outputs: The outputs of the component.
1134
        :param inputs: The current global input state.
1135
        :param receivers: List of components that receive inputs from the component.
1136
        :param include_outputs_from: List of component names that should always return an output from the pipeline.
1137
        """
1138
        for receiver_name, sender_socket, receiver_socket in receivers:
1✔
1139
            # We either get the value that was produced by the actor or we use the _NO_OUTPUT_PRODUCED class to indicate
1140
            # that the sender did not produce an output for this socket.
1141
            # This allows us to track if a pre-decessor already ran but did not produce an output.
1142
            value = component_outputs.get(sender_socket.name, _NO_OUTPUT_PRODUCED)
1✔
1143

1144
            if receiver_name not in inputs:
1✔
1145
                inputs[receiver_name] = {}
1✔
1146

1147
            if is_socket_lazy_variadic(receiver_socket):
1✔
1148
                # If the receiver socket is lazy variadic, we append the new input.
1149
                # Lazy variadic sockets can collect multiple inputs.
1150
                _write_to_lazy_variadic_socket(
1✔
1151
                    inputs=inputs,
1152
                    receiver_name=receiver_name,
1153
                    receiver_socket_name=receiver_socket.name,
1154
                    component_name=component_name,
1155
                    value=value,
1156
                )
1157
            else:
1158
                # If the receiver socket is not lazy variadic, it is greedy variadic or non-variadic.
1159
                # We overwrite with the new input if it's not _NO_OUTPUT_PRODUCED or if the current value is None.
1160
                _write_to_standard_socket(
1✔
1161
                    inputs=inputs,
1162
                    receiver_name=receiver_name,
1163
                    receiver_socket_name=receiver_socket.name,
1164
                    component_name=component_name,
1165
                    value=value,
1166
                )
1167

1168
        # If we want to include all outputs from this actor in the final outputs, we don't need to prune any consumed
1169
        # outputs
1170
        if component_name in include_outputs_from:
1✔
1171
            return component_outputs
1✔
1172

1173
        # We prune outputs that were consumed by any receiving sockets.
1174
        # All remaining outputs will be added to the final outputs of the pipeline.
1175
        consumed_outputs = {sender_socket.name for _, sender_socket, __ in receivers}
1✔
1176
        pruned_outputs = {key: value for key, value in component_outputs.items() if key not in consumed_outputs}
1✔
1177

1178
        return pruned_outputs
1✔
1179

1180
    @staticmethod
1✔
1181
    def _is_queue_stale(priority_queue: FIFOPriorityQueue) -> bool:
1✔
1182
        """
1183
        Checks if the priority queue needs to be recomputed because the priorities might have changed.
1184

1185
        :param priority_queue: Priority queue of component names.
1186
        """
1187
        return len(priority_queue) == 0 or priority_queue.peek()[0] > ComponentPriority.READY
1✔
1188

1189
    @staticmethod
1✔
1190
    def validate_pipeline(priority_queue: FIFOPriorityQueue) -> None:
1✔
1191
        """
1192
        Validate the pipeline to check if it is blocked or has no valid entry point.
1193

1194
        :param priority_queue: Priority queue of component names.
1195
        """
1196
        if len(priority_queue) == 0:
1✔
1197
            return
×
1198

1199
        candidate = priority_queue.peek()
1✔
1200
        if candidate is not None and candidate[0] == ComponentPriority.BLOCKED:
1✔
1201
            raise PipelineRuntimeError(
×
1202
                "Cannot run pipeline - all components are blocked. "
1203
                "This typically happens when:\n"
1204
                "1. There is no valid entry point for the pipeline\n"
1205
                "2. There is a circular dependency preventing the pipeline from running\n"
1206
                "Check the connections between these components and ensure all required inputs are provided."
1207
            )
1208

1209

1210
def _connections_status(
1✔
1211
    sender_node: str, receiver_node: str, sender_sockets: List[OutputSocket], receiver_sockets: List[InputSocket]
1212
) -> str:
1213
    """
1214
    Lists the status of the sockets, for error messages.
1215
    """
1216
    sender_sockets_entries = []
1✔
1217
    for sender_socket in sender_sockets:
1✔
1218
        sender_sockets_entries.append(f" - {sender_socket.name}: {_type_name(sender_socket.type)}")
1✔
1219
    sender_sockets_list = "\n".join(sender_sockets_entries)
1✔
1220

1221
    receiver_sockets_entries = []
1✔
1222
    for receiver_socket in receiver_sockets:
1✔
1223
        if receiver_socket.senders:
1✔
1224
            sender_status = f"sent by {','.join(receiver_socket.senders)}"
1✔
1225
        else:
1226
            sender_status = "available"
1✔
1227
        receiver_sockets_entries.append(
1✔
1228
            f" - {receiver_socket.name}: {_type_name(receiver_socket.type)} ({sender_status})"
1229
        )
1230
    receiver_sockets_list = "\n".join(receiver_sockets_entries)
1✔
1231

1232
    return f"'{sender_node}':\n{sender_sockets_list}\n'{receiver_node}':\n{receiver_sockets_list}"
1✔
1233

1234

1235
# Utility functions for writing to sockets
1236

1237

1238
def _write_to_lazy_variadic_socket(
1✔
1239
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1240
) -> None:
1241
    """
1242
    Write to a lazy variadic socket.
1243

1244
    Mutates inputs in place.
1245
    """
1246
    if not inputs[receiver_name].get(receiver_socket_name):
1✔
1247
        inputs[receiver_name][receiver_socket_name] = []
1✔
1248

1249
    inputs[receiver_name][receiver_socket_name].append({"sender": component_name, "value": value})
1✔
1250

1251

1252
def _write_to_standard_socket(
1✔
1253
    inputs: Dict[str, Any], receiver_name: str, receiver_socket_name: str, component_name: str, value: Any
1254
) -> None:
1255
    """
1256
    Write to a greedy variadic or non-variadic socket.
1257

1258
    Mutates inputs in place.
1259
    """
1260
    current_value = inputs[receiver_name].get(receiver_socket_name)
1✔
1261

1262
    # Only overwrite if there's no existing value, or we have a new value to provide
1263
    if current_value is None or value is not _NO_OUTPUT_PRODUCED:
1✔
1264
        inputs[receiver_name][receiver_socket_name] = [{"sender": component_name, "value": value}]
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc