• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 13304757769

13 Feb 2025 09:41AM UTC coverage: 85.75% (+0.02%) from 85.729%
13304757769

Pull #880

github

web-flow
Merge 060c2c47e into e83798b4c
Pull Request #880: Update documentation

223 of 296 branches covered (75.34%)

Branch coverage included in aggregate %.

81 of 82 new or added lines in 4 files covered. (98.78%)

8 existing lines in 1 file now uncovered.

1486 of 1697 relevant lines covered (87.57%)

3.5 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.83
/zntrack/plugins/dvc_plugin/__init__.py
1
import copy
4✔
2
import dataclasses
4✔
3
import json
4✔
4
import pathlib
4✔
5
import typing as t
4✔
6

7
import znflow
4✔
8
import znflow.handler
4✔
9
import znflow.utils
4✔
10
import znjson
4✔
11

12
from zntrack import config, converter
4✔
13
from zntrack.config import (
4✔
14
    PLUGIN_EMPTY_RETRUN_VALUE,
15
    ZNTRACK_CACHE,
16
    ZNTRACK_FIELD_DUMP,
17
    ZNTRACK_FIELD_LOAD,
18
    ZNTRACK_FIELD_SUFFIX,
19
    ZNTRACK_OPTION,
20
    ZNTRACK_OPTION_PLOTS_CONFIG,
21
    ZnTrackOptionEnum,
22
)
23

24
# if t.TYPE_CHECKING:
25
from zntrack.node import Node
4✔
26
from zntrack.plugins import ZnTrackPlugin
4✔
27
from zntrack.utils import module_handler
4✔
28
from zntrack.utils.misc import (
4✔
29
    RunDVCImportPathHandler,
30
    get_attr_always_list,
31
    sort_and_deduplicate,
32
)
33
from zntrack.utils.node_wd import NWDReplaceHandler, nwd
4✔
34

35

36
@dataclasses.dataclass
4✔
37
class DVCPlugin(ZnTrackPlugin):
4✔
38
    def getter(self, field: dataclasses.Field) -> t.Any:
4✔
39
        getter = field.metadata.get(ZNTRACK_FIELD_LOAD)
4✔
40
        suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)
4✔
41

42
        if getter is not None:
4✔
43
            if suffix is not None:
4✔
44
                return getter(self.node, field.name, suffix=suffix)
4✔
45
            return getter(self.node, field.name)
4✔
46
        return PLUGIN_EMPTY_RETRUN_VALUE
4✔
47

48
    def save(self, field: dataclasses.Field) -> None:
4✔
49
        dump_func = field.metadata.get(ZNTRACK_FIELD_DUMP)
4✔
50
        suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)
4✔
51

52
        if dump_func is not None:
4✔
53
            if suffix is not None:
4!
54
                dump_func(self.node, field.name, suffix=suffix)
4✔
55
            else:
56
                dump_func(self.node, field.name)
×
57

58
    def convert_to_params_yaml(self) -> dict | object:
4✔
59
        data = {}
4✔
60
        for field in dataclasses.fields(self.node):
4✔
61
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS:
4✔
62
                data[field.name] = getattr(self.node, field.name)
4✔
63
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS:
4✔
64
                if getattr(self.node, field.name) is None:
4✔
65
                    continue
4✔
66
                content = getattr(self.node, field.name)
4✔
67
                if isinstance(content, (list, tuple, dict)):
4✔
68
                    new_content = []
4✔
69
                    for val in (
4✔
70
                        content
71
                        if isinstance(content, (list, tuple))
72
                        else content.values()
73
                    ):
74
                        if dataclasses.is_dataclass(val) and not isinstance(
4✔
75
                            val, (Node, znflow.Connection, znflow.CombinedConnections)
76
                        ):
77
                            # We save the values of the passed dataclasses
78
                            #  to the params.yaml file to be later used
79
                            #  by the DataclassContainer to recreate the
80
                            #  instance with the correct parameters.
81
                            dc_params = dataclasses.asdict(val)
4✔
82
                            dc_params["_cls"] = (
4✔
83
                                f"{module_handler(val)}.{val.__class__.__name__}"
84
                            )
85
                            new_content.append(dc_params)
4✔
86
                        elif isinstance(
4✔
87
                            val, (znflow.Connection, znflow.CombinedConnections)
88
                        ):
89
                            pass
4✔
90
                        else:
91
                            raise ValueError(
4✔
92
                                f"Found unsupported type '{type(val)}' ({val}) for DEPS"
93
                                f" field '{field.name}' in list"
94
                            )
95
                    if len(new_content) > 0:
4✔
96
                        data[field.name] = new_content
4✔
97
                elif dataclasses.is_dataclass(content) and not isinstance(
4✔
98
                    content, (Node, znflow.Connection, znflow.CombinedConnections)
99
                ):
100
                    dc_params = dataclasses.asdict(content)
4✔
101
                    dc_params["_cls"] = (
4✔
102
                        f"{module_handler(content)}.{content.__class__.__name__}"
103
                    )
104
                    data[field.name] = dc_params
4✔
105
                elif isinstance(content, (znflow.Connection, znflow.CombinedConnections)):
4✔
106
                    pass
4✔
107
                else:
108
                    raise ValueError(
4✔
109
                        f"Found unsupported type '{type(content)}' ({content})"
110
                        f" for DEPS field '{field.name}'"
111
                    )
112

113
        if len(data) > 0:
4✔
114
            return data
4✔
115
        return PLUGIN_EMPTY_RETRUN_VALUE
4✔
116

117
    def convert_to_dvc_yaml(self) -> dict | object:
4✔
118
        node_dict = converter.NodeConverter().encode(self.node)
4✔
119

120
        cmd = f"zntrack run {node_dict['module']}.{node_dict['cls']}"
4✔
121
        cmd += f" --name {node_dict['name']}"
4✔
122
        if hasattr(self.node, "_method"):
4✔
123
            cmd += f" --method {self.node._method}"
4✔
124
        stages = {
4✔
125
            "cmd": cmd,
126
            "metrics": [
127
                {
128
                    (self.node.nwd / "node-meta.json").as_posix(): {
129
                        "cache": config.ALWAYS_CACHE
130
                    }
131
                }
132
            ],
133
        }
134
        if self.node.always_changed:
4✔
135
            stages["always_changed"] = True
4✔
136
        plots = []
4✔
137

138
        nwd_handler = NWDReplaceHandler()
4✔
139

140
        for field in dataclasses.fields(self.node):
4✔
141
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS:
4✔
142
                stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).append(
4✔
143
                    self.node.name
144
                )
145
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS_PATH:
4✔
146
                if getattr(self.node, field.name) is None:
4✔
UNCOV
147
                    continue
×
148
                content = nwd_handler(
4✔
149
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
150
                )
151
                content = [{pathlib.Path(x).as_posix(): None} for x in content]
4✔
152
                stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).extend(content)
4✔
153
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.OUTS_PATH:
4✔
154
                if getattr(self.node, field.name) is None:
4✔
155
                    continue
×
156
                if getattr(self.node, field.name) == nwd:
4✔
UNCOV
157
                    raise ValueError(
×
158
                        "Can not use 'zntrack.nwd' directly as an output path. "
159
                        "Please use 'zntrack.nwd / <path/file>' instead."
160
                    )
161
                content = nwd_handler(
4✔
162
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
163
                )
164
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
165
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
166
                    content = [{c: {"cache": False}} for c in content]
4✔
167
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
168
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PLOTS_PATH:
4✔
169
                if getattr(self.node, field.name) is None:
4✔
UNCOV
170
                    continue
×
171
                content = nwd_handler(
4✔
172
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
173
                )
174
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
175
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
176
                    content = [{c: {"cache": False}} for c in content]
4✔
177
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
178
                # plots[self.node.name] = None
179
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.METRICS_PATH:
4✔
180
                if getattr(self.node, field.name) is None:
4✔
UNCOV
181
                    continue
×
182
                content = nwd_handler(
4✔
183
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
184
                )
185
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
186
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
187
                    content = [{c: {"cache": False}} for c in content]
4✔
188
                stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
4✔
189
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.OUTS:
4✔
190
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
191
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
192
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
193
                    content = [{c: {"cache": False}} for c in content]
4✔
194
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
195
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PLOTS:
4✔
196
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
197
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
198
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
199
                    content = [{c: {"cache": False}} for c in content]
4✔
200
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
201
                if ZNTRACK_OPTION_PLOTS_CONFIG in field.metadata:
4✔
202
                    file_path = (
4✔
203
                        (self.node.nwd / field.name).with_suffix(suffix).as_posix()
204
                    )
205
                    plots_config = field.metadata[ZNTRACK_OPTION_PLOTS_CONFIG].copy()
4✔
206
                    if "x" not in plots_config or "y" not in plots_config:
4✔
UNCOV
207
                        raise ValueError(
×
208
                            "Both 'x' and 'y' must be specified in the plots_config."
209
                        )
210
                    if "x" in plots_config:
4!
211
                        plots_config["x"] = {file_path: plots_config["x"]}
4✔
212
                    if isinstance(plots_config["y"], list):
4✔
213
                        for idx, y in enumerate(plots_config["y"]):
×
214
                            cfg = copy.deepcopy(plots_config)
×
UNCOV
215
                            cfg["y"] = {file_path: y}
×
UNCOV
216
                            plots.append({f"{self.node.name}_{field.name}_{idx}": cfg})
×
217
                    else:
218
                        if "y" in plots_config:
4!
219
                            plots_config["y"] = {file_path: plots_config["y"]}
4✔
220
                        plots.append({f"{self.node.name}_{field.name}": plots_config})
4✔
221
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.METRICS:
4✔
222
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
223
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
224
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
225
                    content = [{c: {"cache": False}} for c in content]
4✔
226
                stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
4✔
227
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS:
4✔
228
                if getattr(self.node, field.name) is None:
4✔
229
                    continue
4✔
230
                content = get_attr_always_list(self.node, field.name)
4✔
231
                paths = []
4✔
232
                for con in content:
4✔
233
                    if isinstance(con, (znflow.Connection)):
4✔
234
                        if con.item is not None:
4✔
235
                            raise NotImplementedError(
236
                                "znflow.Connection getitem is not supported yet."
237
                            )
238
                        paths.extend(
4✔
239
                            converter.node_to_output_paths(con.instance, con.attribute)
240
                        )
241
                    elif isinstance(con, (znflow.CombinedConnections)):
4✔
242
                        for _con in con.connections:
4✔
243
                            if con.item is not None:
4✔
244
                                raise NotImplementedError(
245
                                    "znflow.Connection getitem is not supported yet."
246
                                )
247
                            paths.extend(
4✔
248
                                converter.node_to_output_paths(
249
                                    _con.instance, _con.attribute
250
                                )
251
                            )
252
                    elif dataclasses.is_dataclass(con) and not isinstance(con, Node):
4!
253
                        # add node name to params.yaml
254
                        stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).append(
4✔
255
                            self.node.name
256
                        )
257
                    else:
UNCOV
258
                        raise ValueError("unsupoorted type")
×
259

260
                if len(paths) > 0:
4✔
261
                    stages.setdefault(ZnTrackOptionEnum.DEPS.value, []).extend(paths)
4✔
262
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS_PATH:
4✔
263
                if getattr(self.node, field.name) is None:
4✔
264
                    continue
4✔
265
                content = [
4✔
266
                    pathlib.Path(c).as_posix()
267
                    for c in get_attr_always_list(self.node, field.name)
268
                ]
269
                RunDVCImportPathHandler()(self.node.__dict__.get(field.name))
4✔
270
                stages.setdefault(ZnTrackOptionEnum.DEPS.value, []).extend(content)
4✔
271

272
        for key in stages:
4✔
273
            if key in ["cmd", "always_changed"]:
4✔
274
                continue
4✔
275
            stages[key] = sort_and_deduplicate(stages[key])
4✔
276

277
        return {"stages": stages, "plots": plots}
4✔
278

279
    def convert_to_zntrack_json(self, graph) -> dict | object:
4✔
280
        data = {
4✔
281
            "nwd": self.node.nwd,
282
        }
283
        for field in dataclasses.fields(self.node):
4✔
284
            if field.metadata.get(ZNTRACK_OPTION) in [
4✔
285
                ZnTrackOptionEnum.PARAMS_PATH,
286
                ZnTrackOptionEnum.DEPS_PATH,
287
                ZnTrackOptionEnum.OUTS_PATH,
288
                ZnTrackOptionEnum.PLOTS_PATH,
289
                ZnTrackOptionEnum.METRICS_PATH,
290
                ZnTrackOptionEnum.DEPS,
291
            ]:
292
                data[field.name] = self.node.__dict__[field.name]
4✔
293

294
        data = znjson.dumps(
4✔
295
            data,
296
            indent=4,
297
            cls=znjson.ZnEncoder.from_converters(
298
                [
299
                    converter.ConnectionConverter,
300
                    converter.NodeConverter,
301
                    converter.CombinedConnectionsConverter,
302
                    znjson.converter.PathlibConverter,
303
                    converter.DVCImportPathConverter,
304
                    converter.DataclassConverter,
305
                ],
306
                add_default=False,
307
            ),
308
        )
309
        return json.loads(data)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc