• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / ZnTrack / 12946204970

24 Jan 2025 08:43AM UTC coverage: 83.063% (+0.07%) from 82.992%
12946204970

Pull #870

github

web-flow
Merge 2fc8509ee into 3cd192980
Pull Request #870: add `zntrack.config` to update default values

952 of 1232 branches covered (77.27%)

Branch coverage included in aggregate %.

10 of 10 new or added lines in 5 files covered. (100.0%)

9 existing lines in 1 file now uncovered.

1613 of 1856 relevant lines covered (86.91%)

3.47 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

92.7
/zntrack/plugins/dvc_plugin/__init__.py
1
import contextlib
4✔
2
import copy
4✔
3
import dataclasses
4✔
4
import json
4✔
5
import pathlib
4✔
6
import typing as t
4✔
7

8
import pandas as pd
4✔
9
import yaml
4✔
10
import znflow
4✔
11
import znflow.handler
4✔
12
import znflow.utils
4✔
13
import znjson
4✔
14

15
from zntrack import converter
4✔
16
from zntrack.config import (
4✔
17
    NOT_AVAILABLE,
18
    PARAMS_FILE_PATH,
19
    PLUGIN_EMPTY_RETRUN_VALUE,
20
    ZNTRACK_CACHE,
21
    ZNTRACK_FIELD_DUMP,
22
    ZNTRACK_FIELD_LOAD,
23
    ZNTRACK_FIELD_SUFFIX,
24
    ZNTRACK_FILE_PATH,
25
    ZNTRACK_LAZY_VALUE,
26
    ZNTRACK_OPTION,
27
    ZNTRACK_OPTION_PLOTS_CONFIG,
28
    ZnTrackOptionEnum,
29
)
30
from zntrack import config
4✔
31

32
# if t.TYPE_CHECKING:
33
from zntrack.node import Node
4✔
34
from zntrack.plugins import ZnTrackPlugin, base_getter
4✔
35
from zntrack.utils import module_handler
4✔
36
from zntrack.utils.misc import (
4✔
37
    RunDVCImportPathHandler,
38
    TempPathLoader,
39
    get_attr_always_list,
40
    sort_and_deduplicate,
41
)
42
from zntrack.utils.node_wd import NWDReplaceHandler, nwd
4✔
43

44

45
@dataclasses.dataclass
4✔
46
class DVCPlugin(ZnTrackPlugin):
4✔
47
    def getter(self, field: dataclasses.Field) -> t.Any:
4✔
48
        getter = field.metadata.get(ZNTRACK_FIELD_LOAD)
4✔
49
        suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)
4✔
50

51
        if getter is not None:
4✔
52
            if suffix is not None:
4✔
53
                return getter(self.node, field.name, suffix=suffix)
4✔
54
            return getter(self.node, field.name)
4✔
55
        return PLUGIN_EMPTY_RETRUN_VALUE
4✔
56

57
    def save(self, field: dataclasses.Field) -> None:
4✔
58
        dump_func = field.metadata.get(ZNTRACK_FIELD_DUMP)
4✔
59
        suffix = field.metadata.get(ZNTRACK_FIELD_SUFFIX)
4✔
60

61
        if dump_func is not None:
4✔
62
            if suffix is not None:
4✔
63
                dump_func(self.node, field.name, suffix=suffix)
4✔
64
            else:
UNCOV
65
                dump_func(self.node, field.name)
×
66

67
    def convert_to_params_yaml(self) -> dict | object:
4✔
68
        data = {}
4✔
69
        for field in dataclasses.fields(self.node):
4✔
70
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS:
4✔
71
                data[field.name] = getattr(self.node, field.name)
4✔
72
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS:
4✔
73
                if getattr(self.node, field.name) is None:
4✔
74
                    continue
4✔
75
                content = getattr(self.node, field.name)
4✔
76
                if isinstance(content, (list, tuple, dict)):
4✔
77
                    new_content = []
4✔
78
                    for val in (
4✔
79
                        content
80
                        if isinstance(content, (list, tuple))
81
                        else content.values()
82
                    ):
83
                        if dataclasses.is_dataclass(val) and not isinstance(
4✔
84
                            val, (Node, znflow.Connection, znflow.CombinedConnections)
85
                        ):
86
                            # We save the values of the passed dataclasses
87
                            #  to the params.yaml file to be later used
88
                            #  by the DataclassContainer to recreate the
89
                            #  instance with the correct parameters.
90
                            dc_params = dataclasses.asdict(val)
4✔
91
                            dc_params["_cls"] = (
4✔
92
                                f"{module_handler(val)}.{val.__class__.__name__}"
93
                            )
94
                            new_content.append(dc_params)
4✔
95
                        elif isinstance(
4✔
96
                            val, (znflow.Connection, znflow.CombinedConnections)
97
                        ):
98
                            pass
4✔
99
                        else:
100
                            raise ValueError(
4✔
101
                                f"Found unsupported type '{type(val)}' ({val}) for DEPS field '{field.name}' in list"
102
                            )
103
                    if len(new_content) > 0:
4✔
104
                        data[field.name] = new_content
4✔
105
                elif dataclasses.is_dataclass(content) and not isinstance(
4✔
106
                    content, (Node, znflow.Connection, znflow.CombinedConnections)
107
                ):
108
                    dc_params = dataclasses.asdict(content)
4✔
109
                    dc_params["_cls"] = (
4✔
110
                        f"{module_handler(content)}.{content.__class__.__name__}"
111
                    )
112
                    data[field.name] = dc_params
4✔
113
                elif isinstance(content, (znflow.Connection, znflow.CombinedConnections)):
4✔
114
                    pass
4✔
115
                else:
116
                    raise ValueError(
4✔
117
                        f"Found unsupported type '{type(content)}' ({content}) for DEPS field '{field.name}'"
118
                    )
119

120
        if len(data) > 0:
4✔
121
            return data
4✔
122
        return PLUGIN_EMPTY_RETRUN_VALUE
4✔
123

124
    def convert_to_dvc_yaml(self) -> dict | object:
4✔
125
        node_dict = converter.NodeConverter().encode(self.node)
4✔
126

127
        cmd = f"zntrack run {node_dict['module']}.{node_dict['cls']} --name {node_dict['name']}"
4✔
128
        if hasattr(self.node, "_method"):
4✔
129
            cmd += f" --method {self.node._method}"
4✔
130
        stages = {
4✔
131
            "cmd": cmd,
132
            "metrics": [
133
                {
134
                    (self.node.nwd / "node-meta.json").as_posix(): {
135
                        "cache": config.ALWAYS_CACHE
136
                    }
137
                }
138
            ],
139
        }
140
        if self.node.always_changed:
4✔
141
            stages["always_changed"] = True
4✔
142
        plots = []
4✔
143

144
        nwd_handler = NWDReplaceHandler()
4✔
145

146
        for field in dataclasses.fields(self.node):
4!
147
            if field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS:
4✔
148
                stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).append(
4✔
149
                    self.node.name
150
                )
151
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PARAMS_PATH:
4✔
152
                if getattr(self.node, field.name) is None:
4✔
UNCOV
153
                    continue
×
154
                content = nwd_handler(
4✔
155
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
156
                )
157
                content = [{pathlib.Path(x).as_posix(): None} for x in content]
4✔
158
                stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).extend(content)
4✔
159
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.OUTS_PATH:
4✔
160
                if getattr(self.node, field.name) is None:
4✔
UNCOV
161
                    continue
×
162
                if getattr(self.node, field.name) == nwd:
4✔
UNCOV
163
                    raise ValueError(
×
164
                        "Can not use 'zntrack.nwd' direclty as an output path. Please use 'zntrack.nwd / <path/file>' instead."
165
                    )
166
                content = nwd_handler(
4✔
167
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
168
                )
169
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
170
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
171
                    content = [{c: {"cache": False}} for c in content]
4✔
172
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
173
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PLOTS_PATH:
4✔
174
                if getattr(self.node, field.name) is None:
4✔
UNCOV
175
                    continue
×
176
                content = nwd_handler(
4✔
177
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
178
                )
179
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
180
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
181
                    content = [{c: {"cache": False}} for c in content]
4✔
182
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
183
                # plots[self.node.name] = None
184
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.METRICS_PATH:
4✔
185
                if getattr(self.node, field.name) is None:
4✔
UNCOV
186
                    continue
×
187
                content = nwd_handler(
4✔
188
                    get_attr_always_list(self.node, field.name), nwd=self.node.nwd
189
                )
190
                content = [pathlib.Path(x).as_posix() for x in content]
4✔
191
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
192
                    content = [{c: {"cache": False}} for c in content]
4✔
193
                stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
4✔
194
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.OUTS:
4✔
195
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
196
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
197
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
198
                    content = [{c: {"cache": False}} for c in content]
4✔
199
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
200
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.PLOTS:
4✔
201
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
202
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
203
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
204
                    content = [{c: {"cache": False}} for c in content]
4✔
205
                stages.setdefault(ZnTrackOptionEnum.OUTS.value, []).extend(content)
4✔
206
                if ZNTRACK_OPTION_PLOTS_CONFIG in field.metadata:
4✔
207
                    file_path = (
4✔
208
                        (self.node.nwd / field.name).with_suffix(suffix).as_posix()
209
                    )
210
                    plots_config = field.metadata[ZNTRACK_OPTION_PLOTS_CONFIG].copy()
4✔
211
                    if "x" not in plots_config or "y" not in plots_config:
4✔
UNCOV
212
                        raise ValueError(
×
213
                            "Both 'x' and 'y' must be specified in the plots_config."
214
                        )
215
                    if "x" in plots_config:
4✔
216
                        plots_config["x"] = {file_path: plots_config["x"]}
4✔
217
                    if isinstance(plots_config["y"], list):
4!
218
                        for idx, y in enumerate(plots_config["y"]):
×
219
                            cfg = copy.deepcopy(plots_config)
×
220
                            cfg["y"] = {file_path: y}
×
UNCOV
221
                            plots.append({f"{self.node.name}_{field.name}_{idx}": cfg})
×
222
                    else:
223
                        if "y" in plots_config:
4✔
224
                            plots_config["y"] = {file_path: plots_config["y"]}
4✔
225
                        plots.append({f"{self.node.name}_{field.name}": plots_config})
4!
226
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.METRICS:
4✔
227
                suffix = field.metadata[ZNTRACK_FIELD_SUFFIX]
4✔
228
                content = [(self.node.nwd / field.name).with_suffix(suffix).as_posix()]
4✔
229
                if field.metadata.get(ZNTRACK_CACHE) is False:
4✔
230
                    content = [{c: {"cache": False}} for c in content]
4✔
231
                stages.setdefault(ZnTrackOptionEnum.METRICS.value, []).extend(content)
4✔
232
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS:
4✔
233
                if getattr(self.node, field.name) is None:
4✔
234
                    continue
4✔
235
                content = get_attr_always_list(self.node, field.name)
4✔
236
                paths = []
4✔
237
                for con in content:
4✔
238
                    if isinstance(con, (znflow.Connection)):
4✔
239
                        if con.item is not None:
4✔
240
                            raise NotImplementedError(
4✔
241
                                "znflow.Connection getitem is not supported yet."
242
                            )
243
                        paths.extend(
4✔
244
                            converter.node_to_output_paths(con.instance, con.attribute)
245
                        )
246
                    elif isinstance(con, (znflow.CombinedConnections)):
4✔
247
                        for _con in con.connections:
4✔
248
                            if con.item is not None:
4✔
249
                                raise NotImplementedError(
250
                                    "znflow.Connection getitem is not supported yet."
251
                                )
252
                            paths.extend(
4✔
253
                                converter.node_to_output_paths(
254
                                    _con.instance, _con.attribute
255
                                )
256
                            )
257
                    elif dataclasses.is_dataclass(con) and not isinstance(con, Node):
4✔
258
                        # add node name to params.yaml
259
                        stages.setdefault(ZnTrackOptionEnum.PARAMS.value, []).append(
4✔
260
                            self.node.name
261
                        )
262
                    else:
UNCOV
263
                        raise ValueError("unsupoorted type")
×
264

265
                if len(paths) > 0:
4✔
266
                    stages.setdefault(ZnTrackOptionEnum.DEPS.value, []).extend(paths)
4✔
267
            elif field.metadata.get(ZNTRACK_OPTION) == ZnTrackOptionEnum.DEPS_PATH:
4✔
268
                if getattr(self.node, field.name) is None:
4✔
269
                    continue
4✔
270
                content = [
4✔
271
                    pathlib.Path(c).as_posix()
272
                    for c in get_attr_always_list(self.node, field.name)
273
                ]
274
                RunDVCImportPathHandler()(self.node.__dict__.get(field.name))
4✔
275
                stages.setdefault(ZnTrackOptionEnum.DEPS.value, []).extend(content)
4✔
276

277
        for key in stages:
4✔
278
            if key in ["cmd", "always_changed"]:
4✔
279
                continue
4✔
280
            stages[key] = sort_and_deduplicate(stages[key])
4✔
281

282
        return {"stages": stages, "plots": plots}
4✔
283

284
    def convert_to_zntrack_json(self, graph) -> dict | object:
4✔
285
        data = {
4✔
286
            "nwd": self.node.nwd,
287
        }
288
        for field in dataclasses.fields(self.node):
4✔
289
            if field.metadata.get(ZNTRACK_OPTION) in [
4✔
290
                ZnTrackOptionEnum.PARAMS_PATH,
291
                ZnTrackOptionEnum.DEPS_PATH,
292
                ZnTrackOptionEnum.OUTS_PATH,
293
                ZnTrackOptionEnum.PLOTS_PATH,
294
                ZnTrackOptionEnum.METRICS_PATH,
295
                ZnTrackOptionEnum.DEPS,
296
            ]:
297
                data[field.name] = self.node.__dict__[field.name]
4✔
298

299
        data = znjson.dumps(
4✔
300
            data,
301
            indent=4,
302
            cls=znjson.ZnEncoder.from_converters(
303
                [
304
                    converter.ConnectionConverter,
305
                    converter.NodeConverter,
306
                    converter.CombinedConnectionsConverter,
307
                    znjson.converter.PathlibConverter,
308
                    converter.DVCImportPathConverter,
309
                    converter.DataclassConverter,
310
                ],
311
                add_default=False,
312
            ),
313
        )
314
        return json.loads(data)
4✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc