• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13087729679

01 Feb 2025 10:24AM UTC coverage: 85.858% (+0.04%) from 85.815%
13087729679

Pull #868

github

web-flow
Merge fbfdfa6be into 82751ef1b
Pull Request #868: fix custom node name in group

224 of 296 branches covered (75.68%)

Branch coverage included in aggregate %.

40 of 41 new or added lines in 4 files covered. (97.56%)

2 existing lines in 1 file now uncovered.

1488 of 1698 relevant lines covered (87.63%)

3.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.52
/zntrack/node.py
1
import contextlib
4✔
2
import dataclasses
4✔
3
import datetime
4✔
4
import json
4✔
5
import logging
4✔
6
import pathlib
4✔
7
import typing as t
4✔
8
import uuid
4✔
9
import warnings
4✔
10

11
import typing_extensions as te
4✔
12
import znfields
4✔
13
import znflow
4✔
14

15
from zntrack.group import Group
4✔
16
from zntrack.state import NodeStatus
4✔
17
from zntrack.utils.misc import get_plugins_from_env
4✔
18

19
from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
4✔
20

21
try:
4✔
22
    from typing import dataclass_transform
4✔
UNCOV
23
except ImportError:
1✔
UNCOV
24
    from typing_extensions import dataclass_transform
1✔
25

26
T = t.TypeVar("T", bound="Node")
4✔
27

28
log = logging.getLogger(__name__)
4✔
29

30

31
def _name_getter(self, attr_name: str) -> str:
4✔
32
    """Retrieve the name of a node based on the current graph context.
33

34
    Parameter
35
    ---------
36
        attr_name (str): The attribute name to retrieve.
37

38
    Returns
39
    -------
40
        str: The resolved node name.
41

42
    """
43
    value = self.__dict__.get(attr_name)  # Safer lookup with .get()
4✔
44
    graph = znflow.get_graph()
4✔
45

46
    # If value exists and the graph is either empty or not inside a group, return it
47
    if value is not None:
4✔
48
        if graph is znflow.empty_graph or graph.active_group is None:
4✔
49
            return str(value)
4✔
50

51
    # If no graph is active, return the class name as the default
52
    if graph is znflow.empty_graph:
4✔
53
        return str(self.__class__.__name__)
4✔
54

55
    # Compute name based on project-wide node names
56
    return str(graph.compute_all_node_names()[self.uuid])
4✔
57

58

59
@dataclass_transform()
4✔
60
@dataclasses.dataclass(kw_only=True)
4✔
61
class Node(znflow.Node, znfields.Base):
4✔
62
    """A Node."""
63

64
    name: str | None = znfields.field(
4✔
65
        default=None, getter=_name_getter
66
    )  # TODO: add setter and log warning
67
    always_changed: bool = dataclasses.field(default=False, repr=False)
4✔
68

69
    _protected_ = znflow.Node._protected_ + ["nwd", "name", "state"]
4✔
70

71
    def __post_init__(self):
4✔
72
        if self.name is None:
4✔
73
            # automatic node names expectes the name to be None when
74
            # exiting the graph context.
75
            if not znflow.get_graph() is not znflow.empty_graph:
×
76
                self.name = self.__class__.__name__
×
77
                if "_" in self.name:
×
78
                    log.warning(
×
79
                        "Node name should not contain '_'. This character is used for defining groups."
80
                    )
81

82
    def _post_load_(self):
4✔
83
        """Called after `from_rev` is called."""
84
        raise NotImplementedError
85

86
    def run(self):
4✔
87
        raise NotImplementedError
88

89
    def save(self):
4✔
90
        for plugin in self.state.plugins.values():
4✔
91
            with plugin:
4✔
92
                for field in dataclasses.fields(self):
4✔
93
                    value = getattr(self, field.name)
4✔
94
                    if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]):
4✔
95
                        raise ValueError(
×
96
                            f"Field '{field.name}' is not set. Please set it before saving."
97
                        )
98
                    try:
4✔
99
                        plugin.save(field)
4✔
100
                    except Exception as err:  # noqa: E722
×
101
                        if plugin._continue_on_error_:
×
102
                            warnings.warn(
×
103
                                f"Plugin {plugin.__class__.__name__} failed to save field {field.name}."
104
                            )
105
                        else:
106
                            raise err
×
107

108
        _ = self.state
4✔
109
        self.__dict__["state"]["state"] = NodeStatusEnum.FINISHED
4✔
110

111
    def __init_subclass__(cls):
4✔
112
        return dataclasses.dataclass(cls, kw_only=True)
4✔
113

114
    @property
4✔
115
    def nwd(self) -> pathlib.Path:
4✔
116
        return self.state.nwd
4✔
117

118
    @classmethod
4✔
119
    def from_rev(
4✔
120
        cls: t.Type[T],
121
        name: str | None = None,
122
        remote: str | None = None,
123
        rev: str | None = None,
124
        running: bool = False,
125
        lazy_evaluation: bool = True,
126
        **kwargs,
127
    ) -> T:
128
        if name is None:
4✔
129
            name = cls.__name__
4✔
130
        lazy_values = {}
4✔
131
        for field in dataclasses.fields(cls):
4✔
132
            # check if the field is in the init
133
            if field.init:
4✔
134
                lazy_values[field.name] = ZNTRACK_LAZY_VALUE
4✔
135

136
        lazy_values["name"] = name
4✔
137
        lazy_values["always_changed"] = None  # TODO: read the state from dvc.yaml
4✔
138
        instance = cls(**lazy_values)
4✔
139

140
        # TODO: check if the node is finished or not.
141
        instance.__dict__["state"] = NodeStatus(
4✔
142
            remote=remote,
143
            rev=rev,
144
            state=NodeStatusEnum.RUNNING if running else NodeStatusEnum.FINISHED,
145
            lazy_evaluation=lazy_evaluation,
146
            group=Group.from_nwd(instance.nwd),
147
        ).to_dict()
148

149
        instance.__dict__["state"]["plugins"] = get_plugins_from_env(instance)
4✔
150

151
        with contextlib.suppress(FileNotFoundError):
4✔
152
            # need to update run_count after the state is set
153
            # TODO: do we want to set the UUID as well?
154
            # TODO: test that run_count is correct, when using from_rev from another
155
            #  commit
156
            with instance.state.fs.open(instance.nwd / "node-meta.json") as f:
4✔
157
                content = json.load(f)
4✔
158
                run_count = content.get("run_count", 0)
4✔
159
                run_time = content.get("run_time", 0)
4✔
160
                if node_uuid := content.get("uuid", None):
4!
161
                    instance._uuid = uuid.UUID(node_uuid)
4✔
162
                instance.__dict__["state"]["run_count"] = run_count
4✔
163
                instance.__dict__["state"]["run_time"] = datetime.timedelta(
4✔
164
                    seconds=run_time
165
                )
166

167
        if not instance.state.lazy_evaluation:
4✔
168
            for field in dataclasses.fields(cls):
4✔
169
                _ = getattr(instance, field.name)
4✔
170

171
        instance._external_ = True
4✔
172
        if not running and hasattr(instance, "_post_load_"):
4✔
173
            with contextlib.suppress(NotImplementedError):
4✔
174
                instance._post_load_()
4✔
175

176
        return instance
4✔
177

178
    @property
4✔
179
    def state(self) -> NodeStatus:
4✔
180
        if "state" not in self.__dict__:
4✔
181
            self.__dict__["state"] = NodeStatus().to_dict()
4✔
182
            self.__dict__["state"]["plugins"] = get_plugins_from_env(self)
4✔
183

184
        return NodeStatus(**self.__dict__["state"], node=self)
4✔
185

186
    @te.deprecated("loading is handled automatically via lazy evaluation")
4✔
187
    def load(self):
4✔
188
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc