• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / canals / 5953113206

23 Aug 2023 03:06PM UTC coverage: 93.182% (+0.2%) from 92.958%
5953113206

push

github

web-flow
Remove handling of shared component instances on Pipeline serialization (#102)

157 of 160 branches covered (98.13%)

Branch coverage included in aggregate %.

581 of 632 relevant lines covered (91.93%)

0.92 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

89.07
canals/pipeline/pipeline.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4
from typing import Optional, Any, Dict, List, Literal, Union
1✔
5

6
import os
1✔
7
import json
1✔
8
import datetime
1✔
9
import logging
1✔
10
from pathlib import Path
1✔
11
from copy import deepcopy
1✔
12
from collections import OrderedDict
1✔
13

14
import networkx
1✔
15

16
from canals.component import component, Component
1✔
17
from canals.errors import (
1✔
18
    PipelineError,
19
    PipelineConnectError,
20
    PipelineMaxLoops,
21
    PipelineRuntimeError,
22
    PipelineValidationError,
23
)
24
from canals.pipeline.draw import _draw, _convert_for_debug, RenderingEngines
1✔
25
from canals.pipeline.sockets import InputSocket, OutputSocket
1✔
26
from canals.pipeline.validation import _validate_pipeline_input
1✔
27
from canals.pipeline.connections import _parse_connection_name, _find_unambiguous_connection
1✔
28
from canals.utils import _type_name
1✔
29

30
logger = logging.getLogger(__name__)
1✔
31

32

33
class Pipeline:
1✔
34
    """
35
    Components orchestration engine.
36

37
    Builds a graph of components and orchestrates their execution according to the execution graph.
38
    """
39

40
    def __init__(
1✔
41
        self,
42
        metadata: Optional[Dict[str, Any]] = None,
43
        max_loops_allowed: int = 100,
44
        debug_path: Union[Path, str] = Path(".canals_debug/"),
45
    ):
46
        """
47
        Creates the Pipeline.
48

49
        Args:
50
            metadata: arbitrary dictionary to store metadata about this pipeline. Make sure all the values contained in
51
                this dictionary can be serialized and deserialized if you wish to save this pipeline to file with
52
                `save_pipelines()/load_pipelines()`.
53
            max_loops_allowed: how many times the pipeline can run the same node before throwing an exception.
54
            debug_path: when debug is enabled in `run()`, where to save the debug data.
55
        """
56
        self.metadata = metadata or {}
1✔
57
        self.max_loops_allowed = max_loops_allowed
1✔
58
        self.graph = networkx.MultiDiGraph()
1✔
59
        self.debug: Dict[int, Dict[str, Any]] = {}
1✔
60
        self.debug_path = Path(debug_path)
1✔
61

62
    def __eq__(self, other) -> bool:
1✔
63
        """
64
        Equal pipelines share every metadata, node and edge, but they're not required to use
65
        the same node instances: this allows pipeline saved and then loaded back to be equal to themselves.
66
        """
67
        if (
1✔
68
            not isinstance(other, type(self))
69
            or not getattr(self, "metadata") == getattr(other, "metadata")
70
            or not getattr(self, "max_loops_allowed") == getattr(other, "max_loops_allowed")
71
            or not hasattr(self, "graph")
72
            or not hasattr(other, "graph")
73
        ):
74
            return False
×
75

76
        return (
1✔
77
            self.graph.adj == other.graph.adj
78
            and self._comparable_nodes_list(self.graph) == self._comparable_nodes_list(other.graph)
79
            and self.graph.graph == other.graph.graph
80
        )
81

82
    def to_dict(self) -> Dict[str, Any]:
1✔
83
        """
84
        Returns this Pipeline instance as a dictionary.
85
        This is meant to be an intermediate representation but it can be also used to save a pipeline to file.
86
        """
87
        components = {name: instance.to_dict() for name, instance in self.graph.nodes(data="instance")}
1✔
88
        connections = []
1✔
89
        for sender, receiver, sockets in self.graph.edges:
1✔
90
            (sender_socket, receiver_socket) = sockets.split("/")
1✔
91
            connections.append(
1✔
92
                {
93
                    "sender": f"{sender}.{sender_socket}",
94
                    "receiver": f"{receiver}.{receiver_socket}",
95
                }
96
            )
97
        return {
1✔
98
            "metadata": self.metadata,
99
            "max_loops_allowed": self.max_loops_allowed,
100
            "components": components,
101
            "connections": connections,
102
        }
103

104
    @classmethod
1✔
105
    def from_dict(cls, data: Dict[str, Any], **kwargs) -> "Pipeline":
1✔
106
        """
107
        Creates a Pipeline instance from a dictionary.
108
        A sample `data` dictionary could be formatted like so:
109
        ```
110
        {
111
            "metadata": {"test": "test"},
112
            "max_loops_allowed": 100,
113
            "components": {
114
                "add_two": {
115
                    "type": "AddFixedValue",
116
                    "init_parameters": {"add": 2},
117
                },
118
                "add_default": {
119
                    "type": "AddFixedValue",
120
                    "init_parameters": {"add": 1},
121
                },
122
                "double": {
123
                    "type": "Double",
124
                },
125
            },
126
            "connections": [
127
                {"sender": "add_two.result", "receiver": "double.value"},
128
                {"sender": "double.value", "receiver": "add_default.value"},
129
            ],
130
        }
131
        ```
132

133
        Supported kwargs:
134
        `components`: a dictionary of {name: instance} to reuse instances of components instead of creating new ones.
135
        """
136
        metadata = data.get("metadata", {})
1✔
137
        max_loops_allowed = data.get("max_loops_allowed", 100)
1✔
138
        debug_path = Path(data.get("debug_path", ".canals_debug/"))
1✔
139
        pipe = cls(
1✔
140
            metadata=metadata,
141
            max_loops_allowed=max_loops_allowed,
142
            debug_path=debug_path,
143
        )
144
        components_to_reuse = kwargs.get("components", {})
1✔
145
        for name, component_data in data.get("components", {}).items():
1✔
146
            if name in components_to_reuse:
1✔
147
                # Reuse an instance
148
                instance = components_to_reuse[name]
1✔
149
            else:
150
                if "type" not in component_data:
1✔
151
                    raise PipelineError(f"Missing 'type' in component '{name}'")
1✔
152
                if component_data["type"] not in component.registry:
1✔
153
                    raise PipelineError(f"Component '{component_data['type']}' not imported.")
1✔
154
                # Create a new one
155
                instance = component.registry[component_data["type"]].from_dict(component_data)
1✔
156
            pipe.add_component(name=name, instance=instance)
1✔
157

158
        for connection in data.get("connections", []):
1✔
159
            if "sender" not in connection:
1✔
160
                raise PipelineError(f"Missing sender in connection: {connection}")
1✔
161
            if "receiver" not in connection:
1✔
162
                raise PipelineError(f"Missing receiver in connection: {connection}")
1✔
163
            pipe.connect(connect_from=connection["sender"], connect_to=connection["receiver"])
1✔
164

165
        return pipe
1✔
166

167
    def _comparable_nodes_list(self, graph: networkx.MultiDiGraph) -> List[Dict[str, Any]]:
1✔
168
        """
169
        Replaces instances of nodes with their class name in order to make sure they're comparable.
170
        """
171
        nodes = []
1✔
172
        for node in graph.nodes:
1✔
173
            comparable_node = graph.nodes[node]
×
174
            comparable_node["instance"] = comparable_node["instance"].__class__
×
175
            nodes.append(comparable_node)
×
176
        nodes.sort()
1✔
177
        return nodes
1✔
178

179
    def add_component(self, name: str, instance: Component) -> None:
1✔
180
        """
181
        Create a component for the given component. Components are not connected to anything by default:
182
        use `Pipeline.connect()` to connect components together.
183

184
        Component names must be unique, but component instances can be reused if needed.
185

186
        Args:
187
            name: the name of the component.
188
            instance: the component instance.
189

190
        Returns:
191
            None
192

193
        Raises:
194
            ValueError: if a component with the same name already exists
195
            PipelineValidationError: if the given instance is not a Canals component
196
        """
197
        # Component names are unique
198
        if name in self.graph.nodes:
1✔
199
            raise ValueError(f"A component named '{name}' already exists in this pipeline: choose another name.")
×
200

201
        # Components can't be named `_debug`
202
        if name == "_debug":
1✔
203
            raise ValueError("'_debug' is a reserved name for debug output. Choose another name.")
×
204

205
        # Component instances must be components
206
        if not hasattr(instance, "__canals_component__"):
1✔
207
            raise PipelineValidationError(
×
208
                f"'{type(instance)}' doesn't seem to be a component. Is this class decorated with @component?"
209
            )
210

211
        # Create the component's input and output sockets
212
        inputs = getattr(instance.run, "__canals_input__", {})
1✔
213
        outputs = getattr(instance.run, "__canals_output__", {})
1✔
214
        input_sockets = {name: InputSocket(**data) for name, data in inputs.items()}
1✔
215
        output_sockets = {name: OutputSocket(**data) for name, data in outputs.items()}
1✔
216

217
        # Add component to the graph, disconnected
218
        logger.debug("Adding component '%s' (%s)", name, instance)
1✔
219
        self.graph.add_node(
1✔
220
            name,
221
            instance=instance,
222
            input_sockets=input_sockets,
223
            output_sockets=output_sockets,
224
            visits=0,
225
        )
226

227
    def connect(self, connect_from: str, connect_to: str) -> None:
1✔
228
        """
229
        Connects two components together. All components to connect must exist in the pipeline.
230
        If connecting to an component that has several output connections, specify the inputs and output names as
231
        'component_name.connections_name'.
232

233
        Args:
234
            connect_from: the component that delivers the value. This can be either just a component name or can be
235
                in the format `component_name.connection_name` if the component has multiple outputs.
236
            connect_to: the component that receives the value. This can be either just a component name or can be
237
                in the format `component_name.connection_name` if the component has multiple inputs.
238

239
        Returns:
240
            None
241

242
        Raises:
243
            PipelineConnectError: if the two components cannot be connected (for example if one of the components is
244
                not present in the pipeline, or the connections don't match by type, and so on).
245
        """
246
        # Edges may be named explicitly by passing 'node_name.edge_name' to connect().
247
        from_node, from_socket_name = _parse_connection_name(connect_from)
1✔
248
        to_node, to_socket_name = _parse_connection_name(connect_to)
1✔
249

250
        # Get the nodes data.
251
        try:
1✔
252
            from_sockets = self.graph.nodes[from_node]["output_sockets"]
1✔
253
        except KeyError as exc:
1✔
254
            raise ValueError(f"Component named {from_node} not found in the pipeline.") from exc
1✔
255

256
        try:
1✔
257
            to_sockets = self.graph.nodes[to_node]["input_sockets"]
1✔
258
        except KeyError as exc:
1✔
259
            raise ValueError(f"Component named {to_node} not found in the pipeline.") from exc
1✔
260

261
        # If the name of either socket is given, get the socket
262
        if from_socket_name:
1✔
263
            from_socket = from_sockets.get(from_socket_name, None)
1✔
264
            if not from_socket:
1✔
265
                raise PipelineConnectError(
1✔
266
                    f"'{from_node}.{from_socket_name} does not exist. "
267
                    f"Output connections of {from_node} are: "
268
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in from_sockets.items()])
269
                )
270
        if to_socket_name:
1✔
271
            to_socket = to_sockets.get(to_socket_name, None)
1✔
272
            if not to_socket:
1✔
273
                raise PipelineConnectError(
1✔
274
                    f"'{to_node}.{to_socket_name} does not exist. "
275
                    f"Input connections of {to_node} are: "
276
                    + ", ".join([f"{name} (type {_type_name(socket.type)})" for name, socket in to_sockets.items()])
277
                )
278

279
        # Look for an unambiguous connection among the possible ones.
280
        # Note that if there is more than one possible connection but two sockets match by name, they're paired.
281
        from_sockets = [from_socket] if from_socket_name else list(from_sockets.values())
1✔
282
        to_sockets = [to_socket] if to_socket_name else list(to_sockets.values())
1✔
283
        from_socket, to_socket = _find_unambiguous_connection(
1✔
284
            sender_node=from_node, sender_sockets=from_sockets, receiver_node=to_node, receiver_sockets=to_sockets
285
        )
286

287
        # Connect the components on these sockets
288
        self._direct_connect(from_node=from_node, from_socket=from_socket, to_node=to_node, to_socket=to_socket)
1✔
289

290
    def _direct_connect(self, from_node: str, from_socket: OutputSocket, to_node: str, to_socket: InputSocket) -> None:
1✔
291
        """
292
        Directly connect socket to socket. This method does not type-check the connections: use 'Pipeline.connect()'
293
        instead (which uses 'find_unambiguous_connection()' to validate types).
294
        """
295
        # Make sure the receiving socket isn't already connected - sending sockets can be connected as many times as needed,
296
        # so they don't need this check
297
        if to_socket.sender:
1✔
298
            raise PipelineConnectError(
1✔
299
                f"Cannot connect '{from_node}.{from_socket.name}' with '{to_node}.{to_socket.name}': "
300
                f"{to_node}.{to_socket.name} is already connected to {to_socket.sender}.\n"
301
            )
302

303
        # Create the connection
304
        logger.debug("Connecting '%s.%s' to '%s.%s'", from_node, from_socket.name, to_node, to_socket.name)
1✔
305
        edge_key = f"{from_socket.name}/{to_socket.name}"
1✔
306
        self.graph.add_edge(
1✔
307
            from_node,
308
            to_node,
309
            key=edge_key,
310
            conn_type=_type_name(from_socket.type),
311
            from_socket=from_socket,
312
            to_socket=to_socket,
313
        )
314

315
        # Stores the name of the node that will send its output to this socket
316
        to_socket.sender = from_node
1✔
317

318
    def get_component(self, name: str) -> Component:
1✔
319
        """
320
        Returns an instance of a component.
321

322
        Args:
323
            name: the name of the component
324

325
        Returns:
326
            The instance of that component.
327

328
        Raises:
329
            ValueError: if a component with that name is not present in the pipeline.
330
        """
331
        try:
×
332
            return self.graph.nodes[name]["instance"]
×
333
        except KeyError as exc:
×
334
            raise ValueError(f"Component named {name} not found in the pipeline.") from exc
×
335

336
    def draw(self, path: Path, engine: RenderingEngines = "mermaid-img") -> None:
1✔
337
        """
338
        Draws the pipeline. Requires either `graphviz` as a system dependency, or an internet connection for Mermaid.
339
        Run `pip install canals[graphviz]` or `pip install canals[mermaid]` to install missing dependencies.
340

341
        Args:
342
            path: where to save the diagram.
343
            engine: which format to save the graph as. Accepts 'graphviz', 'mermaid-text', 'mermaid-img'.
344
                Default is 'mermaid-img'.
345

346
        Returns:
347
            None
348

349
        Raises:
350
            ImportError: if `engine='graphviz'` and `pygraphviz` is not installed.
351
            HTTPConnectionError: (and similar) if the internet connection is down or other connection issues.
352
        """
353
        sockets = {
1✔
354
            comp: "\n".join([f"{name}: {socket}" for name, socket in data.get("input_sockets", {}).items()])
355
            for comp, data in self.graph.nodes(data=True)
356
        }
357
        print(sockets)
1✔
358
        _draw(graph=deepcopy(self.graph), path=path, engine=engine)
1✔
359

360
    def warm_up(self):
1✔
361
        """
362
        Make sure all nodes are warm.
363

364
        It's the node's responsibility to make sure this method can be called at every `Pipeline.run()`
365
        without re-initializing everything.
366
        """
367
        for node in self.graph.nodes:
1✔
368
            if hasattr(self.graph.nodes[node]["instance"], "warm_up"):
1✔
369
                logger.info("Warming up component %s...", node)
×
370
                self.graph.nodes[node]["instance"].warm_up()
×
371

372
    def run(self, data: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
1✔
373
        """
374
        Runs the pipeline.
375

376
        Args:
377
            data: the inputs to give to the input components of the Pipeline.
378
            parameters: a dictionary with all the parameters of all the components, namespaced by component.
379
            debug: whether to collect and return debug information.
380

381
        Returns:
382
            A dictionary with the outputs of the output components of the Pipeline.
383

384
        Raises:
385
            PipelineRuntimeError: if the any of the components fail or return unexpected output.
386
        """
387
        # **** The Pipeline.run() algorithm ****
388
        #
389
        # Nodes are run as soon as an input for them appears in the inputs buffer.
390
        # When there's more than a node at once in the buffer (which means some
391
        # branches are running in parallel or that there are loops) they are selected to
392
        # run in FIFO order by the `inputs_buffer` OrderedDict.
393
        #
394
        # Inputs are labeled with the name of the node they're aimed for:
395
        #
396
        #   ````
397
        #   inputs_buffer[target_node] = {"input_name": input_value, ...}
398
        #   ```
399
        #
400
        # Nodes should wait until all the necessary input data has arrived before running.
401
        # If they're popped from the input_buffer before they're ready, they're put back in.
402
        # If the pipeline has branches of different lengths, it's possible that a node has to
403
        # wait a bit and let other nodes pass before receiving all the input data it needs.
404
        #
405
        # Chetsheet for networkx data access:
406
        # - Name of the node       # self.graph.nodes  (List[str])
407
        # - Node instance          # self.graph.nodes[node]["instance"]
408
        # - Input nodes            # [e[0] for e in self.graph.in_edges(node)]
409
        # - Output nodes           # [e[1] for e in self.graph.out_edges(node)]
410
        # - Output edges           # [e[2]["label"] for e in self.graph.out_edges(node, data=True)]
411
        #
412
        # if debug:
413
        #     os.makedirs("debug", exist_ok=True)
414

415
        data = _validate_pipeline_input(self.graph, input_values=data)
1✔
416
        self._clear_visits_count()
1✔
417
        self.warm_up()
1✔
418

419
        logger.info("Pipeline execution started.")
1✔
420
        inputs_buffer = OrderedDict(
1✔
421
            {
422
                node: {key: value for key, value in input_data.items() if value is not None}
423
                for node, input_data in data.items()
424
            }
425
        )
426
        pipeline_output: Dict[str, Dict[str, Any]] = {}
1✔
427

428
        if debug:
1✔
429
            logger.info("Debug mode ON.")
×
430
        self.debug = {}
1✔
431

432
        # *** PIPELINE EXECUTION LOOP ***
433
        # We select the nodes to run by popping them in FIFO order from the inputs buffer.
434
        step = 0
1✔
435
        while inputs_buffer:
1✔
436
            step += 1
1✔
437
            if debug:
1✔
438
                self._record_pipeline_step(step, inputs_buffer, pipeline_output)
×
439
            logger.debug("> Queue at step %s: %s", step, {k: list(v.keys()) for k, v in inputs_buffer.items()})
1✔
440

441
            component_name, inputs = inputs_buffer.popitem(last=False)  # FIFO
1✔
442

443
            # Make sure it didn't run too many times already
444
            self._check_max_loops(component_name)
1✔
445

446
            # **** IS IT MY TURN YET? ****
447
            # Check if the component should be run or not
448
            action = self._calculate_action(name=component_name, inputs=inputs, inputs_buffer=inputs_buffer)
1✔
449

450
            # This component is missing data: let's put it back in the queue and wait.
451
            if action == "wait":
1✔
452
                if not inputs_buffer:
1✔
453
                    # What if there are no components to wait for?
454
                    raise PipelineRuntimeError(
×
455
                        f"'{component_name}' is stuck waiting for input, but there are no other components to run. "
456
                        "This is likely a Canals bug. Open an issue at https://github.com/deepset-ai/canals."
457
                    )
458

459
                inputs_buffer[component_name] = inputs
1✔
460
                continue
1✔
461

462
            # This component did not receive the input it needs: it must be on a skipped branch. Let's not run it.
463
            if action == "skip":
1✔
464
                self.graph.nodes[component_name]["visits"] += 1
1✔
465
                inputs_buffer = self._skip_downstream_unvisited_nodes(
1✔
466
                    component_name=component_name, inputs_buffer=inputs_buffer
467
                )
468
                continue
1✔
469

470
            if action == "remove":
1✔
471
                # This component has no reason of being in the run queue and we need to remove it. For example, this can happen to components that are connected to skipped branches of the pipeline.
472
                continue
×
473

474
            # **** RUN THE NODE ****
475
            # It is our turn! The node is ready to run and all necessary inputs are present
476
            output = self._run_component(name=component_name, inputs=inputs)
1✔
477

478
            # **** PROCESS THE OUTPUT ****
479
            # The node run successfully. Let's store or distribute the output it produced, if it's valid.
480
            if not self.graph.out_edges(component_name):
1✔
481
                # Note: if a node outputs many times (like in loops), the output will be overwritten
482
                pipeline_output[component_name] = output
1✔
483
            else:
484
                inputs_buffer = self._route_output(
1✔
485
                    node_results=output, node_name=component_name, inputs_buffer=inputs_buffer
486
                )
487

488
        if debug:
1✔
489
            self._record_pipeline_step(step + 1, inputs_buffer, pipeline_output)
×
490

491
            # Save to json
492
            os.makedirs(self.debug_path, exist_ok=True)
×
493
            with open(self.debug_path / "data.json", "w", encoding="utf-8") as datafile:
×
494
                json.dump(self.debug, datafile, indent=4, default=str)
×
495

496
            # Store in the output
497
            pipeline_output["_debug"] = self.debug  # type: ignore
×
498

499
        logger.info("Pipeline executed successfully.")
1✔
500
        return pipeline_output
1✔
501

502
    def _record_pipeline_step(self, step, inputs_buffer, pipeline_output):
1✔
503
        """
504
        Stores a snapshot of this step into the self.debug dictionary of the pipeline.
505
        """
506
        mermaid_graph = _convert_for_debug(deepcopy(self.graph))
×
507
        self.debug[step] = {
×
508
            "time": datetime.datetime.now(),
509
            "inputs_buffer": list(inputs_buffer.items()),
510
            "pipeline_output": pipeline_output,
511
            "diagram": mermaid_graph,
512
        }
513

514
    def _clear_visits_count(self):
1✔
515
        """
516
        Make sure all nodes's visits count is zero.
517
        """
518
        for node in self.graph.nodes:
1✔
519
            self.graph.nodes[node]["visits"] = 0
1✔
520

521
    def _check_max_loops(self, component_name: str):
1✔
522
        """
523
        Verify whether this component run too many times.
524
        """
525
        if self.graph.nodes[component_name]["visits"] > self.max_loops_allowed:
1✔
526
            raise PipelineMaxLoops(
1✔
527
                f"Maximum loops count ({self.max_loops_allowed}) exceeded for component '{component_name}'."
528
            )
529

530
    # This function is complex so it contains quite some logic, it needs tons of information
531
    # regarding a component to understand what action it should take so we have many local
532
    # variables and to keep things simple we also have multiple returns.
533
    # In the end this amount of information makes it easier to understand the internal logic so
534
    # we chose to ignore these pylint warnings.
535
    def _calculate_action(  # pylint: disable=too-many-locals, too-many-return-statements
1✔
536
        self, name: str, inputs: Dict[str, Any], inputs_buffer: Dict[str, Any]
537
    ) -> Literal["run", "wait", "skip", "remove"]:
538
        """
539
        Calculates the action to take for the component specified by `name`.
540
        There are four possible actions:
541
            * run
542
            * wait
543
            * skip
544
            * remove
545

546
        The below conditions are evaluated in this order.
547

548
        Component will run if at least one of the following statements is true:
549
            * It received all mandatory inputs
550
            * It received all mandatory inputs and it has no optional inputs
551
            * It received all mandatory inputs and all optional inputs are skipped
552
            * It received all mandatory inputs and some optional inputs and the rest are skipped
553
            * It received some of its inputs and the others are defaulted
554
            * It's the first component of the pipeline
555

556
        Component will wait if:
557
            * It received some of its inputs and the other are not skipped
558
            * It received all mandatory inputs and some optional inputs have not been skipped
559

560
        Component will be skipped if:
561
            * It never ran nor waited
562

563
        Component will be removed if:
564
            * It ran or waited at least once but can't do it again
565

566
        If none of the above condition is met a PipelineRuntimeError is raised.
567

568
        For simplicity sake input components that create a cycle, or components that already ran
569
        and don't create a cycle are considered as skipped.
570

571
        Args:
572
            name: Name of the component
573
            inputs: Values that the component will take as input
574
            inputs_buffer: Other components' inputs
575

576
        Returns:
577
            Action to take for component specifing whether it should run, wait, skip or be removed
578

579
        Raises:
580
            PipelineRuntimeError: If action to take can't be determined
581
        """
582

583
        # Upstream components/socket pairs the current component is connected to
584
        input_components = {
1✔
585
            from_node: data["to_socket"].name for from_node, _, data in self.graph.in_edges(name, data=True)
586
        }
587
        # Sockets that have received inputs from upstream components
588
        received_input_sockets = set(inputs.keys())
1✔
589

590
        # All components inputs, whether they're connected, default or pipeline inputs
591
        input_sockets: Dict[str, InputSocket] = self.graph.nodes[name]["input_sockets"].keys()
1✔
592
        optional_input_sockets = {
1✔
593
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if socket.is_optional
594
        }
595
        mandatory_input_sockets = {
1✔
596
            socket.name for socket in self.graph.nodes[name]["input_sockets"].values() if not socket.is_optional
597
        }
598

599
        # Components that are in the inputs buffer and have no inputs assigned are considered skipped
600
        skipped_components = {n for n, v in inputs_buffer.items() if not v}
1✔
601

602
        # Sockets that have their upstream component marked as skipped
603
        skipped_optional_input_sockets = {
1✔
604
            sockets["to_socket"].name
605
            for from_node, _, sockets in self.graph.in_edges(name, data=True)
606
            if from_node in skipped_components and sockets["to_socket"].name in optional_input_sockets
607
        }
608

609
        for from_node, socket in input_components.items():
1✔
610
            if socket not in optional_input_sockets:
1✔
611
                continue
1✔
612
            loops_back = networkx.has_path(self.graph, name, from_node)
1✔
613
            has_run = self.graph.nodes[from_node]["visits"] > 0
1✔
614
            if loops_back or has_run:
1✔
615
                # Consider all input components that loop back to current component
616
                # or that have already run at least once as skipped.
617
                # This must be done to correctly handle cycles in the pipeline or we
618
                # would reach a dead lock in components that have multiple inputs and
619
                # one of these forms a cycle.
620
                skipped_optional_input_sockets.add(socket)
1✔
621

622
        ##############
623
        # RUN CHECKS #
624
        ##############
625
        if (
1✔
626
            mandatory_input_sockets.issubset(received_input_sockets)
627
            and input_sockets == received_input_sockets | mandatory_input_sockets | skipped_optional_input_sockets
628
        ):
629
            # We received all mandatory inputs and:
630
            #   * There are no optional inputs or
631
            #   * All optional inputs are skipped or
632
            #   * We received part of the optional inputs, the rest are skipped
633
            if not optional_input_sockets:
1✔
634
                logger.debug("Component '%s' is ready to run. All mandatory inputs received.", name)
1✔
635
            else:
636
                logger.debug(
1✔
637
                    "Component '%s' is ready to run. All mandatory inputs received, skipped optional inputs: %s",
638
                    name,
639
                    skipped_optional_input_sockets,
640
                )
641
            return "run"
1✔
642

643
        if set(input_components.values()).issubset(received_input_sockets):
1✔
644
            # We have data from each connected input component.
645
            # We reach this when the current component is the first of the pipeline or
646
            # when it has defaults and all its input components have run.
647
            logger.debug("Component '%s' is ready to run. All expected inputs were received.", name)
1✔
648
            return "run"
1✔
649

650
        ###############
651
        # WAIT CHECKS #
652
        ###############
653
        if mandatory_input_sockets == received_input_sockets and skipped_optional_input_sockets.issubset(
1✔
654
            optional_input_sockets
655
        ):
656
            # We received all of the inputs we need, but some optional inputs have not been run or skipped yet
657
            logger.debug(
×
658
                "Component '%s' is waiting. All mandatory inputs received, some optional are not skipped: %s",
659
                name,
660
                optional_input_sockets - skipped_optional_input_sockets,
661
            )
662
            return "wait"
×
663

664
        if any(self.graph.nodes[n]["visits"] == 0 for n in input_components.keys()):
1✔
665
            # Some upstream component that must send input to the current component has yet to run.
666
            logger.debug(
1✔
667
                "Component '%s' is waiting. Missing inputs: %s",
668
                name,
669
                set(input_components.values()),
670
            )
671
            return "wait"
1✔
672

673
        ###############
674
        # SKIP CHECKS #
675
        ###############
676
        if self.graph.nodes[name]["visits"] == 0:
1✔
677
            # It's the first time visiting this component, if it can't run nor wait
678
            # it's fine skipping it at this point.
679
            logger.debug("Component '%s' is skipped. It can't run nor wait.", name)
1✔
680
            return "skip"
1✔
681

682
        #################
683
        # REMOVE CHECKS #
684
        #################
685
        if self.graph.nodes[name]["visits"] > 0:
×
686
            # This component has already been visited at least once. If it can't run nor wait
687
            # there is no reason to skip it again. So we it must be removed.
688
            logger.debug("Component '%s' is removed. It can't run, wait or skip.", name)
×
689
            return "remove"
×
690

691
        # Can't determine action to take
692
        raise PipelineRuntimeError(
×
693
            f"Can't determine Component '{name}' action. "
694
            f"Mandatory input sockets: {mandatory_input_sockets}, "
695
            f"optional input sockets: {optional_input_sockets}, "
696
            f"received input: {list(inputs.keys())}, "
697
            f"input components: {list(input_components.keys())}, "
698
            f"skipped components: {skipped_components}, "
699
            f"skipped optional inputs: {skipped_optional_input_sockets}."
700
            f"This is likely a Canals bug. Please open an issue at https://github.com/deepset-ai/canals.",
701
        )
702

703
    def _skip_downstream_unvisited_nodes(self, component_name: str, inputs_buffer: OrderedDict) -> OrderedDict:
1✔
704
        """
705
        When a component is skipped, put all downstream nodes in the inputs buffer too: the might be skipped too,
706
        unless they are merge nodes. They will be evaluated later by the pipeline execution loop.
707
        """
708
        downstream_nodes = [e[1] for e in self.graph.out_edges(component_name)]
1✔
709
        for downstream_node in downstream_nodes:
1✔
710
            if downstream_node in inputs_buffer:
1✔
711
                continue
1✔
712
            if self.graph.nodes[downstream_node]["visits"] == 0:
1✔
713
                # Skip downstream nodes only if they never been visited
714
                inputs_buffer[downstream_node] = {}
1✔
715
        return inputs_buffer
1✔
716

717
    def _run_component(self, name: str, inputs: Dict[str, Any]) -> Dict[str, Any]:
1✔
718
        """
719
        Once we're confident this component is ready to run, run it and collect the output.
720
        """
721
        self.graph.nodes[name]["visits"] += 1
1✔
722
        instance = self.graph.nodes[name]["instance"]
1✔
723
        try:
1✔
724
            logger.info("* Running %s (visits: %s)", name, self.graph.nodes[name]["visits"])
1✔
725
            logger.debug("   '%s' inputs: %s", name, inputs)
1✔
726

727
            outputs = instance.run(**inputs)
1✔
728

729
            # Unwrap the output
730
            logger.debug("   '%s' outputs: %s\n", name, outputs)
1✔
731

732
        except Exception as e:
×
733
            raise PipelineRuntimeError(
×
734
                f"{name} raised '{e.__class__.__name__}: {e}' \nInputs: {inputs}\n\n"
735
                "See the stacktrace above for more information."
736
            ) from e
737

738
        return outputs
1✔
739

740
    def _route_output(
1✔
741
        self,
742
        node_name: str,
743
        node_results: Dict[str, Any],
744
        inputs_buffer: OrderedDict,
745
    ) -> OrderedDict:
746
        """
747
        Distrubute the outputs of the component into the input buffer of downstream components.
748

749
        Returns the updated inputs buffer.
750
        """
751
        # This is not a terminal node: find out where the output goes, to which nodes and along which edge
752
        is_decision_node_for_loop = (
1✔
753
            any(networkx.has_path(self.graph, edge[1], node_name) for edge in self.graph.out_edges(node_name))
754
            and len(self.graph.out_edges(node_name)) > 1
755
        )
756
        for edge_data in self.graph.out_edges(node_name, data=True):
1✔
757
            to_socket = edge_data[2]["to_socket"]
1✔
758
            from_socket = edge_data[2]["from_socket"]
1✔
759
            target_node = edge_data[1]
1✔
760

761
            # If this is a decision node and a loop is involved, we add to the input buffer only the nodes
762
            # that received their expected output and we leave the others out of the queue.
763
            if is_decision_node_for_loop and node_results.get(from_socket.name, None) is None:
1✔
764
                if networkx.has_path(self.graph, target_node, node_name):
1✔
765
                    # In case we're choosing to leave a loop, do not put the loop's node in the buffer.
766
                    logger.debug(
1✔
767
                        "Not adding '%s' to the inputs buffer: we're leaving the loop.",
768
                        target_node,
769
                    )
770
                else:
771
                    # In case we're choosing to stay in a loop, do not put the external node in the buffer.
772
                    logger.debug(
1✔
773
                        "Not adding '%s' to the inputs buffer: we're staying in the loop.",
774
                        target_node,
775
                    )
776
            else:
777
                # In all other cases, populate the inputs buffer for all downstream nodes, setting None to any
778
                # edge that did not receive input.
779
                if target_node not in inputs_buffer:
1✔
780
                    inputs_buffer[target_node] = {}  # Create the buffer for the downstream node if it's not there yet
1✔
781

782
                value_to_route = node_results.get(from_socket.name, None)
1✔
783
                if value_to_route is not None:
1✔
784
                    inputs_buffer[target_node][to_socket.name] = value_to_route
1✔
785

786
        return inputs_buffer
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc