• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13015170160

28 Jan 2025 04:37PM UTC coverage: 85.815% (+2.8%) from 83.058%
13015170160

push

github

web-flow
replace `poetry` with `uv` (#872)

* add log message

* move to uv

* update lint sections

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update urls

* update tests to use uv

* updates

* add h5py

* build documentation test in CI

* update rtd to use uv

* add `dvc-s3`

* add sphinx

* add furo

* add docs packages

* install pandoc

* use sudo

* add publish action

* add missing sections

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

223 of 295 branches covered (75.59%)

Branch coverage included in aggregate %.

3 of 5 new or added lines in 1 file covered. (60.0%)

1 existing line in 1 file now uncovered.

1483 of 1693 relevant lines covered (87.6%)

3.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

88.33
/zntrack/node.py
1
import contextlib
4✔
2
import dataclasses
4✔
3
import datetime
4✔
4
import json
4✔
5
import logging
4✔
6
import pathlib
4✔
7
import typing as t
4✔
8
import uuid
4✔
9
import warnings
4✔
10

11
import typing_extensions as te
4✔
12
import znfields
4✔
13
import znflow
4✔
14

15
from zntrack.group import Group
4✔
16
from zntrack.state import NodeStatus
4✔
17
from zntrack.utils.misc import get_plugins_from_env
4✔
18

19
from .config import NOT_AVAILABLE, ZNTRACK_LAZY_VALUE, NodeStatusEnum
4✔
20

21
try:
4✔
22
    from typing import dataclass_transform
4✔
23
except ImportError:
1✔
24
    from typing_extensions import dataclass_transform
1✔
25

26
T = t.TypeVar("T", bound="Node")
4✔
27

28
log = logging.getLogger(__name__)
4✔
29

30

31
def _name_getter(self, name):
4✔
32
    value = self.__dict__[name]
4✔
33
    if value is not None:
4✔
34
        return value
4✔
35
    # find the value based on the current project context
36
    graph = znflow.get_graph()
4✔
37
    if graph is znflow.empty_graph:
4✔
38
        return self.__class__.__name__
4✔
39

40
    return graph.compute_all_node_names()[self.uuid]
4✔
41

42

43
@dataclass_transform()
4✔
44
@dataclasses.dataclass(kw_only=True)
4✔
45
class Node(znflow.Node, znfields.Base):
4✔
46
    """A Node."""
47

48
    name: str | None = znfields.field(
4✔
49
        default=None, getter=_name_getter
50
    )  # TODO: add setter and log warning
51
    always_changed: bool = dataclasses.field(default=False, repr=False)
4✔
52

53
    _protected_ = znflow.Node._protected_ + ["nwd", "name", "state"]
4✔
54

55
    def __post_init__(self):
4✔
56
        if self.name is None:
4✔
57
            # automatic node names expectes the name to be None when
58
            # exiting the graph context.
59
            if not znflow.get_graph() is not znflow.empty_graph:
×
60
                self.name = self.__class__.__name__
×
NEW
61
                if "_" in self.name:
×
NEW
62
                    log.warning(
×
63
                        "Node name should not contain '_'. This character is used for defining groups."
64
                    )
65

66
    def _post_load_(self):
4✔
67
        """Called after `from_rev` is called."""
68
        raise NotImplementedError
69

70
    def run(self):
4✔
71
        raise NotImplementedError
72

73
    def save(self):
4✔
74
        for plugin in self.state.plugins.values():
4✔
75
            with plugin:
4✔
76
                for field in dataclasses.fields(self):
4✔
77
                    value = getattr(self, field.name)
4✔
78
                    if any(value is x for x in [ZNTRACK_LAZY_VALUE, NOT_AVAILABLE]):
4✔
79
                        raise ValueError(
×
80
                            f"Field '{field.name}' is not set. Please set it before saving."
81
                        )
82
                    try:
4✔
83
                        plugin.save(field)
4✔
84
                    except Exception as err:  # noqa: E722
×
85
                        if plugin._continue_on_error_:
×
86
                            warnings.warn(
×
87
                                f"Plugin {plugin.__class__.__name__} failed to save field {field.name}."
88
                            )
89
                        else:
90
                            raise err
×
91

92
        _ = self.state
4✔
93
        self.__dict__["state"]["state"] = NodeStatusEnum.FINISHED
4✔
94

95
    def __init_subclass__(cls):
4✔
96
        return dataclasses.dataclass(cls, kw_only=True)
4✔
97

98
    @property
4✔
99
    def nwd(self) -> pathlib.Path:
4✔
100
        return self.state.nwd
4✔
101

102
    @classmethod
4✔
103
    def from_rev(
4✔
104
        cls: t.Type[T],
105
        name: str | None = None,
106
        remote: str | None = None,
107
        rev: str | None = None,
108
        running: bool = False,
109
        lazy_evaluation: bool = True,
110
        **kwargs,
111
    ) -> T:
112
        if name is None:
4✔
113
            name = cls.__name__
4✔
114
        lazy_values = {}
4✔
115
        for field in dataclasses.fields(cls):
4✔
116
            # check if the field is in the init
117
            if field.init:
4✔
118
                lazy_values[field.name] = ZNTRACK_LAZY_VALUE
4✔
119

120
        lazy_values["name"] = name
4✔
121
        lazy_values["always_changed"] = None  # TODO: read the state from dvc.yaml
4✔
122
        instance = cls(**lazy_values)
4✔
123

124
        # TODO: check if the node is finished or not.
125
        instance.__dict__["state"] = NodeStatus(
4✔
126
            remote=remote,
127
            rev=rev,
128
            state=NodeStatusEnum.RUNNING if running else NodeStatusEnum.FINISHED,
129
            lazy_evaluation=lazy_evaluation,
130
            group=Group.from_nwd(instance.nwd),
131
        ).to_dict()
132

133
        instance.__dict__["state"]["plugins"] = get_plugins_from_env(instance)
4✔
134

135
        with contextlib.suppress(FileNotFoundError):
4✔
136
            # need to update run_count after the state is set
137
            # TODO: do we want to set the UUID as well?
138
            # TODO: test that run_count is correct, when using from_rev from another
139
            #  commit
140
            with instance.state.fs.open(instance.nwd / "node-meta.json") as f:
4✔
141
                content = json.load(f)
4✔
142
                run_count = content.get("run_count", 0)
4✔
143
                run_time = content.get("run_time", 0)
4✔
144
                if node_uuid := content.get("uuid", None):
4!
145
                    instance._uuid = uuid.UUID(node_uuid)
4✔
146
                instance.__dict__["state"]["run_count"] = run_count
4✔
147
                instance.__dict__["state"]["run_time"] = datetime.timedelta(
4✔
148
                    seconds=run_time
149
                )
150

151
        if not instance.state.lazy_evaluation:
4✔
152
            for field in dataclasses.fields(cls):
4✔
153
                _ = getattr(instance, field.name)
4✔
154

155
        instance._external_ = True
4✔
156
        if not running and hasattr(instance, "_post_load_"):
4✔
157
            with contextlib.suppress(NotImplementedError):
4✔
158
                instance._post_load_()
4✔
159

160
        return instance
4✔
161

162
    @property
4✔
163
    def state(self) -> NodeStatus:
4✔
164
        if "state" not in self.__dict__:
4✔
165
            self.__dict__["state"] = NodeStatus().to_dict()
4✔
166
            self.__dict__["state"]["plugins"] = get_plugins_from_env(self)
4✔
167

168
        return NodeStatus(**self.__dict__["state"], node=self)
4✔
169

170
    @te.deprecated("loading is handled automatically via lazy evaluation")
4✔
171
    def load(self):
4✔
172
        pass
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc