• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnFlow / 11546904243

18 Oct 2024 03:41PM UTC coverage: 96.804% (-0.04%) from 96.844%
11546904243

push

github

web-flow
add automatic break points based on magic method detection (#113)

* add automatic break points based on magic method detection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix tests but undo magic methods

* fix tests

* support all magic methods

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pre-commit fix

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

219 of 229 new or added lines in 5 files covered. (95.63%)

4 existing lines in 3 files now uncovered.

2635 of 2722 relevant lines covered (96.8%)

3.87 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

81.48
/znflow/deployment/dask_depl.py
1
"""ZnFlow deployment using Dask."""
4✔
2

3
import dataclasses
4✔
4
import typing
4✔
5
import typing as t
4✔
6
import uuid
4✔
7

8
from dask.distributed import Client, Future
4✔
9

10
from znflow import handler
4✔
11
from znflow.handler import UpdateConnectionsWithPredecessor
4✔
12
from znflow.node import Node
4✔
13

14
from .base import DeploymentBase
4✔
15

16
if typing.TYPE_CHECKING:
4✔
UNCOV
17
    pass
18

19

20
def node_submit(node, **kwargs):
4✔
21
    """Submit script for Dask worker.
22

23
    Parameters
24
    ----------
25
    node: any
26
        the Node class
27
    kwargs: dict
28
        predecessors: dict of {uuid: Connection} shape
29

30
    Returns
31
    -------
32
    any:
33
        the Node class with updated state (after calling "Node.run").
34

35
    """
36
    predecessors = kwargs.get("predecessors", {})
×
37
    updater = UpdateConnectionsWithPredecessor()
×
38
    for item in dir(node):
×
39
        # TODO this information is available in the graph,
40
        #  no need to expensively iterate over all attributes
41
        if item.startswith("_"):
×
42
            continue
×
43
        value = updater(getattr(node, item), predecessors=predecessors)
×
44
        if updater.updated:
×
45
            setattr(node, item, value)
×
46

47
    node.run()
×
48
    return node
×
49

50

51
# TODO: release the future objects
52
@dataclasses.dataclass
4✔
53
class DaskDeployment(DeploymentBase):
4✔
54
    client: Client = dataclasses.field(default_factory=Client)
4✔
55
    results: typing.Dict[uuid.UUID, Future] = dataclasses.field(
4✔
56
        default_factory=dict, init=False
57
    )
58

59
    def run(self, nodes: t.Optional[list] = None):
4✔
60
        super().run(nodes)
4✔
61
        self._load_results()
4✔
62

63
    def _run_node(self, node_uuid):
4✔
64
        node = self.graph.nodes[node_uuid]["value"]
4✔
65
        predecessors = list(self.graph.predecessors(node_uuid))
4✔
66
        for predecessor in predecessors:
4✔
67
            predecessor_available = self.graph.nodes[predecessor].get("available", False)
4✔
68
            if self.graph.immutable_nodes and predecessor_available:
4✔
69
                continue
4✔
70
            self._run_node(predecessor)
4✔
71

72
        node_available = self.graph.nodes[node_uuid].get("available", False)
4✔
73
        if self.graph.immutable_nodes and node_available:
4✔
74
            return
4✔
75
        if node._external_:
4✔
76
            raise NotImplementedError(
77
                "External nodes are not supported in Dask deployment"
78
            )
79

80
        self.results[node_uuid] = self.client.submit(
4✔
81
            node_submit,
82
            node=node,
83
            predecessors={x: self.results[x] for x in self.results if x in predecessors},
84
            pure=False,
85
            key=f"{node.__class__.__name__}-{node_uuid}",
86
        )
87
        self.graph.nodes[node_uuid]["available"] = True
4✔
88

89
    def _load_results(self):
4✔
90
        # TODO: only load nodes that have actually changed
91
        for node_uuid in self.graph.reverse():
4✔
92
            node = self.graph.nodes[node_uuid]["value"]
4✔
93
            try:
4✔
94
                result = self.results[node.uuid].result()
4✔
95
                if isinstance(node, Node):
4✔
96
                    node.__dict__.update(result.__dict__)
4✔
97
                    self.graph._update_node_attributes(node, handler.UpdateConnectors())
4✔
98
                else:
99
                    node.result = result.result
4✔
100
            except KeyError:
4✔
101
                pass
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc