• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

zincware / dask4dvc / 4764305061

pending completion
4764305061

Pull #25

github

GitHub
Merge 68c2654fb into 556449cd0
Pull Request #25: Update dask4dvc arguments

132 of 132 new or added lines in 5 files covered. (100.0%)

237 of 290 relevant lines covered (81.72%)

0.82 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.85
/dask4dvc/methods.py
1
"""Some general 'dask4dvc' methods."""
1✔
2

3

4
import contextlib
1✔
5
import typing
1✔
6
import logging
1✔
7

8
import dask.distributed
1✔
9
import dvc.lock
1✔
10
import dvc.exceptions
1✔
11
import dvc.repo
1✔
12
from dvc.repo.reproduce import _get_steps
1✔
13
import dvc.utils.strictyaml
1✔
14
import dvc.stage
1✔
15
from dvc.stage.cache import RunCacheNotFoundError
1✔
16
import random
1✔
17
import time
1✔
18
import subprocess
1✔
19

20
from dask4dvc import utils
1✔
21

22
log = logging.getLogger(__name__)
1✔
23

24

25
def _run_locked_cmd(
1✔
26
    repo: dvc.repo.Repo, func: typing.Callable, *args: tuple, **kwargs: dict
27
) -> typing.Any:
28
    """Retry running a DVC command until the lock is released.
29

30
    Parameters
31
    ----------
32
    repo: dvc.repo.Repo
33
        The DVC repository.
34
    func: callable
35
        The DVC command to run. e.g. 'repo.reproduce'
36
    *args: list
37
        The positional arguments for the command.
38
    **kwargs: dict
39
        The keyword arguments for the command.
40

41

42
    Returns
43
    -------
44
    typing.Any: the return value of the command.
45
    """
46
    err = ValueError("Something went wrong")
×
47
    for _ in range(utils.CONFIG.retries):
×
48
        try:
×
49
            while repo.lock.is_locked:
×
50
                time.sleep(random.random() * 5)  # between 0 and 5 seconds
×
51
            return func(*args, **kwargs)
×
52
        except (dvc.lock.LockError, dvc.utils.strictyaml.YAMLValidationError) as err:
×
53
            log.debug(err)
×
54
    raise err
×
55

56

57
def _load_run_cache(repo: dvc.repo.Repo, stage: dvc.stage.Stage) -> None:
1✔
58
    """Load the run cache for the given stage.
59

60
    Raises
61
    ------
62
    RunCacheNotFoundError:
63
        if the stage is not cached.
64
    """
65
    with dvc.repo.lock_repo(repo):
×
66
        with repo.scm_context():
×
67
            repo.stage_cache.restore(stage=stage)
×
68
            log.info(
×
69
                f"Stage '{stage.addressing}' is cached - skipping run, checking out"
70
                " outputs "
71
            )
72

73

74
def submit_stage(name: str, force: bool, successors: list) -> str:
1✔
75
    """Submit a stage to the Dask cluster."""
76
    repo = dvc.repo.Repo()
×
77

78
    if force:
×
79
        stages = [repo.stage.get_target(name)]
×
80
    else:
81
        # dvc reproduce returns the stages that are not checked out
82
        stages = _run_locked_cmd(repo, repo.reproduce, name, dry=True, single_item=True)
×
83

84
    if len(stages) == 0 and not force:
×
85
        # if the stage is already checked out, we don't need to run it
86
        log.info(f"Stage '{name}' didn't change, skipping")
×
87

88
    else:
89
        if len(stages) > 1:
×
90
            # we use single-item, so it should never be more than 1
91
            raise ValueError("Something went wrong")
×
92

93
        for stage in stages:
×
94
            if not force:
×
95
                with contextlib.suppress(RunCacheNotFoundError):
×
96
                    # check if the stage is already in the run cache
97
                    _run_locked_cmd(repo, _load_run_cache, repo, stages[0])
×
98
                    return name
×
99
            # if not, run the stage
100
            log.info(f"Running stage '{name}': \n > {stage.cmd}")
×
101
            subprocess.check_call(stage.cmd, shell=True)
×
102
            # add the stage to the run cache
103
            _run_locked_cmd(repo, repo.commit, name, force=True)
×
104

105
    return name
×
106

107

108
def parallel_submit(
1✔
109
    client: dask.distributed.Client, targets: list[str], force: bool
110
) -> typing.Dict[str, dask.distributed.Future]:
111
    """Submit all stages to the Dask cluster."""
112
    mapping = {}
1✔
113
    repo = dvc.repo.Repo()
1✔
114

115
    if len(targets) == 0:
1✔
116
        targets = repo.index.graph.nodes
1✔
117
    else:
118
        targets = [repo.stage.get_target(x) for x in targets]
1✔
119

120
    nodes = _get_steps(repo.index.graph, targets, downstream=False, single_item=False)
1✔
121

122
    for node in nodes:
1✔
123
        if node.cmd is None:
1✔
124
            # if the stage doesn't have a command, e.g. a dvc tracked file
125
            # we don't need to run it
126
            mapping[node] = None
1✔
127
            continue
1✔
128
        successors = [
1✔
129
            mapping[successor] for successor in repo.index.graph.successors(node)
130
        ]
131

132
        mapping[node] = client.submit(
1✔
133
            submit_stage,
134
            node.addressing,
135
            force=force,
136
            successors=successors,
137
            pure=False,
138
            key=node.addressing,
139
        )
140

141
    mapping = {
1✔
142
        node.addressing: future for node, future in mapping.items() if future is not None
143
    }
144

145
    return mapping
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc