• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13105840119

01 Feb 2025 12:01PM UTC coverage: 85.729%. Remained the same
13105840119

push

github

web-flow
test `cached_property` (#877)

* add cached_property test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

223 of 296 branches covered (75.34%)

Branch coverage included in aggregate %.

1483 of 1694 relevant lines covered (87.54%)

3.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.52
/zntrack/node.py
1
import contextlib
4✔
2
import dataclasses
4✔
3
import datetime
4✔
4
import json
4✔
5
import logging
4✔
6
import pathlib
4✔
7
import typing as t
4✔
8
import uuid
4✔
9
import warnings
4✔
10

11
import typing_extensions as ty_ex
4✔
12
import znfields
4✔
13
import znflow
4✔
14

15
from zntrack.group import Group
4✔
16
from zntrack.state import NodeStatus
4✔
17
from zntrack.utils.misc import get_plugins_from_env
4✔
18

19
from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
4✔
20

21
try:
4✔
22
    from typing import dataclass_transform
4✔
23
except ImportError:
1✔
24
    from typing_extensions import dataclass_transform
1✔
25

26
T = t.TypeVar("T", bound="Node")
4✔
27

28
log = logging.getLogger(__name__)
4✔
29

30

31
def _name_getter(self, attr_name: str) -> str:
4✔
32
    """Retrieve the name of a node based on the current graph context.
33

34
    Parameter
35
    ---------
36
        attr_name (str): The attribute name to retrieve.
37

38
    Returns
39
    -------
40
        str: The resolved node name.
41

42
    """
43
    value = self.__dict__.get(attr_name)  # Safer lookup with .get()
4✔
44
    graph = znflow.get_graph()
4✔
45

46
    # If value exists and the graph is either empty or not inside a group, return it
47
    if value is not None:
4✔
48
        if graph is znflow.empty_graph or graph.active_group is None:
4✔
49
            return str(value)
4✔
50

51
    # If no graph is active, return the class name as the default
52
    if graph is znflow.empty_graph:
4✔
53
        return str(self.__class__.__name__)
4✔
54

55
    # Compute name based on project-wide node names
56
    return str(graph.compute_all_node_names()[self.uuid])
4✔
57

58

59
@dataclass_transform()
4✔
60
@dataclasses.dataclass(kw_only=True)
4✔
61
class Node(znflow.Node, znfields.Base):
4✔
62
    """A Node."""
63

64
    name: str | None = znfields.field(
4✔
65
        default=None, getter=_name_getter
66
    )  # TODO: add setter and log warning
67
    always_changed: bool = dataclasses.field(default=False, repr=False)
4✔
68

69
    _protected_ = znflow.Node._protected_ + ["nwd", "name", "state"]
4✔
70

71
    def __post_init__(self):
4✔
72
        if self.name is None:
4✔
73
            # automatic node names expects the name to be None when
74
            # exiting the graph context.
75
            if not znflow.get_graph() is not znflow.empty_graph:
×
76
                self.name = self.__class__.__name__
×
77
                if "_" in self.name:
×
78
                    log.warning(
×
79
                        "Node name should not contain '_'."
80
                        " This character is used for defining groups."
81
                    )
82

83
    def _post_load_(self):
4✔
84
        """Called after `from_rev` is called."""
85
        raise NotImplementedError
86

87
    def run(self):
4✔
88
        raise NotImplementedError
89

90
    def save(self):
4✔
91
        for plugin in self.state.plugins.values():
4✔
92
            with plugin:
4✔
93
                for field in dataclasses.fields(self):
4✔
94
                    value = getattr(self, field.name)
4✔
95
                    if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]):
4✔
96
                        raise ValueError(
×
97
                            f"Field '{field.name}' is not set."
98
                            " Please set it before saving."
99
                        )
100
                    try:
4✔
101
                        plugin.save(field)
4✔
102
                    except Exception as err:  # noqa: E722
×
103
                        if plugin._continue_on_error_:
×
104
                            warnings.warn(
×
105
                                f"Plugin {plugin.__class__.__name__} failed to"
106
                                f" save field {field.name}."
107
                            )
108
                        else:
109
                            raise err
×
110

111
        _ = self.state
4✔
112
        self.__dict__["state"]["state"] = NodeStatusEnum.FINISHED
4✔
113

114
    def __init_subclass__(cls):
4✔
115
        return dataclasses.dataclass(cls, kw_only=True)
4✔
116

117
    @property
4✔
118
    def nwd(self) -> pathlib.Path:
4✔
119
        return self.state.nwd
4✔
120

121
    @classmethod
4✔
122
    def from_rev(
4✔
123
        cls: t.Type[T],
124
        name: str | None = None,
125
        remote: str | None = None,
126
        rev: str | None = None,
127
        running: bool = False,
128
        lazy_evaluation: bool = True,
129
        **kwargs,
130
    ) -> T:
131
        if name is None:
4✔
132
            name = cls.__name__
4✔
133
        lazy_values = {}
4✔
134
        for field in dataclasses.fields(cls):
4✔
135
            # check if the field is in the init
136
            if field.init:
4✔
137
                lazy_values[field.name] = ZNTRACK_LAZY_VALUE
4✔
138

139
        lazy_values["name"] = name
4✔
140
        lazy_values["always_changed"] = None  # TODO: read the state from dvc.yaml
4✔
141
        instance = cls(**lazy_values)
4✔
142

143
        # TODO: check if the node is finished or not.
144
        instance.__dict__["state"] = NodeStatus(
4✔
145
            remote=remote,
146
            rev=rev,
147
            state=NodeStatusEnum.RUNNING if running else NodeStatusEnum.FINISHED,
148
            lazy_evaluation=lazy_evaluation,
149
            group=Group.from_nwd(instance.nwd),
150
        ).to_dict()
151

152
        instance.__dict__["state"]["plugins"] = get_plugins_from_env(instance)
4✔
153

154
        with contextlib.suppress(FileNotFoundError):
4✔
155
            # need to update run_count after the state is set
156
            # TODO: do we want to set the UUID as well?
157
            # TODO: test that run_count is correct, when using from_rev from another
158
            #  commit
159
            with instance.state.fs.open(instance.nwd / "node-meta.json") as f:
4✔
160
                content = json.load(f)
4✔
161
                run_count = content.get("run_count", 0)
4✔
162
                run_time = content.get("run_time", 0)
4✔
163
                if node_uuid := content.get("uuid", None):
4!
164
                    instance._uuid = uuid.UUID(node_uuid)
4✔
165
                instance.__dict__["state"]["run_count"] = run_count
4✔
166
                instance.__dict__["state"]["run_time"] = datetime.timedelta(
4✔
167
                    seconds=run_time
168
                )
169

170
        if not instance.state.lazy_evaluation:
4✔
171
            for field in dataclasses.fields(cls):
4✔
172
                _ = getattr(instance, field.name)
4✔
173

174
        instance._external_ = True
4✔
175
        if not running and hasattr(instance, "_post_load_"):
4✔
176
            with contextlib.suppress(NotImplementedError):
4✔
177
                instance._post_load_()
4✔
178

179
        return instance
4✔
180

181
    @property
4✔
182
    def state(self) -> NodeStatus:
4✔
183
        if "state" not in self.__dict__:
4✔
184
            self.__dict__["state"] = NodeStatus().to_dict()
4✔
185
            self.__dict__["state"]["plugins"] = get_plugins_from_env(self)
4✔
186

187
        return NodeStatus(**self.__dict__["state"], node=self)
4✔
188

189
    @ty_ex.deprecated("loading is handled automatically via lazy evaluation")
4✔
190
    def load(self):
4✔
191
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc