• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13324078184

14 Feb 2025 06:49AM UTC coverage: 85.928% (-0.001%) from 85.929%
13324078184

push

github

web-flow
fix `zntrack list` for nested groups (#882)

* fix `zntrack list` for nested groups

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* format

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

226 of 299 branches covered (75.59%)

Branch coverage included in aggregate %.

11 of 12 new or added lines in 1 file covered. (91.67%)

2 existing lines in 1 file now uncovered.

1496 of 1705 relevant lines covered (87.74%)

3.51 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

91.47
/zntrack/node.py
1
import contextlib
4✔
2
import dataclasses
4✔
3
import datetime
4✔
4
import json
4✔
5
import logging
4✔
6
import pathlib
4✔
7
import typing as t
4✔
8
import uuid
4✔
9
import warnings
4✔
10

11
import typing_extensions as ty_ex
4✔
12
import znfields
4✔
13
import znflow
4✔
14
from dvc.stage.exceptions import InvalidStageName
4✔
15
from dvc.stage.utils import is_valid_name
4✔
16

17
from zntrack.group import Group
4✔
18
from zntrack.state import NodeStatus
4✔
19
from zntrack.utils.misc import get_plugins_from_env
4✔
20

21
from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
4✔
22

23
try:
4✔
24
    from typing import dataclass_transform
4✔
UNCOV
25
except ImportError:
1✔
UNCOV
26
    from typing_extensions import dataclass_transform
1✔
27

28
T = t.TypeVar("T", bound="Node")
4✔
29

30
log = logging.getLogger(__name__)
4✔
31

32

33
def _name_setter(self, attr_name: str, value: str) -> None:
4✔
34
    """Check if the node name is valid."""
35

36
    if value is not None and not is_valid_name(value):
4✔
37
        raise InvalidStageName
4✔
38

39
    if value is not None and "_" in value:
4✔
40
        warnings.warn(
4✔
41
            "Node name should not contain '_'."
42
            " This character is used for defining groups."
43
        )
44

45
    self.__dict__[attr_name] = value
4✔
46

47

48
def _name_getter(self, attr_name: str) -> str:
4✔
49
    """Retrieve the name of a node based on the current graph context.
50

51
    Parameter
52
    ---------
53
        attr_name (str): The attribute name to retrieve.
54

55
    Returns
56
    -------
57
        str: The resolved node name.
58

59
    """
60
    value = self.__dict__.get(attr_name)  # Safer lookup with .get()
4✔
61
    graph = znflow.get_graph()
4✔
62

63
    # If value exists and the graph is either empty or not inside a group, return it
64
    if value is not None:
4✔
65
        if graph is znflow.empty_graph or graph.active_group is None:
4✔
66
            return str(value)
4✔
67

68
    # If no graph is active, return the class name as the default
69
    if graph is znflow.empty_graph:
4✔
70
        return str(self.__class__.__name__)
4✔
71

72
    # Compute name based on project-wide node names
73
    return str(graph.compute_all_node_names()[self.uuid])
4✔
74

75

76
@dataclass_transform()
4✔
77
@dataclasses.dataclass(kw_only=True)
4✔
78
class Node(znflow.Node, znfields.Base):
4✔
79
    """A Node."""
80

81
    name: str | None = znfields.field(
4✔
82
        default=None, getter=_name_getter, setter=_name_setter
83
    )
84
    always_changed: bool = dataclasses.field(default=False, repr=False)
4✔
85

86
    _protected_ = znflow.Node._protected_ + ["nwd", "name", "state"]
4✔
87

88
    def __post_init__(self):
4✔
89
        if self.name is None:
4✔
90
            # automatic node names expects the name to be None when
91
            # exiting the graph context.
92
            if not znflow.get_graph() is not znflow.empty_graph:
×
93
                self.name = self.__class__.__name__
×
94

95
    def _post_load_(self):
4✔
96
        """Called after `from_rev` is called."""
97
        raise NotImplementedError
98

99
    def run(self):
4✔
100
        raise NotImplementedError
101

102
    def save(self):
4✔
103
        for plugin in self.state.plugins.values():
4✔
104
            with plugin:
4✔
105
                for field in dataclasses.fields(self):
4✔
106
                    value = getattr(self, field.name)
4✔
107
                    if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]):
4✔
108
                        raise ValueError(
×
109
                            f"Field '{field.name}' is not set."
110
                            " Please set it before saving."
111
                        )
112
                    try:
4✔
113
                        plugin.save(field)
4✔
114
                    except Exception as err:  # noqa: E722
×
115
                        if plugin._continue_on_error_:
×
116
                            warnings.warn(
×
117
                                f"Plugin {plugin.__class__.__name__} failed to"
118
                                f" save field {field.name}."
119
                            )
120
                        else:
121
                            raise err
×
122

123
        _ = self.state
4✔
124
        self.__dict__["state"]["state"] = NodeStatusEnum.FINISHED
4✔
125

126
    def __init_subclass__(cls):
4✔
127
        return dataclasses.dataclass(cls, kw_only=True)
4✔
128

129
    @property
4✔
130
    def nwd(self) -> pathlib.Path:
4✔
131
        return self.state.nwd
4✔
132

133
    @classmethod
4✔
134
    def from_rev(
4✔
135
        cls: t.Type[T],
136
        name: str | None = None,
137
        remote: str | None = None,
138
        rev: str | None = None,
139
        running: bool = False,
140
        lazy_evaluation: bool = True,
141
        **kwargs,
142
    ) -> T:
143
        if name is None:
4✔
144
            name = cls.__name__
4✔
145
        lazy_values = {}
4✔
146
        for field in dataclasses.fields(cls):
4✔
147
            # check if the field is in the init
148
            if field.init:
4✔
149
                lazy_values[field.name] = ZNTRACK_LAZY_VALUE
4✔
150

151
        lazy_values["name"] = name
4✔
152
        lazy_values["always_changed"] = None  # TODO: read the state from dvc.yaml
4✔
153
        instance = cls(**lazy_values)
4✔
154

155
        # TODO: check if the node is finished or not.
156
        instance.__dict__["state"] = NodeStatus(
4✔
157
            remote=remote,
158
            rev=rev,
159
            state=NodeStatusEnum.RUNNING if running else NodeStatusEnum.FINISHED,
160
            lazy_evaluation=lazy_evaluation,
161
            group=Group.from_nwd(instance.nwd),
162
        ).to_dict()
163

164
        instance.__dict__["state"]["plugins"] = get_plugins_from_env(instance)
4✔
165

166
        with contextlib.suppress(FileNotFoundError):
4✔
167
            # need to update run_count after the state is set
168
            # TODO: do we want to set the UUID as well?
169
            # TODO: test that run_count is correct, when using from_rev from another
170
            #  commit
171
            with instance.state.fs.open(instance.nwd / "node-meta.json") as f:
4✔
172
                content = json.load(f)
4✔
173
                run_count = content.get("run_count", 0)
4✔
174
                run_time = content.get("run_time", 0)
4✔
175
                if node_uuid := content.get("uuid", None):
4!
176
                    instance._uuid = uuid.UUID(node_uuid)
4✔
177
                instance.__dict__["state"]["run_count"] = run_count
4✔
178
                instance.__dict__["state"]["run_time"] = datetime.timedelta(
4✔
179
                    seconds=run_time
180
                )
181

182
        if not instance.state.lazy_evaluation:
4✔
183
            for field in dataclasses.fields(cls):
4✔
184
                _ = getattr(instance, field.name)
4✔
185

186
        instance._external_ = True
4✔
187
        if not running and hasattr(instance, "_post_load_"):
4✔
188
            with contextlib.suppress(NotImplementedError):
4✔
189
                instance._post_load_()
4✔
190

191
        return instance
4✔
192

193
    @property
4✔
194
    def state(self) -> NodeStatus:
4✔
195
        if "state" not in self.__dict__:
4✔
196
            self.__dict__["state"] = NodeStatus().to_dict()
4✔
197
            self.__dict__["state"]["plugins"] = get_plugins_from_env(self)
4✔
198

199
        return NodeStatus(**self.__dict__["state"], node=self)
4✔
200

201
    @ty_ex.deprecated("loading is handled automatically via lazy evaluation")
4✔
202
    def load(self):
4✔
203
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc