• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5902589996

18 Aug 2023 12:26PM UTC coverage: 93.556% (+0.09%) from 93.466%
5902589996

Pull #95

github

web-flow
Merge fbfe6429f into 50c1afd14
Pull Request #95: Remove all mentions of Component.defaults

177 of 182 branches covered (97.25%)

Branch coverage included in aggregate %.

665 of 718 relevant lines covered (92.62%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

90.03
canals/pipeline/pipeline.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
from typing import Optional, Any, Dict, List, Literal, Union
1✔
5

6
import os
1✔
7
import json
1✔
8
import datetime
1✔
9
import logging
1✔
10
from pathlib import Path
1✔
11
from copy import deepcopy
1✔
12
from collections import OrderedDict
1✔
13

14
import networkx
1✔
15

16
from canals.component import component, Component
1✔
17
from canals.errors import (
1✔
18
    PipelineError,
19
    PipelineConnectError,
20
    PipelineMaxLoops,
21
    PipelineRuntimeError,
22
    PipelineValidationError,
23
)
24
from canals.pipeline.draw import _draw, _convert_for_debug, RenderingEngines
1✔
25
from canals.pipeline.sockets import InputSocket, OutputSocket
1✔
26
from canals.pipeline.validation import _validate_pipeline_input
1✔
27
from canals.pipeline.connections import _parse_connection_name, _find_unambiguous_connection
1✔
28
from canals.utils import _type_name
1✔
29

30
logger = logging.getLogger(__name__)
1✔
31

32

33
class Pipeline:
1✔
34
    """
35
    Components orchestration engine.
36

37
    Builds a graph of components and orchestrates their execution according to the execution graph.
38
    """
39

40
    def __init__(
1✔
41
        self,
42
        metadata: Optional[Dict[str, Any]] = None,
43
        max_loops_allowed: int = 100,
44
        debug_path: Union[Path, str] = Path(".canals_debug/"),
45
    ):
46
        """
47
        Creates the Pipeline.
48

49
        Args:
50
            metadata: arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in
51
                this dictionary can be serialized and deserialized if you wish to save this pipeline to file with
52
                `save_pipelines()/load_pipelines()`.
53
            max_loops_allowed: how many times the pipeline can run the same node before throwing an exception.
54
            debug_path: when debug is enabled in `run()`, where to save the debug data.
55
        """
56
        self.metadata = metadata or {}
1✔
57
        self.max_loops_allowed = max_loops_allowed
1✔
58
        self.graph = networkx.MultiDiGraph()
1✔
59
        self.debug: Dict[int, Dict[str, Any]] = {}
1✔
60
        self.debug_path = Path(debug_path)
1✔
61

62
    def __eq__(self, other) -> bool:
1✔
63
        """
64
        Equal pipelines share every metadata, node and edge, but they're not required to use
65
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
66
        """
67
        if (
1✔
68
            not isinstance(other, type(self))
69
            or not getattr(self, "metadata") == getattr(other, "metadata")
70
            or not getattr(self, "max_loops_allowed") == getattr(other, "max_loops_allowed")
71
            or not hasattr(self, "graph")
72
            or not hasattr(other, "graph")
73
        ):
74
            return False
×
75

76
        return (
1✔
77
            self.graph.adj == other.graph.adj
78
            and self._comparable_nodes_list(self.graph) == self._comparable_nodes_list(other.graph)
79
            and self.graph.graph == other.graph.graph
80
        )
81

82
    def to_dict(self) -> Dict[str, Any]:
1✔
83
        """
84
        Returns this Pipeline instance as a dictionary.
85
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
86
        """
87
        components = {name: instance.to_dict() for name, instance in self.graph.nodes(data="instance")}
1✔
88
        connections = []
1✔
89
        for sender, receiver, sockets in self.graph.edges:
1✔
90
            (sender_socket, receiver_socket) = sockets.split("/")
1✔
91
            connections.append(
1✔
92
                {
93
                    "sender": f"{sender}.{sender_socket}",
94
                    "receiver": f"{receiver}.{receiver_socket}",
95
                }
96
            )
97
        return {
1✔
98
            "metadata": self.metadata,
99
            "max_loops_allowed": self.max_loops_allowed,
100
            "components": components,
101
            "connections": connections,
102
        }
103

104
    @classmethod
1✔
105
    def from_dict(cls, data: Dict[str, Any], **kwargs) -> "Pipeline":
1✔
106
        """
107
        Creates a Pipeline instance from a dictionary.
108
        A sample `data` dictionary could be formatted like so:
109
        ```
110
        {
111
            "metadata": {"test": "test"},
112
            "max_loops_allowed": 100,
113
            "components": {
114
                "add_two": {
115
                    "type": "AddFixedValue",
116
                    "hash": "123",
117
                    "init_parameters": {"add": 2},
118
                },
119
                "add_default": {
120
                    "type": "AddFixedValue",
121
                    "hash": "456",
122
                    "init_parameters": {"add": 1},
123
                },
124
                "double": {
125
                    "type": "Double",
126
                    "hash": "789"
127
                },
128
            },
129
            "connections": [
130
                {"sender": "add_two.result", "receiver": "double.value"},
131
                {"sender": "double.value", "receiver": "add_default.value"},
132
            ],
133
        }
134
        ```
135

136
        Supported kwargs:
137
        `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
138
        """
139
        metadata = data.get("metadata", {})
1✔
140
        max_loops_allowed = data.get("max_loops_allowed", 100)
1✔
141
        debug_path = Path(data.get("debug_path", ".canals_debug/"))
1✔
142
        pipe = cls(
1✔
143
            metadata=metadata,
144
            max_loops_allowed=max_loops_allowed,
145
            debug_path=debug_path,
146
        )
147
        components_to_reuse = kwargs.get("components", {})
1✔
148
        for name, component_data in data.get("components", {}).items():
1✔
149
            if name in components_to_reuse:
1✔
150
                # Reuse an instance
151
                instance = components_to_reuse[name]
1✔
152
            else:
153
                if "type" not in component_data:
1✔
154
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
155
                if component_data["type"] not in component.registry:
1✔
156
                    raise PipelineError(f"Component '{component_data['type']}' not imported.")
1✔
157
                # Create a new one
158
                instance = component.registry[component_data["type"]].from_dict(component_data)
1✔
159
            pipe.add_component(name=name, instance=instance)
1✔
160

161
        for connection in data.get("connections", []):
1✔
162
            if "sender" not in connection:
1✔
163
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
164
            if "receiver" not in connection:
1✔
165
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
166
            pipe.connect(connect_from=connection["sender"], connect_to=connection["receiver"])
1✔
167

168
        return pipe
1✔
169

170
    def _comparable_nodes_list(self, graph: networkx.MultiDiGraph) -> List[Dict[str, Any]]:
1✔
171
        """
172
        Replaces instances of nodes with their class name in order to make sure they're comparable.
173
        """
174
        nodes = []
1✔
175
        for node in graph.nodes:
1✔
176
            comparable_node = graph.nodes[node]
1✔
177
            comparable_node["instance"] = comparable_node["instance"].__class__
1✔
178
            nodes.append(comparable_node)
1✔
179
        nodes.sort()
1✔
180
        return nodes
1✔
181

182
    def add_component(self, name: str, instance: Component) -> None:
1✔
183
        """
184
        Create a component for the given component. Components are not connected to anything by default:
185
        use `Pipeline.connect()` to connect components together.
186

187
        Component names must be unique, but component instances can be reused if needed.
188

189
        Args:
190
            name: the name of the component.
191
            instance: the component instance.
192

193
        Returns:
194
            None
195

196
        Raises:
197
            ValueError: if a component with the same name already exists
198
            PipelineValidationError: if the given instance is not a Canals component
199
        """
200
        # Component names are unique
201
        if name in self.graph.nodes:
1✔
202
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
203

204
        # Components can't be named `_debug`
205
        if name == "_debug":
1✔
206
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
207

208
        # Component instances must be components
209
        if not hasattr(instance, "__canals_component__"):
1✔
210
            raise PipelineValidationError(
×
211
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
212
            )
213

214
        # Create the component's input and output sockets
215
        inputs = getattr(instance.run, "__canals_input__", {})
1✔
216
        outputs = getattr(instance.run, "__canals_output__", {})
1✔
217
        input_sockets = {name: InputSocket(**data) for name, data in inputs.items()}
1✔
218
        output_sockets = {name: OutputSocket(**data) for name, data in outputs.items()}
1✔
219

220
        # Add component to the graph, disconnected
221
        logger.debug("Adding component '%s' (%s)", name, instance)
1✔
222
        self.graph.add_node(
1✔
223
            name,
224
            instance=instance,
225
            input_sockets=input_sockets,
226
            output_sockets=output_sockets,
227
            visits=0,
228
        )
229

230
    def connect(self, connect_from: str, connect_to: str) -> None:
1✔
231
        """
232
        Connects two components together. All components to connect must exist in the pipeline.
233
        If connecting to an component that has several output connections, specify the inputs and output names as
234
        'component_name.connections_name'.
235

236
        Args:
237
            connect_from: the component that delivers the value. This can be either just a component name or can be
238
                in the format `component_name.connection_name` if the component has multiple outputs.
239
            connect_to: the component that receives the value. This can be either just a component name or can be
240
                in the format `component_name.connection_name` if the component has multiple inputs.
241

242
        Returns:
243
            None
244

245
        Raises:
246
            PipelineConnectError: if the two components cannot be connected (for example if one of the components is
247
                not present in the pipeline, or the connections don't match by type, and so on).
248
        """
249
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
250
        from_node, from_socket_name = _parse_connection_name(connect_from)
1✔
251
        to_node, to_socket_name = _parse_connection_name(connect_to)
1✔
252

253
        # Get the nodes data.
254
        try:
1✔
255
            from_sockets = self.graph.nodes[from_node]["output_sockets"]
1✔
256
        except KeyError as exc:
1✔
257
            raise ValueError(f"Component named {from_node} not found in the pipeline.") from exc
1✔
258

259
        try:
1✔
260
            to_sockets = self.graph.nodes[to_node]["input_sockets"]
1✔
261
        except KeyError as exc:
1✔
262
            raise ValueError(f"Component named {to_node} not found in the pipeline.") from exc
1✔
263

264
        # If the name of either socket is given, get the socket
265
        if from_socket_name:
1✔
266
            from_socket = from_sockets.get(from_socket_name, None)
1✔
267
            if not from_socket:
1✔
268
                raise PipelineConnectError(
1✔
269
                    f"'{from_node}.{from_socket_name} does not exist. "
270
                    f"Output connections of {from_node} are: "
271
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
272
                )
273
        if to_socket_name:
1✔
274
            to_socket = to_sockets.get(to_socket_name, None)
1✔
275
            if not to_socket:
1✔
276
                raise PipelineConnectError(
1✔
277
                    f"'{to_node}.{to_socket_name} does not exist. "
278
                    f"Input connections of {to_node} are: "
279
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
280
                )
281

282
        # Look for an unambiguous connection among the possible ones.
283
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
284
        from_sockets = [from_socket] if from_socket_name else list(from_sockets.values())
1✔
285
        to_sockets = [to_socket] if to_socket_name else list(to_sockets.values())
1✔
286
        from_socket, to_socket = _find_unambiguous_connection(
1✔
287
            sender_node=from_node, sender_sockets=from_sockets, receiver_node=to_node, receiver_sockets=to_sockets
288
        )
289

290
        # Connect the components on these sockets
291
        self._direct_connect(from_node=from_node, from_socket=from_socket, to_node=to_node, to_socket=to_socket)
1✔
292

293
    def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: str, to_socket: InputSocket) -> None:
1✔
294
        """
295
        Directly connect socket to socket. This method does not type-check the connections: use 'Pipeline.connect()'
296
        instead (which uses 'find_unambiguous_connection()' to validate types).
297
        """
298
        # Make sure the receiving socket isn't already connected - sending sockets can be connected as many times as needed,
299
        # so they don't need this check
300
        if to_socket.sender:
1✔
301
            raise PipelineConnectError(
1✔
302
                f"Cannot connect '{from_node}.{from_socket.name}' with '{to_node}.{to_socket.name}': "
303
                f"{to_node}.{to_socket.name} is already connected to {to_socket.sender}.\n"
304
            )
305

306
        # Create the connection
307
        logger.debug("Connecting '%s.%s' to '%s.%s'", from_node, from_socket.name, to_node, to_socket.name)
1✔
308
        edge_key = f"{from_socket.name}/{to_socket.name}"
1✔
309
        self.graph.add_edge(
1✔
310
            from_node,
311
            to_node,
312
            key=edge_key,
313
            conn_type=_type_name(from_socket.type),
314
            from_socket=from_socket,
315
            to_socket=to_socket,
316
        )
317

318
        # Stores the name of the node that will send its output to this socket
319
        to_socket.sender = from_node
1✔
320

321
    def get_component(self, name: str) -> Component:
1✔
322
        """
323
        Returns an instance of a component.
324

325
        Args:
326
            name: the name of the component
327

328
        Returns:
329
            The instance of that component.
330

331
        Raises:
332
            ValueError: if a component with that name is not present in the pipeline.
333
        """
334
        try:
×
335
            return self.graph.nodes[name]["instance"]
×
336
        except KeyError as exc:
×
337
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
×
338

339
    def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
1✔
340
        """
341
        Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
342
        Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
343

344
        Args:
345
            path: where to save the diagram.
346
            engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'.
347
                Default is 'mermaid-img'.
348

349
        Returns:
350
            None
351

352
        Raises:
353
            ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
354
            HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
355
        """
356
        sockets = {
1✔
357
            comp: "\n".join([f"{name}: {socket}" for name, socket in data.get("input_sockets", {}).items()])
358
            for comp, data in self.graph.nodes(data=True)
359
        }
360
        print(sockets)
1✔
361
        _draw(graph=deepcopy(self.graph), path=path, engine=engine)
1✔
362

363
    def warm_up(self):
1✔
364
        """
365
        Make sure all nodes are warm.
366

367
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
368
        without re-initializing everything.
369
        """
370
        for node in self.graph.nodes:
1✔
371
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
372
                logger.info("Warming up component %s...", node)
×
373
                self.graph.nodes[node]["instance"].warm_up()
×
374

375
    def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
1✔
376
        """
377
        Runs the pipeline.
378

379
        Args:
380
            data: the inputs to give to the input components of the Pipeline.
381
            parameters: a dictionary with all the parameters of all the components, namespaced by component.
382
            debug: whether to collect and return debug information.
383

384
        Returns:
385
            A dictionary with the outputs of the output components of the Pipeline.
386

387
        Raises:
388
            PipelineRuntimeError: if the any of the components fail or return unexpected output.
389
        """
390
        # **** The Pipeline.run() algorithm ****
391
        #
392
        # Nodes are run as soon as an input for them appears in the inputs buffer.
393
        # When there's more than a node at once in the buffer (which means some
394
        # branches are running in parallel or that there are loops) they are selected to
395
        # run in FIFO order by the `inputs_buffer` OrderedDict.
396
        #
397
        # Inputs are labeled with the name of the node they're aimed for:
398
        #
399
        #   ````
400
        #   inputs_buffer[target_node] = {"input_name": input_value, ...}
401
        #   ```
402
        #
403
        # Nodes should wait until all the necessary input data has arrived before running.
404
        # If they're popped from the input_buffer before they're ready, they're put back in.
405
        # If the pipeline has branches of different lengths, it's possible that a node has to
406
        # wait a bit and let other nodes pass before receiving all the input data it needs.
407
        #
408
        # Chetsheet for networkx data access:
409
        # - Name of the node       # self.graph.nodes  (List[str])
410
        # - Node instance          # self.graph.nodes[node]["instance"]
411
        # - Input nodes            # [e[0] for e in self.graph.in_edges(node)]
412
        # - Output nodes           # [e[1] for e in self.graph.out_edges(node)]
413
        # - Output edges           # [e[2]["label"] for e in self.graph.out_edges(node, data=True)]
414
        #
415
        # if debug:
416
        #     os.makedirs("debug", exist_ok=True)
417

418
        data = _validate_pipeline_input(self.graph, input_values=data)
1✔
419
        self._clear_visits_count()
1✔
420
        self.warm_up()
1✔
421

422
        logger.info("Pipeline execution started.")
1✔
423
        inputs_buffer = OrderedDict(
1✔
424
            {
425
                node: {key: value for key, value in input_data.items() if value is not None}
426
                for node, input_data in data.items()
427
            }
428
        )
429
        pipeline_output: Dict[str, Dict[str, Any]] = {}
1✔
430

431
        if debug:
1✔
432
            logger.info("Debug mode ON.")
×
433
        self.debug = {}
1✔
434

435
        # *** PIPELINE EXECUTION LOOP ***
436
        # We select the nodes to run by popping them in FIFO order from the inputs buffer.
437
        step = 0
1✔
438
        while inputs_buffer:
1✔
439
            step += 1
1✔
440
            if debug:
1✔
441
                self._record_pipeline_step(step, inputs_buffer, pipeline_output)
×
442
            logger.debug("> Queue at step %s: %s", step, {k: list(v.keys()) for k, v in inputs_buffer.items()})
1✔
443

444
            component_name, inputs = inputs_buffer.popitem(last=False)  # FIFO
1✔
445

446
            # Make sure it didn't run too many times already
447
            self._check_max_loops(component_name)
1✔
448

449
            # **** IS IT MY TURN YET? ****
450
            # Check if the component should be run or not
451
            action = self._calculate_action(name=component_name, inputs=inputs, inputs_buffer=inputs_buffer)
1✔
452

453
            # This component is missing data: let's put it back in the queue and wait.
454
            if action == "wait":
1✔
455
                if not inputs_buffer:
1✔
456
                    # What if there are no components to wait for?
457
                    raise PipelineRuntimeError(
×
458
                        f"'{component_name}' is stuck waiting for input, but there are no other components to run. "
459
                        "This is likely a Canals bug. Open an issue at https://github.com/deepset-ai/canals."
460
                    )
461

462
                inputs_buffer[component_name] = inputs
1✔
463
                continue
1✔
464

465
            # This component did not receive the input it needs: it must be on a skipped branch. Let's not run it.
466
            if action == "skip":
1✔
467
                self.graph.nodes[component_name]["visits"] += 1
1✔
468
                inputs_buffer = self._skip_downstream_unvisited_nodes(
1✔
469
                    component_name=component_name, inputs_buffer=inputs_buffer
470
                )
471
                continue
1✔
472

473
            if action == "remove":
1✔
474
                # This component has no reason of being in the run queue and we need to remove it. For example, this can happen to components that are connected to skipped branches of the pipeline.
475
                continue
×
476

477
            # **** RUN THE NODE ****
478
            # It is our turn! The node is ready to run and all necessary inputs are present
479
            output = self._run_component(name=component_name, inputs=inputs)
1✔
480

481
            # **** PROCESS THE OUTPUT ****
482
            # The node run successfully. Let's store or distribute the output it produced, if it's valid.
483
            if not self.graph.out_edges(component_name):
1✔
484
                # Note: if a node outputs many times (like in loops), the output will be overwritten
485
                pipeline_output[component_name] = output
1✔
486
            else:
487
                inputs_buffer = self._route_output(
1✔
488
                    node_results=output, node_name=component_name, inputs_buffer=inputs_buffer
489
                )
490

491
        if debug:
1✔
492
            self._record_pipeline_step(step + 1, inputs_buffer, pipeline_output)
×
493

494
            # Save to json
495
            os.makedirs(self.debug_path, exist_ok=True)
×
496
            with open(self.debug_path / "data.json", "w", encoding="utf-8") as datafile:
×
497
                json.dump(self.debug, datafile, indent=4, default=str)
×
498

499
            # Store in the output
500
            pipeline_output["_debug"] = self.debug  # type: ignore
×
501

502
        logger.info("Pipeline executed successfully.")
1✔
503
        return pipeline_output
1✔
504

505
    def _record_pipeline_step(self, step, inputs_buffer, pipeline_output):
1✔
506
        """
507
        Stores a snapshot of this step into the self.debug dictionary of the pipeline.
508
        """
509
        mermaid_graph = _convert_for_debug(deepcopy(self.graph))
×
510
        self.debug[step] = {
×
511
            "time": datetime.datetime.now(),
512
            "inputs_buffer": list(inputs_buffer.items()),
513
            "pipeline_output": pipeline_output,
514
            "diagram": mermaid_graph,
515
        }
516

517
    def _clear_visits_count(self):
1✔
518
        """
519
        Make sure all nodes's visits count is zero.
520
        """
521
        for node in self.graph.nodes:
1✔
522
            self.graph.nodes[node]["visits"] = 0
1✔
523

524
    def _check_max_loops(self, component_name: str):
1✔
525
        """
526
        Verify whether this component run too many times.
527
        """
528
        if self.graph.nodes[component_name]["visits"] > self.max_loops_allowed:
1✔
529
            raise PipelineMaxLoops(
1✔
530
                f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{component_name}'."
531
            )
532

533
    # This function is complex so it contains quite some logic, it needs tons of information
534
    # regarding a component to understand what action it should take so we have many local
535
    # variables and to keep things simple we also have multiple returns.
536
    # In the end this amount of information makes it easier to understand the internal logic so
537
    # we chose to ignore these pylint warnings.
538
    def _calculate_action(  # pylint: disable=too-many-locals, too-many-return-statements
1✔
539
        self, name: str, inputs: Dict[str, Any], inputs_buffer: Dict[str, Any]
540
    ) -> Literal["run", "wait", "skip", "remove"]:
541
        """
542
        Calculates the action to take for the component specified by `name`.
543
        There are four possible actions:
544
            * run
545
            * wait
546
            * skip
547
            * remove
548

549
        The below conditions are evaluated in this order.
550

551
        Component will run if at least one of the following statements is true:
552
            * It received all mandatory inputs
553
            * It received all mandatory inputs and it has no optional inputs
554
            * It received all mandatory inputs and all optional inputs are skipped
555
            * It received all mandatory inputs and some optional inputs and the rest are skipped
556
            * It received some of its inputs and the others are defaulted
557
            * It's the first component of the pipeline
558

559
        Component will wait if:
560
            * It received some of its inputs and the other are not skipped
561
            * It received all mandatory inputs and some optional inputs have not been skipped
562

563
        Component will be skipped if:
564
            * It never ran nor waited
565

566
        Component will be removed if:
567
            * It ran or waited at least once but can't do it again
568

569
        If none of the above condition is met a PipelineRuntimeError is raised.
570

571
        For simplicity sake input components that create a cycle, or components that already ran
572
        and don't create a cycle are considered as skipped.
573

574
        Args:
575
            name: Name of the component
576
            inputs: Values that the component will take as input
577
            inputs_buffer: Other components' inputs
578

579
        Returns:
580
            Action to take for component specifing whether it should run, wait, skip or be removed
581

582
        Raises:
583
            PipelineRuntimeError: If action to take can't be determined
584
        """
585

586
        # Upstream components/socket pairs the current component is connected to
587
        input_components = {
1✔
588
            from_node: data["to_socket"].name for from_node, _, data in self.graph.in_edges(name, data=True)
589
        }
590
        # Sockets that have received inputs from upstream components
591
        received_input_sockets = set(inputs.keys())
1✔
592

593
        # All components inputs, whether they're connected, default or pipeline inputs
594
        input_sockets: Dict[str, InputSocket] = self.graph.nodes[name]["input_sockets"].keys()
1✔
595
        optional_input_sockets = {
1✔
596
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.is_optional
597
        }
598
        mandatory_input_sockets = {
1✔
599
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.is_optional
600
        }
601

602
        # Components that are in the inputs buffer and have no inputs assigned are considered skipped
603
        skipped_components = {n for n, v in inputs_buffer.items() if not v}
1✔
604

605
        # Sockets that have their upstream component marked as skipped
606
        skipped_optional_input_sockets = {
1✔
607
            sockets["to_socket"].name
608
            for from_node, _, sockets in self.graph.in_edges(name, data=True)
609
            if from_node in skipped_components and sockets["to_socket"].name in optional_input_sockets
610
        }
611

612
        for from_node, socket in input_components.items():
1✔
613
            if socket not in optional_input_sockets:
1✔
614
                continue
1✔
615
            loops_back = networkx.has_path(self.graph, name, from_node)
1✔
616
            has_run = self.graph.nodes[from_node]["visits"] > 0
1✔
617
            if loops_back or has_run:
1✔
618
                # Consider all input components that loop back to current component
619
                # or that have already run at least once as skipped.
620
                # This must be done to correctly handle cycles in the pipeline or we
621
                # would reach a dead lock in components that have multiple inputs and
622
                # one of these forms a cycle.
623
                skipped_optional_input_sockets.add(socket)
1✔
624

625
        ##############
626
        # RUN CHECKS #
627
        ##############
628
        if (
1✔
629
            mandatory_input_sockets.issubset(received_input_sockets)
630
            and input_sockets == received_input_sockets | mandatory_input_sockets | skipped_optional_input_sockets
631
        ):
632
            # We received all mandatory inputs and:
633
            #   * There are no optional inputs or
634
            #   * All optional inputs are skipped or
635
            #   * We received part of the optional inputs, the rest are skipped
636
            if not optional_input_sockets:
1✔
637
                logger.debug("Component '%s' is ready to run. All mandatory inputs received.", name)
1✔
638
            else:
639
                logger.debug(
1✔
640
                    "Component '%s' is ready to run. All mandatory inputs received, skipped optional inputs: %s",
641
                    name,
642
                    skipped_optional_input_sockets,
643
                )
644
            return "run"
1✔
645

646
        if set(input_components.values()).issubset(received_input_sockets):
1✔
647
            # We have data from each connected input component.
648
            # We reach this when the current component is the first of the pipeline or
649
            # when it has defaults and all its input components have run.
650
            logger.debug("Component '%s' is ready to run. All expected inputs were received.", name)
1✔
651
            return "run"
1✔
652

653
        ###############
654
        # WAIT CHECKS #
655
        ###############
656
        if mandatory_input_sockets == received_input_sockets and skipped_optional_input_sockets.issubset(
1✔
657
            optional_input_sockets
658
        ):
659
            # We received all of the inputs we need, but some optional inputs have not been run or skipped yet
660
            logger.debug(
×
661
                "Component '%s' is waiting. All mandatory inputs received, some optional are not skipped: %s",
662
                name,
663
                optional_input_sockets - skipped_optional_input_sockets,
664
            )
665
            return "wait"
×
666

667
        if any(self.graph.nodes[n]["visits"] == 0 for n in input_components.keys()):
1✔
668
            # Some upstream component that must send input to the current component has yet to run.
669
            logger.debug(
1✔
670
                "Component '%s' is waiting. Missing inputs: %s",
671
                name,
672
                set(input_components.values()),
673
            )
674
            return "wait"
1✔
675

676
        ###############
677
        # SKIP CHECKS #
678
        ###############
679
        if self.graph.nodes[name]["visits"] == 0:
1✔
680
            # It's the first time visiting this component, if it can't run nor wait
681
            # it's fine skipping it at this point.
682
            logger.debug("Component '%s' is skipped. It can't run nor wait.", name)
1✔
683
            return "skip"
1✔
684

685
        #################
686
        # REMOVE CHECKS #
687
        #################
688
        if self.graph.nodes[name]["visits"] > 0:
×
689
            # This component has already been visited at least once. If it can't run nor wait
690
            # there is no reason to skip it again. So we it must be removed.
691
            logger.debug("Component '%s' is removed. It can't run, wait or skip.", name)
×
692
            return "remove"
×
693

694
        # Can't determine action to take
695
        raise PipelineRuntimeError(
×
696
            f"Can't determine Component '{name}' action. "
697
            f"Mandatory input sockets: {mandatory_input_sockets}, "
698
            f"optional input sockets: {optional_input_sockets}, "
699
            f"received input: {list(inputs.keys())}, "
700
            f"input components: {list(input_components.keys())}, "
701
            f"skipped components: {skipped_components}, "
702
            f"skipped optional inputs: {skipped_optional_input_sockets}."
703
            f"This is likely a Canals bug. Please open an issue at https://github.com/deepset-ai/canals.",
704
        )
705

706
    def _skip_downstream_unvisited_nodes(self, component_name: str, inputs_buffer: OrderedDict) -> OrderedDict:
1✔
707
        """
708
        When a component is skipped, put all downstream nodes in the inputs buffer too: the might be skipped too,
709
        unless they are merge nodes. They will be evaluated later by the pipeline execution loop.
710
        """
711
        downstream_nodes = [e[1] for e in self.graph.out_edges(component_name)]
1✔
712
        for downstream_node in downstream_nodes:
1✔
713
            if downstream_node in inputs_buffer:
1✔
714
                continue
1✔
715
            if self.graph.nodes[downstream_node]["visits"] == 0:
1✔
716
                # Skip downstream nodes only if they never been visited
717
                inputs_buffer[downstream_node] = {}
1✔
718
        return inputs_buffer
1✔
719

720
    def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
1✔
721
        """
722
        Once we're confident this component is ready to run, run it and collect the output.
723
        """
724
        self.graph.nodes[name]["visits"] += 1
1✔
725
        instance = self.graph.nodes[name]["instance"]
1✔
726
        try:
1✔
727
            logger.info("* Running %s (visits: %s)", name, self.graph.nodes[name]["visits"])
1✔
728
            logger.debug("   '%s' inputs: %s", name, inputs)
1✔
729

730
            outputs = instance.run(**inputs)
1✔
731

732
            # Unwrap the output
733
            logger.debug("   '%s' outputs: %s\n", name, outputs)
1✔
734

735
        except Exception as e:
×
736
            raise PipelineRuntimeError(
×
737
                f"{name} raised '{e.__class__.__name__}: {e}' \nInputs: {inputs}\n\n"
738
                "See the stacktrace above for more information."
739
            ) from e
740

741
        return outputs
1✔
742

743
    def _route_output(
1✔
744
        self,
745
        node_name: str,
746
        node_results: Dict[str, Any],
747
        inputs_buffer: OrderedDict,
748
    ) -> OrderedDict:
749
        """
750
        Distrubute the outputs of the component into the input buffer of downstream components.
751

752
        Returns the updated inputs buffer.
753
        """
754
        # This is not a terminal node: find out where the output goes, to which nodes and along which edge
755
        is_decision_node_for_loop = (
1✔
756
            any(networkx.has_path(self.graph, edge[1], node_name) for edge in self.graph.out_edges(node_name))
757
            and len(self.graph.out_edges(node_name)) > 1
758
        )
759
        for edge_data in self.graph.out_edges(node_name, data=True):
1✔
760
            to_socket = edge_data[2]["to_socket"]
1✔
761
            from_socket = edge_data[2]["from_socket"]
1✔
762
            target_node = edge_data[1]
1✔
763

764
            # If this is a decision node and a loop is involved, we add to the input buffer only the nodes
765
            # that received their expected output and we leave the others out of the queue.
766
            if is_decision_node_for_loop and node_results.get(from_socket.name, None) is None:
1✔
767
                if networkx.has_path(self.graph, target_node, node_name):
1✔
768
                    # In case we're choosing to leave a loop, do not put the loop's node in the buffer.
769
                    logger.debug(
1✔
770
                        "Not adding '%s' to the inputs buffer: we're leaving the loop.",
771
                        target_node,
772
                    )
773
                else:
774
                    # In case we're choosing to stay in a loop, do not put the external node in the buffer.
775
                    logger.debug(
1✔
776
                        "Not adding '%s' to the inputs buffer: we're staying in the loop.",
777
                        target_node,
778
                    )
779
            else:
780
                # In all other cases, populate the inputs buffer for all downstream nodes, setting None to any
781
                # edge that did not receive input.
782
                if target_node not in inputs_buffer:
1✔
783
                    inputs_buffer[target_node] = {}  # Create the buffer for the downstream node if it's not there yet
1✔
784

785
                value_to_route = node_results.get(from_socket.name, None)
1✔
786
                if value_to_route is not None:
1✔
787
                    inputs_buffer[target_node][to_socket.name] = value_to_route
1✔
788

789
        return inputs_buffer
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc