• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13314190432

13 Feb 2025 06:07PM UTC coverage: 85.929% (+0.2%) from 85.729%
13314190432

Pull #881

github

web-flow
Merge a5edc0312 into e83798b4c
Pull Request #881: check for correct custom Node names

224 of 297 branches covered (75.42%)

Branch coverage included in aggregate %.

7 of 7 new or added lines in 1 file covered. (100.0%)

4 existing lines in 1 file now uncovered.

1492 of 1700 relevant lines covered (87.76%)

3.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.47
/zntrack/node.py
1
import contextlib
4✔
2
import dataclasses
4✔
3
import datetime
4✔
4
import json
4✔
5
import logging
4✔
6
import pathlib
4✔
7
import typing as t
4✔
8
import uuid
4✔
9
import warnings
4✔
10

11
import typing_extensions as ty_ex
4✔
12
import znfields
4✔
13
import znflow
4✔
14

15
from zntrack.group import Group
4✔
16
from zntrack.state import NodeStatus
4✔
17
from zntrack.utils.misc import get_plugins_from_env
4✔
18

19
from dvc.stage.utils import is_valid_name
4✔
20
from dvc.stage.exceptions import InvalidStageName
4✔
21

22
from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
4✔
23

24
try:
4✔
25
    from typing import dataclass_transform
4✔
26
except ImportError:
1✔
27
    from typing_extensions import dataclass_transform
1✔
28

29
T = t.TypeVar("T", bound="Node")
4✔
30

31
log = logging.getLogger(__name__)
4✔
32

33
def _name_setter(self, attr_name: str, value: str) -> None:
4✔
34
    """Check if the node name is valid."""
35

36
    if value is not None and not is_valid_name(value):
4✔
37
        raise InvalidStageName
4✔
38

39
    if value is not None and "_" in value:
4✔
40
        warnings.warn(
4✔
41
            "Node name should not contain '_'."
42
            " This character is used for defining groups."
43
        )
44
    
45
    self.__dict__[attr_name] = value
4✔
46

47
def _name_getter(self, attr_name: str) -> str:
4✔
48
    """Retrieve the name of a node based on the current graph context.
49

50
    Parameter
51
    ---------
52
        attr_name (str): The attribute name to retrieve.
53

54
    Returns
55
    -------
56
        str: The resolved node name.
57

58
    """
59
    value = self.__dict__.get(attr_name)  # Safer lookup with .get()
4✔
60
    graph = znflow.get_graph()
4✔
61

62
    # If value exists and the graph is either empty or not inside a group, return it
63
    if value is not None:
4✔
64
        if graph is znflow.empty_graph or graph.active_group is None:
4✔
65
            return str(value)
4✔
66

67
    # If no graph is active, return the class name as the default
68
    if graph is znflow.empty_graph:
4✔
69
        return str(self.__class__.__name__)
4✔
70

71
    # Compute name based on project-wide node names
72
    return str(graph.compute_all_node_names()[self.uuid])
4✔
73

74

75
@dataclass_transform()
4✔
76
@dataclasses.dataclass(kw_only=True)
4✔
77
class Node(znflow.Node, znfields.Base):
4✔
78
    """A Node."""
79

80
    name: str | None = znfields.field(
4✔
81
        default=None, getter=_name_getter, setter=_name_setter
82
    )
83
    always_changed: bool = dataclasses.field(default=False, repr=False)
4✔
84

85
    _protected_ = znflow.Node._protected_ + ["nwd", "name", "state"]
4✔
86

87
    def __post_init__(self):
4✔
88
        if self.name is None:
4✔
89
            # automatic node names expects the name to be None when
90
            # exiting the graph context.
91
            if not znflow.get_graph() is not znflow.empty_graph:
×
92
                self.name = self.__class__.__name__
×
93

94
    def _post_load_(self):
4✔
95
        """Called after `from_rev` is called."""
96
        raise NotImplementedError
97

98
    def run(self):
4✔
99
        raise NotImplementedError
100

101
    def save(self):
4✔
102
        for plugin in self.state.plugins.values():
4✔
103
            with plugin:
4✔
104
                for field in dataclasses.fields(self):
4✔
105
                    value = getattr(self, field.name)
4✔
106
                    if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]):
4✔
UNCOV
107
                        raise ValueError(
×
108
                            f"Field '{field.name}' is not set."
109
                            " Please set it before saving."
110
                        )
111
                    try:
4✔
112
                        plugin.save(field)
4✔
113
                    except Exception as err:  # noqa: E722
×
114
                        if plugin._continue_on_error_:
×
UNCOV
115
                            warnings.warn(
×
116
                                f"Plugin {plugin.__class__.__name__} failed to"
117
                                f" save field {field.name}."
118
                            )
119
                        else:
UNCOV
120
                            raise err
×
121

122
        _ = self.state
4✔
123
        self.__dict__["state"]["state"] = NodeStatusEnum.FINISHED
4✔
124

125
    def __init_subclass__(cls):
4✔
126
        return dataclasses.dataclass(cls, kw_only=True)
4✔
127

128
    @property
4✔
129
    def nwd(self) -> pathlib.Path:
4✔
130
        return self.state.nwd
4✔
131

132
    @classmethod
4✔
133
    def from_rev(
4✔
134
        cls: t.Type[T],
135
        name: str | None = None,
136
        remote: str | None = None,
137
        rev: str | None = None,
138
        running: bool = False,
139
        lazy_evaluation: bool = True,
140
        **kwargs,
141
    ) -> T:
142
        if name is None:
4✔
143
            name = cls.__name__
4✔
144
        lazy_values = {}
4✔
145
        for field in dataclasses.fields(cls):
4✔
146
            # check if the field is in the init
147
            if field.init:
4✔
148
                lazy_values[field.name] = ZNTRACK_LAZY_VALUE
4✔
149

150
        lazy_values["name"] = name
4✔
151
        lazy_values["always_changed"] = None  # TODO: read the state from dvc.yaml
4✔
152
        instance = cls(**lazy_values)
4✔
153

154
        # TODO: check if the node is finished or not.
155
        instance.__dict__["state"] = NodeStatus(
4✔
156
            remote=remote,
157
            rev=rev,
158
            state=NodeStatusEnum.RUNNING if running else NodeStatusEnum.FINISHED,
159
            lazy_evaluation=lazy_evaluation,
160
            group=Group.from_nwd(instance.nwd),
161
        ).to_dict()
162

163
        instance.__dict__["state"]["plugins"] = get_plugins_from_env(instance)
4✔
164

165
        with contextlib.suppress(FileNotFoundError):
4✔
166
            # need to update run_count after the state is set
167
            # TODO: do we want to set the UUID as well?
168
            # TODO: test that run_count is correct, when using from_rev from another
169
            #  commit
170
            with instance.state.fs.open(instance.nwd / "node-meta.json") as f:
4✔
171
                content = json.load(f)
4✔
172
                run_count = content.get("run_count", 0)
4✔
173
                run_time = content.get("run_time", 0)
4✔
174
                if node_uuid := content.get("uuid", None):
4!
175
                    instance._uuid = uuid.UUID(node_uuid)
4✔
176
                instance.__dict__["state"]["run_count"] = run_count
4✔
177
                instance.__dict__["state"]["run_time"] = datetime.timedelta(
4✔
178
                    seconds=run_time
179
                )
180

181
        if not instance.state.lazy_evaluation:
4✔
182
            for field in dataclasses.fields(cls):
4✔
183
                _ = getattr(instance, field.name)
4✔
184

185
        instance._external_ = True
4✔
186
        if not running and hasattr(instance, "_post_load_"):
4✔
187
            with contextlib.suppress(NotImplementedError):
4✔
188
                instance._post_load_()
4✔
189

190
        return instance
4✔
191

192
    @property
4✔
193
    def state(self) -> NodeStatus:
4✔
194
        if "state" not in self.__dict__:
4✔
195
            self.__dict__["state"] = NodeStatus().to_dict()
4✔
196
            self.__dict__["state"]["plugins"] = get_plugins_from_env(self)
4✔
197

198
        return NodeStatus(**self.__dict__["state"], node=self)
4✔
199

200
    @ty_ex.deprecated("loading is handled automatically via lazy evaluation")
4✔
201
    def load(self):
4✔
UNCOV
202
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc