• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5834361039

11 Aug 2023 03:31PM UTC coverage: 93.466% (-0.06%) from 93.524%
5834361039

Pull #82

github

web-flow
Merge 5f0b60f21 into 19f2e8fac
Pull Request #82: feat: remove `init_parameters` decorator

178 of 183 branches covered (97.27%)

Branch coverage included in aggregate %.

666 of 720 relevant lines covered (92.5%)

0.93 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.81
canals/pipeline/pipeline.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
from typing import Optional, Any, Dict, List, Literal, Union
1✔
5

6
import os
1✔
7
import json
1✔
8
import datetime
1✔
9
import logging
1✔
10
from pathlib import Path
1✔
11
from copy import deepcopy
1✔
12
from collections import OrderedDict
1✔
13

14
import networkx
1✔
15

16
from canals.component import component, Component
1✔
17
from canals.errors import (
1✔
18
    PipelineError,
19
    PipelineConnectError,
20
    PipelineMaxLoops,
21
    PipelineRuntimeError,
22
    PipelineValidationError,
23
)
24
from canals.pipeline.draw import _draw, _convert_for_debug, RenderingEngines
1✔
25
from canals.pipeline.sockets import InputSocket, OutputSocket
1✔
26
from canals.pipeline.validation import _validate_pipeline_input
1✔
27
from canals.pipeline.connections import _parse_connection_name, _find_unambiguous_connection
1✔
28
from canals.utils import _type_name
1✔
29

30
logger = logging.getLogger(__name__)
1✔
31

32

33
class Pipeline:
1✔
34
    """
35
    Components orchestration engine.
36

37
    Builds a graph of components and orchestrates their execution according to the execution graph.
38
    """
39

40
    def __init__(
1✔
41
        self,
42
        metadata: Optional[Dict[str, Any]] = None,
43
        max_loops_allowed: int = 100,
44
        debug_path: Union[Path, str] = Path(".canals_debug/"),
45
    ):
46
        """
47
        Creates the Pipeline.
48

49
        Args:
50
            metadata: arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in
51
                this dictionary can be serialized and deserialized if you wish to save this pipeline to file with
52
                `save_pipelines()/load_pipelines()`.
53
            max_loops_allowed: how many times the pipeline can run the same node before throwing an exception.
54
            debug_path: when debug is enabled in `run()`, where to save the debug data.
55
        """
56
        self.metadata = metadata or {}
1✔
57
        self.max_loops_allowed = max_loops_allowed
1✔
58
        self.graph = networkx.MultiDiGraph()
1✔
59
        self.debug: Dict[int, Dict[str, Any]] = {}
1✔
60
        self.debug_path = Path(debug_path)
1✔
61

62
    def __eq__(self, other) -> bool:
1✔
63
        """
64
        Equal pipelines share every metadata, node and edge, but they're not required to use
65
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
66
        """
67
        if (
1✔
68
            not isinstance(other, type(self))
69
            or not getattr(self, "metadata") == getattr(other, "metadata")
70
            or not getattr(self, "max_loops_allowed") == getattr(other, "max_loops_allowed")
71
            or not hasattr(self, "graph")
72
            or not hasattr(other, "graph")
73
        ):
74
            return False
×
75

76
        return (
1✔
77
            self.graph.adj == other.graph.adj
78
            and self._comparable_nodes_list(self.graph) == self._comparable_nodes_list(other.graph)
79
            and self.graph.graph == other.graph.graph
80
        )
81

82
    def to_dict(self) -> Dict[str, Any]:
1✔
83
        """
84
        Returns this Pipeline instance as a dictionary.
85
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
86
        """
87
        components = {name: instance.to_dict() for name, instance in self.graph.nodes(data="instance")}
1✔
88
        connections = []
1✔
89
        for sender, receiver, sockets in self.graph.edges:
1✔
90
            (sender_socket, receiver_socket) = sockets.split("/")
1✔
91
            connections.append(
1✔
92
                {
93
                    "sender": f"{sender}.{sender_socket}",
94
                    "receiver": f"{receiver}.{receiver_socket}",
95
                }
96
            )
97
        return {
1✔
98
            "metadata": self.metadata,
99
            "max_loops_allowed": self.max_loops_allowed,
100
            "components": components,
101
            "connections": connections,
102
        }
103

104
    @classmethod
1✔
105
    def from_dict(cls, data: Dict[str, Any], **kwargs) -> "Pipeline":
1✔
106
        """
107
        Creates a Pipeline instance from a dictionary.
108
        A sample `data` dictionary could be formatted like so:
109
        ```
110
        {
111
            "metadata": {"test": "test"},
112
            "max_loops_allowed": 100,
113
            "components": {
114
                "add_two": {
115
                    "type": "AddFixedValue",
116
                    "hash": "123",
117
                    "init_parameters": {"add": 2},
118
                },
119
                "add_default": {
120
                    "type": "AddFixedValue",
121
                    "hash": "456",
122
                    "init_parameters": {"add": 1},
123
                },
124
                "double": {
125
                    "type": "Double",
126
                    "hash": "789"
127
                },
128
            },
129
            "connections": [
130
                {"sender": "add_two.result", "receiver": "double.value"},
131
                {"sender": "double.value", "receiver": "add_default.value"},
132
            ],
133
        }
134
        ```
135

136
        Supported kwargs:
137
        `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
138
        """
139
        metadata = data.get("metadata", {})
1✔
140
        max_loops_allowed = data.get("max_loops_allowed", 100)
1✔
141
        debug_path = Path(data.get("debug_path", ".canals_debug/"))
1✔
142
        pipe = cls(
1✔
143
            metadata=metadata,
144
            max_loops_allowed=max_loops_allowed,
145
            debug_path=debug_path,
146
        )
147
        components_to_reuse = kwargs.get("components", {})
1✔
148
        for name, component_data in data.get("components", {}).items():
1✔
149
            if name in components_to_reuse:
1✔
150
                # Reuse an instance
151
                instance = components_to_reuse[name]
1✔
152
            else:
153
                if "type" not in component_data:
1✔
154
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
155
                if component_data["type"] not in component.registry:
1✔
156
                    raise PipelineError(f"Component '{component_data['type']}' not imported.")
1✔
157
                # Create a new one
158
                instance = component.registry[component_data["type"]].from_dict(component_data)
1✔
159
            pipe.add_component(name=name, instance=instance)
1✔
160

161
        for connection in data.get("connections", []):
1✔
162
            if "sender" not in connection:
1✔
163
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
164
            if "receiver" not in connection:
1✔
165
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
166
            pipe.connect(connect_from=connection["sender"], connect_to=connection["receiver"])
1✔
167

168
        return pipe
1✔
169

170
    def _comparable_nodes_list(self, graph: networkx.MultiDiGraph) -> List[Dict[str, Any]]:
1✔
171
        """
172
        Replaces instances of nodes with their class name and defaults list in order to make sure they're comparable.
173
        """
174
        nodes = []
1✔
175
        for node in graph.nodes:
1✔
176
            comparable_node = graph.nodes[node]
1✔
177
            if hasattr(comparable_node, "defaults"):
1✔
178
                comparable_node["defaults"] = comparable_node["instance"].defaults
×
179
            comparable_node["instance"] = comparable_node["instance"].__class__
1✔
180
            nodes.append(comparable_node)
1✔
181
        nodes.sort()
1✔
182
        return nodes
1✔
183

184
    def add_component(self, name: str, instance: Component) -> None:
1✔
185
        """
186
        Create a component for the given component. Components are not connected to anything by default:
187
        use `Pipeline.connect()` to connect components together.
188

189
        Component names must be unique, but component instances can be reused if needed.
190

191
        Args:
192
            name: the name of the component.
193
            instance: the component instance.
194

195
        Returns:
196
            None
197

198
        Raises:
199
            ValueError: if a component with the same name already exists
200
            PipelineValidationError: if the given instance is not a Canals component
201
        """
202
        # Component names are unique
203
        if name in self.graph.nodes:
1✔
204
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
205

206
        # Components can't be named `_debug`
207
        if name == "_debug":
1✔
208
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
209

210
        # Component instances must be components
211
        if not hasattr(instance, "__canals_component__"):
1✔
212
            raise PipelineValidationError(
×
213
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
214
            )
215

216
        # Create the component's input and output sockets
217
        inputs = getattr(instance.run, "__canals_input__", {})
1✔
218
        outputs = getattr(instance.run, "__canals_output__", {})
1✔
219
        input_sockets = {name: InputSocket(**data) for name, data in inputs.items()}
1✔
220
        output_sockets = {name: OutputSocket(**data) for name, data in outputs.items()}
1✔
221

222
        # Add component to the graph, disconnected
223
        logger.debug("Adding component '%s' (%s)", name, instance)
1✔
224
        self.graph.add_node(
1✔
225
            name,
226
            instance=instance,
227
            input_sockets=input_sockets,
228
            output_sockets=output_sockets,
229
            visits=0,
230
        )
231

232
    def connect(self, connect_from: str, connect_to: str) -> None:
1✔
233
        """
234
        Connects two components together. All components to connect must exist in the pipeline.
235
        If connecting to an component that has several output connections, specify the inputs and output names as
236
        'component_name.connections_name'.
237

238
        Args:
239
            connect_from: the component that delivers the value. This can be either just a component name or can be
240
                in the format `component_name.connection_name` if the component has multiple outputs.
241
            connect_to: the component that receives the value. This can be either just a component name or can be
242
                in the format `component_name.connection_name` if the component has multiple inputs.
243

244
        Returns:
245
            None
246

247
        Raises:
248
            PipelineConnectError: if the two components cannot be connected (for example if one of the components is
249
                not present in the pipeline, or the connections don't match by type, and so on).
250
        """
251
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
252
        from_node, from_socket_name = _parse_connection_name(connect_from)
1✔
253
        to_node, to_socket_name = _parse_connection_name(connect_to)
1✔
254

255
        # Get the nodes data.
256
        try:
1✔
257
            from_sockets = self.graph.nodes[from_node]["output_sockets"]
1✔
258
        except KeyError as exc:
1✔
259
            raise ValueError(f"Component named {from_node} not found in the pipeline.") from exc
1✔
260

261
        try:
1✔
262
            to_sockets = self.graph.nodes[to_node]["input_sockets"]
1✔
263
        except KeyError as exc:
1✔
264
            raise ValueError(f"Component named {to_node} not found in the pipeline.") from exc
1✔
265

266
        # If the name of either socket is given, get the socket
267
        if from_socket_name:
1✔
268
            from_socket = from_sockets.get(from_socket_name, None)
1✔
269
            if not from_socket:
1✔
270
                raise PipelineConnectError(
1✔
271
                    f"'{from_node}.{from_socket_name} does not exist. "
272
                    f"Output connections of {from_node} are: "
273
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
274
                )
275
        if to_socket_name:
1✔
276
            to_socket = to_sockets.get(to_socket_name, None)
1✔
277
            if not to_socket:
1✔
278
                raise PipelineConnectError(
1✔
279
                    f"'{to_node}.{to_socket_name} does not exist. "
280
                    f"Input connections of {to_node} are: "
281
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
282
                )
283

284
        # Look for an unambiguous connection among the possible ones.
285
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
286
        from_sockets = [from_socket] if from_socket_name else list(from_sockets.values())
1✔
287
        to_sockets = [to_socket] if to_socket_name else list(to_sockets.values())
1✔
288
        from_socket, to_socket = _find_unambiguous_connection(
1✔
289
            sender_node=from_node, sender_sockets=from_sockets, receiver_node=to_node, receiver_sockets=to_sockets
290
        )
291

292
        # Connect the components on these sockets
293
        self._direct_connect(from_node=from_node, from_socket=from_socket, to_node=to_node, to_socket=to_socket)
1✔
294

295
    def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: str, to_socket: InputSocket) -> None:
1✔
296
        """
297
        Directly connect socket to socket. This method does not type-check the connections: use 'Pipeline.connect()'
298
        instead (which uses 'find_unambiguous_connection()' to validate types).
299
        """
300
        # Make sure the receiving socket isn't already connected - sending sockets can be connected as many times as needed,
301
        # so they don't need this check
302
        if to_socket.sender:
1✔
303
            raise PipelineConnectError(
1✔
304
                f"Cannot connect '{from_node}.{from_socket.name}' with '{to_node}.{to_socket.name}': "
305
                f"{to_node}.{to_socket.name} is already connected to {to_socket.sender}.\n"
306
            )
307

308
        # Create the connection
309
        logger.debug("Connecting '%s.%s' to '%s.%s'", from_node, from_socket.name, to_node, to_socket.name)
1✔
310
        edge_key = f"{from_socket.name}/{to_socket.name}"
1✔
311
        self.graph.add_edge(
1✔
312
            from_node,
313
            to_node,
314
            key=edge_key,
315
            conn_type=_type_name(from_socket.type),
316
            from_socket=from_socket,
317
            to_socket=to_socket,
318
        )
319

320
        # Stores the name of the node that will send its output to this socket
321
        to_socket.sender = from_node
1✔
322

323
    def get_component(self, name: str) -> Component:
1✔
324
        """
325
        Returns an instance of a component.
326

327
        Args:
328
            name: the name of the component
329

330
        Returns:
331
            The instance of that component.
332

333
        Raises:
334
            ValueError: if a component with that name is not present in the pipeline.
335
        """
336
        try:
×
337
            return self.graph.nodes[name]["instance"]
×
338
        except KeyError as exc:
×
339
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
×
340

341
    def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
1✔
342
        """
343
        Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
344
        Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
345

346
        Args:
347
            path: where to save the diagram.
348
            engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'.
349
                Default is 'mermaid-img'.
350

351
        Returns:
352
            None
353

354
        Raises:
355
            ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
356
            HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
357
        """
358
        sockets = {
1✔
359
            comp: "\n".join([f"{name}: {socket}" for name, socket in data.get("input_sockets", {}).items()])
360
            for comp, data in self.graph.nodes(data=True)
361
        }
362
        print(sockets)
1✔
363
        _draw(graph=deepcopy(self.graph), path=path, engine=engine)
1✔
364

365
    def warm_up(self):
1✔
366
        """
367
        Make sure all nodes are warm.
368

369
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
370
        without re-initializing everything.
371
        """
372
        for node in self.graph.nodes:
1✔
373
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
374
                logger.info("Warming up component %s...", node)
×
375
                self.graph.nodes[node]["instance"].warm_up()
×
376

377
    def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
1✔
378
        """
379
        Runs the pipeline.
380

381
        Args:
382
            data: the inputs to give to the input components of the Pipeline.
383
            parameters: a dictionary with all the parameters of all the components, namespaced by component.
384
            debug: whether to collect and return debug information.
385

386
        Returns:
387
            A dictionary with the outputs of the output components of the Pipeline.
388

389
        Raises:
390
            PipelineRuntimeError: if the any of the components fail or return unexpected output.
391
        """
392
        # **** The Pipeline.run() algorithm ****
393
        #
394
        # Nodes are run as soon as an input for them appears in the inputs buffer.
395
        # When there's more than a node at once in the buffer (which means some
396
        # branches are running in parallel or that there are loops) they are selected to
397
        # run in FIFO order by the `inputs_buffer` OrderedDict.
398
        #
399
        # Inputs are labeled with the name of the node they're aimed for:
400
        #
401
        #   ````
402
        #   inputs_buffer[target_node] = {"input_name": input_value, ...}
403
        #   ```
404
        #
405
        # Nodes should wait until all the necessary input data has arrived before running.
406
        # If they're popped from the input_buffer before they're ready, they're put back in.
407
        # If the pipeline has branches of different lengths, it's possible that a node has to
408
        # wait a bit and let other nodes pass before receiving all the input data it needs.
409
        #
410
        # Chetsheet for networkx data access:
411
        # - Name of the node       # self.graph.nodes  (List[str])
412
        # - Node instance          # self.graph.nodes[node]["instance"]
413
        # - Input nodes            # [e[0] for e in self.graph.in_edges(node)]
414
        # - Output nodes           # [e[1] for e in self.graph.out_edges(node)]
415
        # - Output edges           # [e[2]["label"] for e in self.graph.out_edges(node, data=True)]
416
        #
417
        # if debug:
418
        #     os.makedirs("debug", exist_ok=True)
419

420
        data = _validate_pipeline_input(self.graph, input_values=data)
1✔
421
        self._clear_visits_count()
1✔
422
        self.warm_up()
1✔
423

424
        logger.info("Pipeline execution started.")
1✔
425
        inputs_buffer = OrderedDict(
1✔
426
            {
427
                node: {key: value for key, value in input_data.items() if value is not None}
428
                for node, input_data in data.items()
429
            }
430
        )
431
        pipeline_output: Dict[str, Dict[str, Any]] = {}
1✔
432

433
        if debug:
1✔
434
            logger.info("Debug mode ON.")
×
435
        self.debug = {}
1✔
436

437
        # *** PIPELINE EXECUTION LOOP ***
438
        # We select the nodes to run by popping them in FIFO order from the inputs buffer.
439
        step = 0
1✔
440
        while inputs_buffer:
1✔
441
            step += 1
1✔
442
            if debug:
1✔
443
                self._record_pipeline_step(step, inputs_buffer, pipeline_output)
×
444
            logger.debug("> Queue at step %s: %s", step, {k: list(v.keys()) for k, v in inputs_buffer.items()})
1✔
445

446
            component_name, inputs = inputs_buffer.popitem(last=False)  # FIFO
1✔
447

448
            # Make sure it didn't run too many times already
449
            self._check_max_loops(component_name)
1✔
450

451
            # **** IS IT MY TURN YET? ****
452
            # Check if the component should be run or not
453
            action = self._calculate_action(name=component_name, inputs=inputs, inputs_buffer=inputs_buffer)
1✔
454

455
            # This component is missing data: let's put it back in the queue and wait.
456
            if action == "wait":
1✔
457
                if not inputs_buffer:
1✔
458
                    # What if there are no components to wait for?
459
                    raise PipelineRuntimeError(
×
460
                        f"'{component_name}' is stuck waiting for input, but there are no other components to run. "
461
                        "This is likely a Canals bug. Open an issue at https://github.com/deepset-ai/canals."
462
                    )
463

464
                inputs_buffer[component_name] = inputs
1✔
465
                continue
1✔
466

467
            # This component did not receive the input it needs: it must be on a skipped branch. Let's not run it.
468
            if action == "skip":
1✔
469
                self.graph.nodes[component_name]["visits"] += 1
1✔
470
                inputs_buffer = self._skip_downstream_unvisited_nodes(
1✔
471
                    component_name=component_name, inputs_buffer=inputs_buffer
472
                )
473
                continue
1✔
474

475
            if action == "remove":
1✔
476
                # This component has no reason of being in the run queue and we need to remove it. For example, this can happen to components that are connected to skipped branches of the pipeline.
477
                continue
×
478

479
            # **** RUN THE NODE ****
480
            # It is our turn! The node is ready to run and all necessary inputs are present
481
            output = self._run_component(name=component_name, inputs=inputs)
1✔
482

483
            # **** PROCESS THE OUTPUT ****
484
            # The node run successfully. Let's store or distribute the output it produced, if it's valid.
485
            if not self.graph.out_edges(component_name):
1✔
486
                # Note: if a node outputs many times (like in loops), the output will be overwritten
487
                pipeline_output[component_name] = output
1✔
488
            else:
489
                inputs_buffer = self._route_output(
1✔
490
                    node_results=output, node_name=component_name, inputs_buffer=inputs_buffer
491
                )
492

493
        if debug:
1✔
494
            self._record_pipeline_step(step + 1, inputs_buffer, pipeline_output)
×
495

496
            # Save to json
497
            os.makedirs(self.debug_path, exist_ok=True)
×
498
            with open(self.debug_path / "data.json", "w", encoding="utf-8") as datafile:
×
499
                json.dump(self.debug, datafile, indent=4, default=str)
×
500

501
            # Store in the output
502
            pipeline_output["_debug"] = self.debug  # type: ignore
×
503

504
        logger.info("Pipeline executed successfully.")
1✔
505
        return pipeline_output
1✔
506

507
    def _record_pipeline_step(self, step, inputs_buffer, pipeline_output):
1✔
508
        """
509
        Stores a snapshot of this step into the self.debug dictionary of the pipeline.
510
        """
511
        mermaid_graph = _convert_for_debug(deepcopy(self.graph))
×
512
        self.debug[step] = {
×
513
            "time": datetime.datetime.now(),
514
            "inputs_buffer": list(inputs_buffer.items()),
515
            "pipeline_output": pipeline_output,
516
            "diagram": mermaid_graph,
517
        }
518

519
    def _clear_visits_count(self):
1✔
520
        """
521
        Make sure all nodes's visits count is zero.
522
        """
523
        for node in self.graph.nodes:
1✔
524
            self.graph.nodes[node]["visits"] = 0
1✔
525

526
    def _check_max_loops(self, component_name: str):
1✔
527
        """
528
        Verify whether this component run too many times.
529
        """
530
        if self.graph.nodes[component_name]["visits"] > self.max_loops_allowed:
1✔
531
            raise PipelineMaxLoops(
1✔
532
                f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{component_name}'."
533
            )
534

535
    # This function is complex so it contains quite some logic, it needs tons of information
536
    # regarding a component to understand what action it should take so we have many local
537
    # variables and to keep things simple we also have multiple returns.
538
    # In the end this amount of information makes it easier to understand the internal logic so
539
    # we chose to ignore these pylint warnings.
540
    def _calculate_action(  # pylint: disable=too-many-locals, too-many-return-statements
1✔
541
        self, name: str, inputs: Dict[str, Any], inputs_buffer: Dict[str, Any]
542
    ) -> Literal["run", "wait", "skip", "remove"]:
543
        """
544
        Calculates the action to take for the component specified by `name`.
545
        There are four possible actions:
546
            * run
547
            * wait
548
            * skip
549
            * remove
550

551
        The below conditions are evaluated in this order.
552

553
        Component will run if at least one of the following statements is true:
554
            * It received all mandatory inputs
555
            * It received all mandatory inputs and it has no optional inputs
556
            * It received all mandatory inputs and all optional inputs are skipped
557
            * It received all mandatory inputs and some optional inputs and the rest are skipped
558
            * It received some of its inputs and the others are defaulted
559
            * It's the first component of the pipeline
560

561
        Component will wait if:
562
            * It received some of its inputs and the other are not skipped
563
            * It received all mandatory inputs and some optional inputs have not been skipped
564

565
        Component will be skipped if:
566
            * It never ran nor waited
567

568
        Component will be removed if:
569
            * It ran or waited at least once but can't do it again
570

571
        If none of the above condition is met a PipelineRuntimeError is raised.
572

573
        For simplicity sake input components that create a cycle, or components that already ran
574
        and don't create a cycle are considered as skipped.
575

576
        Args:
577
            name: Name of the component
578
            inputs: Values that the component will take as input
579
            inputs_buffer: Other components' inputs
580

581
        Returns:
582
            Action to take for component specifing whether it should run, wait, skip or be removed
583

584
        Raises:
585
            PipelineRuntimeError: If action to take can't be determined
586
        """
587

588
        # Upstream components/socket pairs the current component is connected to
589
        input_components = {
1✔
590
            from_node: data["to_socket"].name for from_node, _, data in self.graph.in_edges(name, data=True)
591
        }
592
        # Sockets that have received inputs from upstream components
593
        received_input_sockets = set(inputs.keys())
1✔
594

595
        # All components inputs, whether they're connected, default or pipeline inputs
596
        input_sockets: Dict[str, InputSocket] = self.graph.nodes[name]["input_sockets"].keys()
1✔
597
        optional_input_sockets = {
1✔
598
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.is_optional
599
        }
600
        mandatory_input_sockets = {
1✔
601
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.is_optional
602
        }
603

604
        # Components that are in the inputs buffer and have no inputs assigned are considered skipped
605
        skipped_components = {n for n, v in inputs_buffer.items() if not v}
1✔
606

607
        # Sockets that have their upstream component marked as skipped
608
        skipped_optional_input_sockets = {
1✔
609
            sockets["to_socket"].name
610
            for from_node, _, sockets in self.graph.in_edges(name, data=True)
611
            if from_node in skipped_components and sockets["to_socket"].name in optional_input_sockets
612
        }
613

614
        for from_node, socket in input_components.items():
1✔
615
            if socket not in optional_input_sockets:
1✔
616
                continue
1✔
617
            loops_back = networkx.has_path(self.graph, name, from_node)
1✔
618
            has_run = self.graph.nodes[from_node]["visits"] > 0
1✔
619
            if loops_back or has_run:
1✔
620
                # Consider all input components that loop back to current component
621
                # or that have already run at least once as skipped.
622
                # This must be done to correctly handle cycles in the pipeline or we
623
                # would reach a dead lock in components that have multiple inputs and
624
                # one of these forms a cycle.
625
                skipped_optional_input_sockets.add(socket)
1✔
626

627
        ##############
628
        # RUN CHECKS #
629
        ##############
630
        if (
1✔
631
            mandatory_input_sockets.issubset(received_input_sockets)
632
            and input_sockets == received_input_sockets | mandatory_input_sockets | skipped_optional_input_sockets
633
        ):
634
            # We received all mandatory inputs and:
635
            #   * There are no optional inputs or
636
            #   * All optional inputs are skipped or
637
            #   * We received part of the optional inputs, the rest are skipped
638
            if not optional_input_sockets:
1✔
639
                logger.debug("Component '%s' is ready to run. All mandatory inputs received.", name)
1✔
640
            else:
641
                logger.debug(
1✔
642
                    "Component '%s' is ready to run. All mandatory inputs received, skipped optional inputs: %s",
643
                    name,
644
                    skipped_optional_input_sockets,
645
                )
646
            return "run"
1✔
647

648
        if set(input_components.values()).issubset(received_input_sockets):
1✔
649
            # We have data from each connected input component.
650
            # We reach this when the current component is the first of the pipeline or
651
            # when it has defaults and all its input components have run.
652
            logger.debug("Component '%s' is ready to run. All expected inputs were received.", name)
1✔
653
            return "run"
1✔
654

655
        ###############
656
        # WAIT CHECKS #
657
        ###############
658
        if mandatory_input_sockets == received_input_sockets and skipped_optional_input_sockets.issubset(
1✔
659
            optional_input_sockets
660
        ):
661
            # We received all of the inputs we need, but some optional inputs have not been run or skipped yet
662
            logger.debug(
×
663
                "Component '%s' is waiting. All mandatory inputs received, some optional are not skipped: %s",
664
                name,
665
                optional_input_sockets - skipped_optional_input_sockets,
666
            )
667
            return "wait"
×
668

669
        if any(self.graph.nodes[n]["visits"] == 0 for n in input_components.keys()):
1✔
670
            # Some upstream component that must send input to the current component has yet to run.
671
            logger.debug(
1✔
672
                "Component '%s' is waiting. Missing inputs: %s",
673
                name,
674
                set(input_components.values()),
675
            )
676
            return "wait"
1✔
677

678
        ###############
679
        # SKIP CHECKS #
680
        ###############
681
        if self.graph.nodes[name]["visits"] == 0:
1✔
682
            # It's the first time visiting this component, if it can't run nor wait
683
            # it's fine skipping it at this point.
684
            logger.debug("Component '%s' is skipped. It can't run nor wait.", name)
1✔
685
            return "skip"
1✔
686

687
        #################
688
        # REMOVE CHECKS #
689
        #################
690
        if self.graph.nodes[name]["visits"] > 0:
×
691
            # This component has already been visited at least once. If it can't run nor wait
692
            # there is no reason to skip it again. So we it must be removed.
693
            logger.debug("Component '%s' is removed. It can't run, wait or skip.", name)
×
694
            return "remove"
×
695

696
        # Can't determine action to take
697
        raise PipelineRuntimeError(
×
698
            f"Can't determine Component '{name}' action. "
699
            f"Mandatory input sockets: {mandatory_input_sockets}, "
700
            f"optional input sockets: {optional_input_sockets}, "
701
            f"received input: {list(inputs.keys())}, "
702
            f"input components: {list(input_components.keys())}, "
703
            f"skipped components: {skipped_components}, "
704
            f"skipped optional inputs: {skipped_optional_input_sockets}."
705
            f"This is likely a Canals bug. Please open an issue at https://github.com/deepset-ai/canals.",
706
        )
707

708
    def _skip_downstream_unvisited_nodes(self, component_name: str, inputs_buffer: OrderedDict) -> OrderedDict:
1✔
709
        """
710
        When a component is skipped, put all downstream nodes in the inputs buffer too: the might be skipped too,
711
        unless they are merge nodes. They will be evaluated later by the pipeline execution loop.
712
        """
713
        downstream_nodes = [e[1] for e in self.graph.out_edges(component_name)]
1✔
714
        for downstream_node in downstream_nodes:
1✔
715
            if downstream_node in inputs_buffer:
1✔
716
                continue
1✔
717
            if self.graph.nodes[downstream_node]["visits"] == 0:
1✔
718
                # Skip downstream nodes only if they never been visited
719
                inputs_buffer[downstream_node] = {}
1✔
720
        return inputs_buffer
1✔
721

722
    def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
1✔
723
        """
724
        Once we're confident this component is ready to run, run it and collect the output.
725
        """
726
        self.graph.nodes[name]["visits"] += 1
1✔
727
        instance = self.graph.nodes[name]["instance"]
1✔
728
        try:
1✔
729
            logger.info("* Running %s (visits: %s)", name, self.graph.nodes[name]["visits"])
1✔
730
            logger.debug("   '%s' inputs: %s", name, inputs)
1✔
731

732
            # # Optional fields are defaulted to None so creation of the input dataclass doesn't fail
733
            # # cause we're missing some argument
734
            # optionals = {field: None for field in instance.__canals_optional_inputs__}
735

736
            # Pass the inputs as kwargs after adding the component's own defaults to them
737
            # inputs = {**optionals, **instance.defaults, **inputs}
738
            # input_dataclass = instance.input(**inputs)
739

740
            outputs = instance.run(**inputs)
1✔
741

742
            # Unwrap the output
743
            logger.debug("   '%s' outputs: %s\n", name, outputs)
1✔
744

745
        except Exception as e:
×
746
            raise PipelineRuntimeError(
×
747
                f"{name} raised '{e.__class__.__name__}: {e}' \nInputs: {inputs}\n\n"
748
                "See the stacktrace above for more information."
749
            ) from e
750

751
        return outputs
1✔
752

753
    def _route_output(
1✔
754
        self,
755
        node_name: str,
756
        node_results: Dict[str, Any],
757
        inputs_buffer: OrderedDict,
758
    ) -> OrderedDict:
759
        """
760
        Distrubute the outputs of the component into the input buffer of downstream components.
761

762
        Returns the updated inputs buffer.
763
        """
764
        # This is not a terminal node: find out where the output goes, to which nodes and along which edge
765
        is_decision_node_for_loop = (
1✔
766
            any(networkx.has_path(self.graph, edge[1], node_name) for edge in self.graph.out_edges(node_name))
767
            and len(self.graph.out_edges(node_name)) > 1
768
        )
769
        for edge_data in self.graph.out_edges(node_name, data=True):
1✔
770
            to_socket = edge_data[2]["to_socket"]
1✔
771
            from_socket = edge_data[2]["from_socket"]
1✔
772
            target_node = edge_data[1]
1✔
773

774
            # If this is a decision node and a loop is involved, we add to the input buffer only the nodes
775
            # that received their expected output and we leave the others out of the queue.
776
            if is_decision_node_for_loop and node_results.get(from_socket.name, None) is None:
1✔
777
                if networkx.has_path(self.graph, target_node, node_name):
1✔
778
                    # In case we're choosing to leave a loop, do not put the loop's node in the buffer.
779
                    logger.debug(
1✔
780
                        "Not adding '%s' to the inputs buffer: we're leaving the loop.",
781
                        target_node,
782
                    )
783
                else:
784
                    # In case we're choosing to stay in a loop, do not put the external node in the buffer.
785
                    logger.debug(
1✔
786
                        "Not adding '%s' to the inputs buffer: we're staying in the loop.",
787
                        target_node,
788
                    )
789
            else:
790
                # In all other cases, populate the inputs buffer for all downstream nodes, setting None to any
791
                # edge that did not receive input.
792
                if target_node not in inputs_buffer:
1✔
793
                    inputs_buffer[target_node] = {}  # Create the buffer for the downstream node if it's not there yet
1✔
794

795
                value_to_route = node_results.get(from_socket.name, None)
1✔
796
                if value_to_route is not None:
1✔
797
                    inputs_buffer[target_node][to_socket.name] = value_to_route
1✔
798

799
        return inputs_buffer
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc