• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 6577725327

19 Oct 2023 04:48PM UTC coverage: 92.743% (+0.07%) from 92.674%
6577725327

Pull #122

github

web-flow
Merge 365ab75d4 into 429c3475c
Pull Request #122: re-introduce variadics to support Joiner node

154 of 158 branches covered (0.0%)

Branch coverage included in aggregate %.

600 of 655 relevant lines covered (91.6%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.94
canals/pipeline/pipeline.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
from typing import Optional, Any, Dict, List, Literal, Union
1✔
5

6
import os
1✔
7
import json
1✔
8
import datetime
1✔
9
import logging
1✔
10
from pathlib import Path
1✔
11
from copy import deepcopy
1✔
12
from collections import OrderedDict
1✔
13

14
import networkx
1✔
15

16
from canals.component import component, Component, InputSocket, OutputSocket
1✔
17
from canals.errors import (
1✔
18
    PipelineError,
19
    PipelineConnectError,
20
    PipelineMaxLoops,
21
    PipelineRuntimeError,
22
    PipelineValidationError,
23
)
24
from canals.pipeline.draw import _draw, _convert_for_debug, RenderingEngines
1✔
25
from canals.pipeline.validation import validate_pipeline_input
1✔
26
from canals.pipeline.connections import parse_connection, _find_unambiguous_connection
1✔
27
from canals.type_utils import _type_name
1✔
28
from canals.serialization import component_to_dict, component_from_dict
1✔
29

30
logger = logging.getLogger(__name__)
1✔
31

32

33
class Pipeline:
1✔
34
    """
35
    Components orchestration engine.
36

37
    Builds a graph of components and orchestrates their execution according to the execution graph.
38
    """
39

40
    def __init__(
1✔
41
        self,
42
        metadata: Optional[Dict[str, Any]] = None,
43
        max_loops_allowed: int = 100,
44
        debug_path: Union[Path, str] = Path(".canals_debug/"),
45
    ):
46
        """
47
        Creates the Pipeline.
48

49
        Args:
50
            metadata: arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in
51
                this dictionary can be serialized and deserialized if you wish to save this pipeline to file with
52
                `save_pipelines()/load_pipelines()`.
53
            max_loops_allowed: how many times the pipeline can run the same node before throwing an exception.
54
            debug_path: when debug is enabled in `run()`, where to save the debug data.
55
        """
56
        self.metadata = metadata or {}
1✔
57
        self.max_loops_allowed = max_loops_allowed
1✔
58
        self.graph = networkx.MultiDiGraph()
1✔
59
        self.debug: Dict[int, Dict[str, Any]] = {}
1✔
60
        self.debug_path = Path(debug_path)
1✔
61

62
    def __eq__(self, other) -> bool:
1✔
63
        """
64
        Equal pipelines share every metadata, node and edge, but they're not required to use
65
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
66
        """
67
        if (
1✔
68
            not isinstance(other, type(self))
69
            or not getattr(self, "metadata") == getattr(other, "metadata")
70
            or not getattr(self, "max_loops_allowed") == getattr(other, "max_loops_allowed")
71
            or not hasattr(self, "graph")
72
            or not hasattr(other, "graph")
73
        ):
74
            return False
×
75

76
        return (
1✔
77
            self.graph.adj == other.graph.adj
78
            and self._comparable_nodes_list(self.graph) == self._comparable_nodes_list(other.graph)
79
            and self.graph.graph == other.graph.graph
80
        )
81

82
    def to_dict(self) -> Dict[str, Any]:
1✔
83
        """
84
        Returns this Pipeline instance as a dictionary.
85
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
86
        """
87
        components = {}
1✔
88
        for name, instance in self.graph.nodes(data="instance"):
1✔
89
            components[name] = component_to_dict(instance)
1✔
90

91
        connections = []
1✔
92
        for sender, receiver, edge_data in self.graph.edges.data():
1✔
93
            sender_socket = edge_data["from_socket"].name
1✔
94
            receiver_socket = edge_data["to_socket"].name
1✔
95
            connections.append(
1✔
96
                {
97
                    "sender": f"{sender}.{sender_socket}",
98
                    "receiver": f"{receiver}.{receiver_socket}",
99
                }
100
            )
101
        return {
1✔
102
            "metadata": self.metadata,
103
            "max_loops_allowed": self.max_loops_allowed,
104
            "components": components,
105
            "connections": connections,
106
        }
107

108
    @classmethod
1✔
109
    def from_dict(cls, data: Dict[str, Any], **kwargs) -> "Pipeline":
1✔
110
        """
111
        Creates a Pipeline instance from a dictionary.
112
        A sample `data` dictionary could be formatted like so:
113
        ```
114
        {
115
            "metadata": {"test": "test"},
116
            "max_loops_allowed": 100,
117
            "components": {
118
                "add_two": {
119
                    "type": "AddFixedValue",
120
                    "init_parameters": {"add": 2},
121
                },
122
                "add_default": {
123
                    "type": "AddFixedValue",
124
                    "init_parameters": {"add": 1},
125
                },
126
                "double": {
127
                    "type": "Double",
128
                },
129
            },
130
            "connections": [
131
                {"sender": "add_two.result", "receiver": "double.value"},
132
                {"sender": "double.value", "receiver": "add_default.value"},
133
            ],
134
        }
135
        ```
136

137
        Supported kwargs:
138
        `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
139
        """
140
        metadata = data.get("metadata", {})
1✔
141
        max_loops_allowed = data.get("max_loops_allowed", 100)
1✔
142
        debug_path = Path(data.get("debug_path", ".canals_debug/"))
1✔
143
        pipe = cls(
1✔
144
            metadata=metadata,
145
            max_loops_allowed=max_loops_allowed,
146
            debug_path=debug_path,
147
        )
148
        components_to_reuse = kwargs.get("components", {})
1✔
149
        for name, component_data in data.get("components", {}).items():
1✔
150
            if name in components_to_reuse:
1✔
151
                # Reuse an instance
152
                instance = components_to_reuse[name]
1✔
153
            else:
154
                if "type" not in component_data:
1✔
155
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
156
                if component_data["type"] not in component.registry:
1✔
157
                    raise PipelineError(f"Component '{component_data['type']}' not imported.")
1✔
158
                # Create a new one
159
                component_class = component.registry[component_data["type"]]
1✔
160
                instance = component_from_dict(component_class, component_data)
1✔
161
            pipe.add_component(name=name, instance=instance)
1✔
162

163
        for connection in data.get("connections", []):
1✔
164
            if "sender" not in connection:
1✔
165
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
166
            if "receiver" not in connection:
1✔
167
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
168
            pipe.connect(connect_from=connection["sender"], connect_to=connection["receiver"])
1✔
169

170
        return pipe
1✔
171

172
    def _comparable_nodes_list(self, graph: networkx.MultiDiGraph) -> List[Dict[str, Any]]:
1✔
173
        """
174
        Replaces instances of nodes with their class name in order to make sure they're comparable.
175
        """
176
        nodes = []
1✔
177
        for node in graph.nodes:
1✔
178
            comparable_node = graph.nodes[node]
×
179
            comparable_node["instance"] = comparable_node["instance"].__class__
×
180
            nodes.append(comparable_node)
×
181
        nodes.sort()
1✔
182
        return nodes
1✔
183

184
    def add_component(self, name: str, instance: Component) -> None:
1✔
185
        """
186
        Create a component for the given component. Components are not connected to anything by default:
187
        use `Pipeline.connect()` to connect components together.
188

189
        Component names must be unique, but component instances can be reused if needed.
190

191
        Args:
192
            name: the name of the component.
193
            instance: the component instance.
194

195
        Returns:
196
            None
197

198
        Raises:
199
            ValueError: if a component with the same name already exists
200
            PipelineValidationError: if the given instance is not a Canals component
201
        """
202
        # Component names are unique
203
        if name in self.graph.nodes:
1✔
204
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
205

206
        # Components can't be named `_debug`
207
        if name == "_debug":
1✔
208
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
209

210
        # Component instances must be components
211
        if not isinstance(instance, Component):
1✔
212
            raise PipelineValidationError(
×
213
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
214
            )
215

216
        # Create the component's input and output sockets
217
        input_sockets = getattr(instance, "__canals_input__", {})
1✔
218
        output_sockets = getattr(instance, "__canals_output__", {})
1✔
219

220
        # Add component to the graph, disconnected
221
        logger.debug("Adding component '%s' (%s)", name, instance)
1✔
222
        self.graph.add_node(
1✔
223
            name,
224
            instance=instance,
225
            input_sockets=input_sockets,
226
            output_sockets=output_sockets,
227
            visits=0,
228
        )
229

230
    def connect(self, connect_from: str, connect_to: str) -> None:
1✔
231
        """
232
        Connects two components together. All components to connect must exist in the pipeline.
233
        If connecting to an component that has several output connections, specify the inputs and output names as
234
        'component_name.connections_name'.
235

236
        Args:
237
            connect_from: the component that delivers the value. This can be either just a component name or can be
238
                in the format `component_name.connection_name` if the component has multiple outputs.
239
            connect_to: the component that receives the value. This can be either just a component name or can be
240
                in the format `component_name.connection_name` if the component has multiple inputs.
241

242
        Returns:
243
            None
244

245
        Raises:
246
            PipelineConnectError: if the two components cannot be connected (for example if one of the components is
247
                not present in the pipeline, or the connections don't match by type, and so on).
248
        """
249
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
250
        from_node, from_socket_name = parse_connection(connect_from)
1✔
251
        to_node, to_socket_name = parse_connection(connect_to)
1✔
252

253
        # Get the nodes data.
254
        try:
1✔
255
            from_sockets = self.graph.nodes[from_node]["output_sockets"]
1✔
256
        except KeyError as exc:
1✔
257
            raise ValueError(f"Component named {from_node} not found in the pipeline.") from exc
1✔
258

259
        try:
1✔
260
            to_sockets = self.graph.nodes[to_node]["input_sockets"]
1✔
261
        except KeyError as exc:
1✔
262
            raise ValueError(f"Component named {to_node} not found in the pipeline.") from exc
1✔
263

264
        # If the name of either socket is given, get the socket
265
        if from_socket_name:
1✔
266
            from_socket = from_sockets.get(from_socket_name, None)
1✔
267
            if not from_socket:
1✔
268
                raise PipelineConnectError(
1✔
269
                    f"'{from_node}.{from_socket_name} does not exist. "
270
                    f"Output connections of {from_node} are: "
271
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
272
                )
273
        if to_socket_name:
1✔
274
            to_socket = to_sockets.get(to_socket_name, None)
1✔
275
            if not to_socket:
1✔
276
                raise PipelineConnectError(
1✔
277
                    f"'{to_node}.{to_socket_name} does not exist. "
278
                    f"Input connections of {to_node} are: "
279
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
280
                )
281

282
        # Look for an unambiguous connection among the possible ones.
283
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
284
        from_sockets = [from_socket] if from_socket_name else list(from_sockets.values())
1✔
285
        to_sockets = [to_socket] if to_socket_name else list(to_sockets.values())
1✔
286
        from_socket, to_socket = _find_unambiguous_connection(
1✔
287
            sender_node=from_node, sender_sockets=from_sockets, receiver_node=to_node, receiver_sockets=to_sockets
288
        )
289

290
        # Connect the components on these sockets
291
        self._direct_connect(from_node=from_node, from_socket=from_socket, to_node=to_node, to_socket=to_socket)
1✔
292

293
    def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: str, to_socket: InputSocket) -> None:
1✔
294
        """
295
        Directly connect socket to socket. This method does not type-check the connections: use 'Pipeline.connect()'
296
        instead (which uses 'find_unambiguous_connection()' to validate types).
297
        """
298
        # Make sure the receiving socket isn't already connected, unless it's variadic. Sending sockets can be
299
        # connected as many times as needed, so they don't need this check
300
        if to_socket.sender and not to_socket.is_variadic:
1✔
301
            raise PipelineConnectError(
1✔
302
                f"Cannot connect '{from_node}.{from_socket.name}' with '{to_node}.{to_socket.name}': "
303
                f"{to_node}.{to_socket.name} is already connected to {to_socket.sender}.\n"
304
            )
305

306
        # Create the connection
307
        logger.debug("Connecting '%s.%s' to '%s.%s'", from_node, from_socket.name, to_node, to_socket.name)
1✔
308
        edge_key = f"{from_socket.name}/{to_socket.name}"
1✔
309
        self.graph.add_edge(
1✔
310
            from_node,
311
            to_node,
312
            key=edge_key,
313
            conn_type=_type_name(from_socket.type),
314
            from_socket=from_socket,
315
            to_socket=to_socket,
316
        )
317

318
        # Stores the name of the nodes that will send its output to this socket
319
        to_socket.sender.append(from_node)
1✔
320

321
    def get_component(self, name: str) -> Component:
1✔
322
        """
323
        Returns an instance of a component.
324

325
        Args:
326
            name: the name of the component
327

328
        Returns:
329
            The instance of that component.
330

331
        Raises:
332
            ValueError: if a component with that name is not present in the pipeline.
333
        """
334
        try:
×
335
            return self.graph.nodes[name]["instance"]
×
336
        except KeyError as exc:
×
337
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
×
338

339
    def draw(self, path: Path, engine: RenderingEngines = "mermaid-image") -> None:
1✔
340
        """
341
        Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
342
        Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
343

344
        Args:
345
            path: where to save the diagram.
346
            engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-image'.
347
                Default is 'mermaid-image'.
348

349
        Returns:
350
            None
351

352
        Raises:
353
            ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
354
            HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
355
        """
356
        _draw(graph=networkx.MultiDiGraph(self.graph), path=path, engine=engine)
1✔
357

358
    def warm_up(self):
1✔
359
        """
360
        Make sure all nodes are warm.
361

362
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
363
        without re-initializing everything.
364
        """
365
        for node in self.graph.nodes:
1✔
366
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
367
                logger.info("Warming up component %s...", node)
×
368
                self.graph.nodes[node]["instance"].warm_up()
×
369

370
    def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
1✔
371
        """
372
        Runs the pipeline.
373

374
        Args:
375
            data: the inputs to give to the input components of the Pipeline.
376
            parameters: a dictionary with all the parameters of all the components, namespaced by component.
377
            debug: whether to collect and return debug information.
378

379
        Returns:
380
            A dictionary with the outputs of the output components of the Pipeline.
381

382
        Raises:
383
            PipelineRuntimeError: if the any of the components fail or return unexpected output.
384
        """
385
        # **** The Pipeline.run() algorithm ****
386
        #
387
        # Nodes are run as soon as an input for them appears in the inputs buffer.
388
        # When there's more than a node at once in the buffer (which means some
389
        # branches are running in parallel or that there are loops) they are selected to
390
        # run in FIFO order by the `inputs_buffer` OrderedDict.
391
        #
392
        # Inputs are labeled with the name of the node they're aimed for:
393
        #
394
        #   ````
395
        #   inputs_buffer[target_node] = {"input_name": input_value, ...}
396
        #   ```
397
        #
398
        # Nodes should wait until all the necessary input data has arrived before running.
399
        # If they're popped from the input_buffer before they're ready, they're put back in.
400
        # If the pipeline has branches of different lengths, it's possible that a node has to
401
        # wait a bit and let other nodes pass before receiving all the input data it needs.
402
        #
403
        # Chetsheet for networkx data access:
404
        # - Name of the node       # self.graph.nodes  (List[str])
405
        # - Node instance          # self.graph.nodes[node]["instance"]
406
        # - Input nodes            # [e[0] for e in self.graph.in_edges(node)]
407
        # - Output nodes           # [e[1] for e in self.graph.out_edges(node)]
408
        # - Output edges           # [e[2]["label"] for e in self.graph.out_edges(node, data=True)]
409
        #
410
        # if debug:
411
        #     os.makedirs("debug", exist_ok=True)
412

413
        data = validate_pipeline_input(self.graph, input_values=data)
1✔
414

415
        logger.info("Pipeline execution started.")
1✔
416
        inputs_buffer = self._prepare_inputs_buffer(data)
1✔
417
        pipeline_output: Dict[str, Dict[str, Any]] = {}
1✔
418
        self._clear_visits_count()
1✔
419
        self.warm_up()
1✔
420

421
        if debug:
1✔
422
            logger.info("Debug mode ON.")
×
423
        self.debug = {}
1✔
424

425
        # *** PIPELINE EXECUTION LOOP ***
426
        # We select the nodes to run by popping them in FIFO order from the inputs buffer.
427
        step = 0
1✔
428
        while inputs_buffer:
1✔
429
            step += 1
1✔
430
            if debug:
1✔
431
                self._record_pipeline_step(step, inputs_buffer, pipeline_output)
×
432
            logger.debug("> Queue at step %s: %s", step, {k: list(v.keys()) for k, v in inputs_buffer.items()})
1✔
433

434
            component_name, inputs = inputs_buffer.popitem(last=False)  # FIFO
1✔
435

436
            # Make sure it didn't run too many times already
437
            self._check_max_loops(component_name)
1✔
438

439
            # **** IS IT MY TURN YET? ****
440
            # Check if the component should be run or not
441
            action = self._calculate_action(name=component_name, inputs=inputs, inputs_buffer=inputs_buffer)
1✔
442

443
            # This component is missing data: let's put it back in the queue and wait.
444
            if action == "wait":
1✔
445
                if not inputs_buffer:
1✔
446
                    # What if there are no components to wait for?
447
                    raise PipelineRuntimeError(
×
448
                        f"'{component_name}' is stuck waiting for input, but there are no other components to run. "
449
                        "This is likely a Canals bug. Open an issue at https://github.com/deepset-ai/canals."
450
                    )
451

452
                inputs_buffer[component_name] = inputs
1✔
453
                continue
1✔
454

455
            # This component did not receive the input it needs: it must be on a skipped branch. Let's not run it.
456
            if action == "skip":
1✔
457
                self.graph.nodes[component_name]["visits"] += 1
1✔
458
                inputs_buffer = self._skip_downstream_unvisited_nodes(
1✔
459
                    component_name=component_name, inputs_buffer=inputs_buffer
460
                )
461
                continue
1✔
462

463
            if action == "remove":
1✔
464
                # This component has no reason of being in the run queue and we need to remove it. For example, this can happen to components that are connected to skipped branches of the pipeline.
465
                continue
×
466

467
            # **** RUN THE NODE ****
468
            # It is our turn! The node is ready to run and all necessary inputs are present
469
            output = self._run_component(name=component_name, inputs=inputs)
1✔
470

471
            # **** PROCESS THE OUTPUT ****
472
            # The node run successfully. Let's store or distribute the output it produced, if it's valid.
473
            if not self.graph.out_edges(component_name):
1✔
474
                # Note: if a node outputs many times (like in loops), the output will be overwritten
475
                pipeline_output[component_name] = output
1✔
476
            else:
477
                inputs_buffer = self._route_output(
1✔
478
                    node_results=output, node_name=component_name, inputs_buffer=inputs_buffer
479
                )
480

481
        if debug:
1✔
482
            self._record_pipeline_step(step + 1, inputs_buffer, pipeline_output)
×
483

484
            # Save to json
485
            os.makedirs(self.debug_path, exist_ok=True)
×
486
            with open(self.debug_path / "data.json", "w", encoding="utf-8") as datafile:
×
487
                json.dump(self.debug, datafile, indent=4, default=str)
×
488

489
            # Store in the output
490
            pipeline_output["_debug"] = self.debug  # type: ignore
×
491

492
        logger.info("Pipeline executed successfully.")
1✔
493
        return pipeline_output
1✔
494

495
    def _record_pipeline_step(self, step, inputs_buffer, pipeline_output):
1✔
496
        """
497
        Stores a snapshot of this step into the self.debug dictionary of the pipeline.
498
        """
499
        mermaid_graph = _convert_for_debug(deepcopy(self.graph))
×
500
        self.debug[step] = {
×
501
            "time": datetime.datetime.now(),
502
            "inputs_buffer": list(inputs_buffer.items()),
503
            "pipeline_output": pipeline_output,
504
            "diagram": mermaid_graph,
505
        }
506

507
    def _clear_visits_count(self):
1✔
508
        """
509
        Make sure all nodes's visits count is zero.
510
        """
511
        for node in self.graph.nodes:
1✔
512
            self.graph.nodes[node]["visits"] = 0
1✔
513

514
    def _check_max_loops(self, component_name: str):
1✔
515
        """
516
        Verify whether this component run too many times.
517
        """
518
        if self.graph.nodes[component_name]["visits"] > self.max_loops_allowed:
1✔
519
            raise PipelineMaxLoops(
1✔
520
                f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{component_name}'."
521
            )
522

523
    # This function is complex so it contains quite some logic, it needs tons of information
524
    # regarding a component to understand what action it should take so we have many local
525
    # variables and to keep things simple we also have multiple returns.
526
    # In the end this amount of information makes it easier to understand the internal logic so
527
    # we chose to ignore these pylint warnings.
528
    def _calculate_action(  # pylint: disable=too-many-locals, too-many-return-statements
1✔
529
        self, name: str, inputs: Dict[str, Any], inputs_buffer: Dict[str, Any]
530
    ) -> Literal["run", "wait", "skip", "remove"]:
531
        """
532
        Calculates the action to take for the component specified by `name`.
533
        There are four possible actions:
534
            * run
535
            * wait
536
            * skip
537
            * remove
538

539
        The below conditions are evaluated in this order.
540

541
        Component will run if at least one of the following statements is true:
542
            * It received all mandatory inputs
543
            * It received all mandatory inputs and it has no optional inputs
544
            * It received all mandatory inputs and all optional inputs are skipped
545
            * It received all mandatory inputs and some optional inputs and the rest are skipped
546
            * It received some of its inputs and the others are defaulted
547
            * It's the first component of the pipeline
548

549
        Component will wait if:
550
            * It received some of its inputs and the other are not skipped
551
            * It received all mandatory inputs and some optional inputs have not been skipped
552

553
        Component will be skipped if:
554
            * It never ran nor waited
555

556
        Component will be removed if:
557
            * It ran or waited at least once but can't do it again
558

559
        If none of the above condition is met a PipelineRuntimeError is raised.
560

561
        For simplicity sake input components that create a cycle, or components that already ran
562
        and don't create a cycle are considered as skipped.
563

564
        Args:
565
            name: Name of the component
566
            inputs: Values that the component will take as input
567
            inputs_buffer: Other components' inputs
568

569
        Returns:
570
            Action to take for component specifing whether it should run, wait, skip or be removed
571

572
        Raises:
573
            PipelineRuntimeError: If action to take can't be determined
574
        """
575

576
        # Upstream components/socket pairs the current component is connected to
577
        input_components = {
1✔
578
            from_node: data["to_socket"].name for from_node, _, data in self.graph.in_edges(name, data=True)
579
        }
580
        # Sockets that have received inputs from upstream components
581
        received_input_sockets = set(inputs.keys())
1✔
582

583
        # All components inputs, whether they're connected, default or pipeline inputs
584
        input_sockets: Dict[str, InputSocket] = self.graph.nodes[name]["input_sockets"].keys()
1✔
585
        optional_input_sockets = {
1✔
586
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.is_optional
587
        }
588
        mandatory_input_sockets = {
1✔
589
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.is_optional
590
        }
591

592
        # Components that are in the inputs buffer and have no inputs assigned are considered skipped
593
        skipped_components = {n for n, v in inputs_buffer.items() if not v}
1✔
594

595
        # Sockets that have their upstream component marked as skipped
596
        skipped_optional_input_sockets = {
1✔
597
            sockets["to_socket"].name
598
            for from_node, _, sockets in self.graph.in_edges(name, data=True)
599
            if from_node in skipped_components and sockets["to_socket"].name in optional_input_sockets
600
        }
601

602
        for from_node, socket in input_components.items():
1✔
603
            if socket not in optional_input_sockets:
1✔
604
                continue
1✔
605
            loops_back = networkx.has_path(self.graph, name, from_node)
1✔
606
            has_run = self.graph.nodes[from_node]["visits"] > 0
1✔
607
            if loops_back or has_run:
1✔
608
                # Consider all input components that loop back to current component
609
                # or that have already run at least once as skipped.
610
                # This must be done to correctly handle cycles in the pipeline or we
611
                # would reach a dead lock in components that have multiple inputs and
612
                # one of these forms a cycle.
613
                skipped_optional_input_sockets.add(socket)
1✔
614

615
        ##############
616
        # RUN CHECKS #
617
        ##############
618
        if (
1✔
619
            mandatory_input_sockets.issubset(received_input_sockets)
620
            and input_sockets == received_input_sockets | mandatory_input_sockets | skipped_optional_input_sockets
621
        ):
622
            # We received all mandatory inputs and:
623
            #   * There are no optional inputs or
624
            #   * All optional inputs are skipped or
625
            #   * We received part of the optional inputs, the rest are skipped
626
            if not optional_input_sockets:
1✔
627
                logger.debug("Component '%s' is ready to run. All mandatory inputs received.", name)
1✔
628
            else:
629
                logger.debug(
1✔
630
                    "Component '%s' is ready to run. All mandatory inputs received, skipped optional inputs: %s",
631
                    name,
632
                    skipped_optional_input_sockets,
633
                )
634
            return "run"
1✔
635

636
        if set(input_components.values()).issubset(received_input_sockets):
1✔
637
            # We have data from each connected input component.
638
            # We reach this when the current component is the first of the pipeline or
639
            # when it has defaults and all its input components have run.
640
            logger.debug("Component '%s' is ready to run. All expected inputs were received.", name)
1✔
641
            return "run"
1✔
642

643
        ###############
644
        # WAIT CHECKS #
645
        ###############
646
        if mandatory_input_sockets == received_input_sockets and skipped_optional_input_sockets.issubset(
1✔
647
            optional_input_sockets
648
        ):
649
            # We received all of the inputs we need, but some optional inputs have not been run or skipped yet
650
            logger.debug(
×
651
                "Component '%s' is waiting. All mandatory inputs received, some optional are not skipped: %s",
652
                name,
653
                optional_input_sockets - skipped_optional_input_sockets,
654
            )
655
            return "wait"
×
656

657
        if any(self.graph.nodes[n]["visits"] == 0 for n in input_components.keys()):
1✔
658
            # Some upstream component that must send input to the current component has yet to run.
659
            logger.debug(
1✔
660
                "Component '%s' is waiting. Missing inputs: %s",
661
                name,
662
                set(input_components.values()),
663
            )
664
            return "wait"
1✔
665

666
        ###############
667
        # SKIP CHECKS #
668
        ###############
669
        if self.graph.nodes[name]["visits"] == 0:
1✔
670
            # It's the first time visiting this component, if it can't run nor wait
671
            # it's fine skipping it at this point.
672
            logger.debug("Component '%s' is skipped. It can't run nor wait.", name)
1✔
673
            return "skip"
1✔
674

675
        #################
676
        # REMOVE CHECKS #
677
        #################
678
        if self.graph.nodes[name]["visits"] > 0:
×
679
            # This component has already been visited at least once. If it can't run nor wait
680
            # there is no reason to skip it again. So we it must be removed.
681
            logger.debug("Component '%s' is removed. It can't run, wait or skip.", name)
×
682
            return "remove"
×
683

684
        # Can't determine action to take
685
        raise PipelineRuntimeError(
×
686
            f"Can't determine Component '{name}' action. "
687
            f"Mandatory input sockets: {mandatory_input_sockets}, "
688
            f"optional input sockets: {optional_input_sockets}, "
689
            f"received input: {list(inputs.keys())}, "
690
            f"input components: {list(input_components.keys())}, "
691
            f"skipped components: {skipped_components}, "
692
            f"skipped optional inputs: {skipped_optional_input_sockets}."
693
            f"This is likely a Canals bug. Please open an issue at https://github.com/deepset-ai/canals.",
694
        )
695

696
    def _skip_downstream_unvisited_nodes(self, component_name: str, inputs_buffer: OrderedDict) -> OrderedDict:
1✔
697
        """
698
        When a component is skipped, put all downstream nodes in the inputs buffer too: the might be skipped too,
699
        unless they are merge nodes. They will be evaluated later by the pipeline execution loop.
700
        """
701
        downstream_nodes = [e[1] for e in self.graph.out_edges(component_name)]
1✔
702
        for downstream_node in downstream_nodes:
1✔
703
            if downstream_node in inputs_buffer:
1✔
704
                continue
1✔
705
            if self.graph.nodes[downstream_node]["visits"] == 0:
1✔
706
                # Skip downstream nodes only if they never been visited
707
                inputs_buffer[downstream_node] = {}
1✔
708
        return inputs_buffer
1✔
709

710
    def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
1✔
711
        """
712
        Once we're confident this component is ready to run, run it and collect the output.
713
        """
714
        self.graph.nodes[name]["visits"] += 1
1✔
715
        instance = self.graph.nodes[name]["instance"]
1✔
716
        try:
1✔
717
            logger.info("* Running %s (visits: %s)", name, self.graph.nodes[name]["visits"])
1✔
718
            logger.debug("   '%s' inputs: %s", name, inputs)
1✔
719

720
            outputs = instance.run(**inputs)
1✔
721

722
            # Unwrap the output
723
            logger.debug("   '%s' outputs: %s\n", name, outputs)
1✔
724

725
            # Make sure the component returned a dictionary
726
            if not isinstance(outputs, dict):
1✔
727
                raise PipelineRuntimeError(
1✔
728
                    f"Component '{name}' returned a value of type "
729
                    f"'{getattr(type(outputs), '__name__', str(type(outputs)))}' instead of a dict. "
730
                    "Components must always return dictionaries: check the the documentation."
731
                )
732

733
        except Exception as e:
1✔
734
            raise PipelineRuntimeError(
1✔
735
                f"{name} raised '{e.__class__.__name__}: {e}' \nInputs: {inputs}\n\n"
736
                "See the stacktrace above for more information."
737
            ) from e
738

739
        return outputs
1✔
740

741
    def _route_output(
1✔
742
        self,
743
        node_name: str,
744
        node_results: Dict[str, Any],
745
        inputs_buffer: OrderedDict,
746
    ) -> OrderedDict:
747
        """
748
        Distrubute the outputs of the component into the input buffer of downstream components.
749

750
        Returns the updated inputs buffer.
751
        """
752
        # This is not a terminal node: find out where the output goes, to which nodes and along which edge
753
        is_decision_node_for_loop = (
1✔
754
            any(networkx.has_path(self.graph, edge[1], node_name) for edge in self.graph.out_edges(node_name))
755
            and len(self.graph.out_edges(node_name)) > 1
756
        )
757
        for edge_data in self.graph.out_edges(node_name, data=True):
1✔
758
            to_socket = edge_data[2]["to_socket"]
1✔
759
            from_socket = edge_data[2]["from_socket"]
1✔
760
            target_node = edge_data[1]
1✔
761

762
            # If this is a decision node and a loop is involved, we add to the input buffer only the nodes
763
            # that received their expected output and we leave the others out of the queue.
764
            if is_decision_node_for_loop and node_results.get(from_socket.name, None) is None:
1✔
765
                if networkx.has_path(self.graph, target_node, node_name):
1✔
766
                    # In case we're choosing to leave a loop, do not put the loop's node in the buffer.
767
                    logger.debug(
1✔
768
                        "Not adding '%s' to the inputs buffer: we're leaving the loop.",
769
                        target_node,
770
                    )
771
                else:
772
                    # In case we're choosing to stay in a loop, do not put the external node in the buffer.
773
                    logger.debug(
1✔
774
                        "Not adding '%s' to the inputs buffer: we're staying in the loop.",
775
                        target_node,
776
                    )
777
            else:
778
                # In all other cases, populate the inputs buffer for all downstream nodes.
779

780
                # Create the buffer for the downstream node if it's not yet there.
781
                if target_node not in inputs_buffer:
1✔
782
                    inputs_buffer[target_node] = {}
1✔
783

784
                # Skip Edges that did not receive any input.
785
                value_to_route = node_results.get(from_socket.name)
1✔
786
                if value_to_route is None:
1✔
787
                    continue
1✔
788

789
                # If the socket was marked as variadic, pile up inputs in a list
790
                if to_socket.is_variadic:
1✔
791
                    inputs_buffer[target_node].setdefault(to_socket.name, []).append(value_to_route)
1✔
792
                # Non-variadic input: just store the value
793
                else:
794
                    inputs_buffer[target_node][to_socket.name] = value_to_route
1✔
795

796
        return inputs_buffer
1✔
797

798
    def _prepare_inputs_buffer(self, data: Dict[str, Any]) -> OrderedDict:
1✔
799
        """
800
        Prepare the inputs buffer based on the parameters that were
801
        passed to run()
802
        """
803
        inputs_buffer: OrderedDict = OrderedDict()
1✔
804
        for node_name, input_data in data.items():
1✔
805
            for socket_name, value in input_data.items():
1✔
806
                if value is None:
1✔
807
                    continue
×
808
                if self.graph.nodes[node_name]["input_sockets"][socket_name].is_variadic:
1✔
809
                    value = [value]
1✔
810
                inputs_buffer.setdefault(node_name, {})[socket_name] = value
1✔
811
        return inputs_buffer
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc