• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-python / 5948296099

23 Aug 2023 07:23AM UTC coverage: 85.801% (+0.04%) from 85.766%
5948296099

Pull #3601

github-actions

olevski
chore: run poetry lock
Pull Request #3601: hotfix: v2.6.1

40 of 48 new or added lines in 10 files covered. (83.33%)

285 existing lines in 25 files now uncovered.

25875 of 30157 relevant lines covered (85.8%)

4.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

51.93
/renku/infrastructure/git_merger.py
1
#
2
# Copyright 2017-2023 - Swiss Data Science Center (SDSC)
3
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
4
# Eidgenössische Technische Hochschule Zürich (ETHZ).
5
#
6
# Licensed under the Apache License, Version 2.0 (the "License");
7
# you may not use this file except in compliance with the License.
8
# You may obtain a copy of the License at
9
#
10
#     http://www.apache.org/licenses/LICENSE-2.0
11
#
12
# Unless required by applicable law or agreed to in writing, software
13
# distributed under the License is distributed on an "AS IS" BASIS,
14
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
# See the License for the specific language governing permissions and
16
# limitations under the License.
17
"""Merge strategies."""
4✔
18

19
import os
4✔
20
import shutil
4✔
21
from json import JSONDecodeError
4✔
22
from pathlib import Path
4✔
23
from tempfile import mkdtemp
4✔
24
from typing import List, NamedTuple, Optional, Union, cast
4✔
25

26
from BTrees.OOBTree import BTree, Bucket, TreeSet
4✔
27
from deepdiff import DeepDiff
4✔
28
from persistent import Persistent
4✔
29
from persistent.list import PersistentList
4✔
30
from zc.relation.catalog import Catalog
4✔
31

32
from renku.core import errors
4✔
33
from renku.core.constant import DATABASE_PATH, RENKU_HOME
4✔
34
from renku.core.util import communication
4✔
35
from renku.domain_model.dataset import Dataset, Url
4✔
36
from renku.domain_model.project import Project
4✔
37
from renku.domain_model.project_context import project_context
4✔
38
from renku.domain_model.workflow.plan import AbstractPlan
4✔
39
from renku.infrastructure.database import Database, Index
4✔
40
from renku.infrastructure.repository import Repository
4✔
41
from renku.version import __version__
4✔
42

43

44
class RemoteEntry(NamedTuple):
4✔
45
    """Reference to an entry in a database on a separate branch."""
46

47
    reference: str
4✔
48
    database: Database
4✔
49
    path: Path
4✔
50
    repository: Repository
4✔
51

52

53
class GitMerger:
4✔
54
    """Git metadata merger."""
55

56
    def merge(self, local: Path, remote: Path, base: Path) -> None:
4✔
57
        """Merge two renku metadata entries together."""
58
        repository = project_context.repository
×
59
        self.remote_entries: List[RemoteEntry] = []
×
60

61
        self._setup_worktrees(repository)
×
62

63
        merged = False
×
64
        self.local_database = project_context.database
×
65

66
        try:
×
67
            local_object = self.local_database.get_from_path(str(project_context.path / local))
×
68
            try:
×
69
                base_object: Optional[Persistent] = self.local_database.get_from_path(str(project_context.path / base))
×
70
            except (errors.ObjectNotFoundError, JSONDecodeError):
×
71
                base_object = None
×
72

73
            for entry in self.remote_entries:
×
74
                # NOTE: Loop through all remote merge branches (Octo merge) and try to merge them
75
                try:
×
76
                    self.remote_database = entry.database
×
77
                    remote_object = self.remote_database.get_from_path(str(project_context.path / remote))
×
78

79
                    # NOTE: treat merge result as new local for subsequent merges
80
                    local_object = self.merge_objects(local_object, remote_object, base_object)
×
81
                    merged = True
×
82
                except errors.ObjectNotFoundError:
×
83
                    continue
×
84
        finally:
85
            # NOTE: cleanup worktrees
86
            for entry in self.remote_entries:
×
87
                repository.remove_worktree(entry.path)
×
88
                shutil.rmtree(entry.path, ignore_errors=True)
×
89

90
        if not merged:
×
91
            raise errors.MetadataMergeError("Couldn't merge metadata: remote object not found in merge branches.")
×
92

93
        self.local_database.persist_to_path(local_object, local)
×
94

95
    def _setup_worktrees(self, repository):
4✔
96
        """Setup git worktrees for the remote branches."""
97

98
        # NOTE: Get remote branches
99
        remote_branches = [os.environ[k] for k in os.environ.keys() if k.startswith("GITHEAD")]
×
100

101
        database_path = Path(RENKU_HOME) / DATABASE_PATH
×
102

103
        for remote_branch in remote_branches:
×
104
            # NOTE: Create a new shallow worktree for each remote branch, could be several in case of an octo merge
105
            worktree_path = Path(mkdtemp())
×
106
            repository.create_worktree(worktree_path, reference=remote_branch, checkout=False)
×
107
            try:
×
108
                remote_repository = Repository(worktree_path)
×
109
                remote_repository.checkout(sparse=[database_path])
×
110

111
                self.remote_entries.append(
×
112
                    RemoteEntry(
113
                        remote_branch,
114
                        Database.from_path(worktree_path / database_path),
115
                        worktree_path,
116
                        remote_repository,
117
                    )
118
                )
119
            except Exception:
×
120
                # NOTE: cleanup worktree
121
                try:
×
122
                    repository.remove_worktree(worktree_path)
×
123
                except Exception:  # nosec
×
124
                    pass
×
125
                raise
×
126

127
    def merge_objects(self, local: Persistent, remote: Persistent, base: Optional[Persistent]) -> Persistent:
4✔
128
        """Merge two database objects."""
NEW
129
        if not isinstance(local, type(remote)):
×
130
            raise errors.MetadataMergeError(f"Cannot merge {local} and {remote}: disparate types.")
×
131
        if isinstance(local, (BTree, Index, Bucket)):
×
132
            return self.merge_btrees(local, remote)
×
133
        elif isinstance(local, TreeSet):  # type: ignore[unreachable]
×
134
            return self.merge_treesets(local, remote)
×
135
        elif isinstance(local, Catalog):
×
136
            return self.merge_catalogs(local, remote)
×
137
        elif isinstance(local, Project):
×
138
            return self.merge_projects(local, remote, cast(Optional[Project], base))
×
139
        else:
140
            raise errors.MetadataMergeError(
×
141
                f"Cannot merge {local} and {remote}: type not supported for automated merge."
142
            )
143

144
    def merge_btrees(
4✔
145
        self, local: Union[BTree, Index, Bucket], remote: Union[BTree, Index, Bucket]
146
    ) -> Union[BTree, Index, Bucket]:
147
        """Merge two BTrees."""
148
        local_key_ids = {k: getattr(v, "_p_oid", None) for k, v in local.items()}
1✔
149
        remote_key_ids = {k: getattr(v, "_p_oid", None) for k, v in remote.items()}
1✔
150

151
        common_modified_keys = [k for k, i in local_key_ids.items() if remote_key_ids.get(k, i) != i]
1✔
152
        new_remote_keys = [k for k in remote_key_ids.keys() if k not in local_key_ids]
1✔
153

154
        for new_remote in new_remote_keys:
1✔
155
            # NOTE: New entries in remote, we can just copy them over
156
            local[new_remote] = remote[new_remote]
1✔
157

158
        for common_key in common_modified_keys:
1✔
159
            # NOTE: Merge conflicts!
160
            local_object = local[common_key]
1✔
161
            remote_object = remote[common_key]
1✔
162

163
            comparison = self._compare_objects(local_object, remote_object)
1✔
164

165
            if comparison is not None:
1✔
166
                local[common_key] = comparison
1✔
167
                continue
1✔
168

169
            local_object._p_activate()
1✔
170
            remote_object._p_activate()
1✔
171
            diff = DeepDiff(local_object, remote_object, exclude_types=[Database])
1✔
172
            pretty_diff = diff.pretty().replace("Value of root.", "local.")
1✔
173
            pretty_diff = "\n".join(f"\t{line}" for line in pretty_diff.splitlines())
1✔
174
            entry_type = str(type(local_object)).split(".")[-1][:-2]
1✔
175

176
            action = communication.prompt(
1✔
177
                "Merge conflict detected:\n"
178
                f"{common_key} ({entry_type}) modified in local and remote branch.\n"
179
                f"Changes between local and remote:\n{pretty_diff}\n"
180
                "Which do you want to keep?\n[l]ocal, [r]emote, [a]bort:",
181
                default="a",
182
            )
183

184
            if action == "r":
1✔
185
                local[common_key] = remote_object
1✔
186
            elif action == "a":
1✔
187
                raise errors.MetadataMergeError("Merge aborted")
×
188
            elif action != "l":
1✔
189
                raise errors.MetadataMergeError(f"Invalid merge option selected: {action}")
×
190

191
        return local
1✔
192

193
    def merge_treesets(self, local: TreeSet, remote: TreeSet) -> TreeSet:
4✔
194
        """Merge two TreeSets."""
195
        local.update([e for e in remote if e not in local])
×
196
        return local
×
197

198
    def merge_indices(self, local: Index, remote: Index) -> Index:
4✔
199
        """Merge two Indices."""
200
        local_key_ids = {k: getattr(v, "_p_oid", None) for k, v in local.items()}
×
201
        remote_key_ids = {k: getattr(v, "_p_oid", None) for k, v in remote.items()}
×
202

203
        common_modified_keys = [k for k, i in local_key_ids.items() if remote_key_ids.get(k, i) != i]
×
204
        new_remote_keys = [k for k in remote_key_ids.keys() if k not in local_key_ids]
×
205

206
        for new_remote in new_remote_keys:
×
207
            # NOTE: New entries in remote, we can just copy them over
208
            local.add(remote.get(new_remote))
×
209

210
        for common_key in common_modified_keys:
×
211
            # NOTE: Merge conflicts!
212
            local_object = local.get(common_key)
×
213
            remote_object = remote.get(common_key)
×
214

215
            comparison = self._compare_objects(local_object, remote_object)
×
216

217
            if comparison is not None:
×
218
                local[common_key] = comparison
×
219
                continue
×
220

221
            local_object._p_activate()
×
222
            remote_object._p_activate()
×
223
            diff = DeepDiff(local_object, remote_object, exclude_types=[Database])
×
224
            pretty_diff = diff.pretty().replace("Value of root.", "local.")
×
225
            pretty_diff = "\n".join(f"\t{line}" for line in pretty_diff.splitlines())
×
226
            entry_type = str(type(local.get(common_key))).split(".")[-1][:-2]
×
227

228
            action = communication.prompt(
×
229
                "Merge conflict detected:\n"
230
                f"{common_key} ({entry_type}) modified in local and remote branch.\n"
231
                f"Changes between local and remote:\n{pretty_diff}\n"
232
                "Which do you want to keep?\n[l]ocal, [r]emote, [a]bort:",
233
                default="a",
234
            )
235

236
            if action == "r":
×
237
                local.pop(common_key)
×
238
                local.add(remote_object)
×
239
            elif action == "a":
×
240
                raise errors.MetadataMergeError("Merge aborted")
×
241
            elif action != "l":
×
242
                raise errors.MetadataMergeError(f"Invalid merge option selected: {action}")
×
243

244
        return local
×
245

246
    def merge_catalogs(self, local: Catalog, remote: Catalog) -> Catalog:
4✔
247
        """Merge two Catalogs."""
248
        for key, value in remote._EMPTY_name_TO_relcount_relset.items():
×
249
            if key not in local._EMPTY_name_TO_relcount_relset:
×
250
                local._EMPTY_name_TO_relcount_relset[key] = value
×
251

252
        for key, value in remote._name_TO_mapping.items():
×
253
            if key not in local._name_TO_mapping:
×
254
                local._name_TO_mapping[key] = value
×
255
                continue
×
256
            for subkey, subvalue in value.items():
×
257
                if subkey not in local._name_TO_mapping[key]:
×
258
                    local._name_TO_mapping[key][subkey] = subvalue
×
259

260
        for key, value in remote._reltoken_name_TO_objtokenset.items():
×
261
            if key not in local._reltoken_name_TO_objtokenset:
×
262
                local._reltoken_name_TO_objtokenset[key] = value
×
263

264
        return local
×
265

266
    def merge_projects(self, local: Project, remote: Project, base: Optional[Project]) -> Project:
4✔
267
        """Merge two Project entries."""
268

269
        local_changed = (
1✔
270
            base is None
271
            or local.keywords != base.keywords
272
            or local.description != base.description
273
            or local.annotations != base.annotations
274
        )
275
        local_template_changed = base is None or local.template_metadata != base.template_metadata
1✔
276
        remote_changed = (
1✔
277
            base is None
278
            or remote.keywords != base.keywords
279
            or remote.description != base.description
280
            or remote.annotations != base.annotations
281
        )
282
        remote_template_changed = base is None or remote.template_metadata != base.template_metadata
1✔
283

284
        if (local_changed or local_template_changed) and not remote_changed and not remote_template_changed:
1✔
285
            return local
1✔
286
        elif not local_changed and not local_template_changed and (remote_changed or remote_template_changed):
1✔
287
            return remote
1✔
288

289
        if local_changed or remote_changed:
1✔
290
            # NOTE: Merge keywords
291
            if local.keywords != remote.keywords:
1✔
292
                if base is None:
1✔
293
                    local.keywords = list(set(local.keywords) | set(remote.keywords))
×
294
                elif local.keywords != base.keywords and remote.keywords != base.keywords:
1✔
295
                    removed = (set(base.keywords) - set(local.keywords)) | (set(base.keywords) - set(remote.keywords))
1✔
296
                    added = (set(local.keywords) - set(base.keywords)) | (set(remote.keywords) - set(base.keywords))
1✔
297
                    existing = set(base.keywords) - removed
1✔
298
                    local.keywords = list(added | existing)
1✔
299
                elif remote.keywords != base.keywords:
×
300
                    local.keywords = remote.keywords
×
301

302
            # NOTE: Merge description
303
            if local.description != remote.description:
1✔
304
                if base is None or (local.description != base.description and remote.description != base.description):
×
305
                    local.description = communication.prompt(
×
306
                        f"Project description was modified in local and remote branch.\n"
307
                        f"local: {local.description}\nremote: {remote.description}\nEnter merged description: ",
308
                        default=local.description,
309
                    )
310
                elif remote.description != base.description:
×
311
                    local.description = remote.description
×
312

313
            # NOTE: Merge annotations
314
            local.annotations = list(set(local.annotations) | set(remote.annotations))
1✔
315

316
        # NOTE: Merge versions
317
        if int(remote.version) > int(local.version):
1✔
318
            local.version = remote.version
1✔
319

320
        local.agent_version = __version__
1✔
321

322
        # NOTE: Merge template data
323
        if local_template_changed and remote_template_changed:
1✔
324
            # NOTE: Merge conflicts!
325
            action = communication.prompt(
1✔
326
                "Merge conflict detected:\n Project template modified/update in both remote and local branch.\n"
327
                f"local: {local.template_metadata.template_source}@{local.template_metadata.template_ref}:"
328
                f"{local.template_metadata.template_id}, "
329
                "version {local.template_metadata.template_version}\n"
330
                f"remote: {remote.template_metadata.template_source}@{remote.template_metadata.template_ref}:"
331
                f"{remote.template_metadata.template_id}, "
332
                "version {remote.template_metadata.template_version}\n"
333
                "Which do you want to keep?\n[l]ocal, [r]emote, [a]bort:",
334
                default="a",
335
            )
336

337
            if action == "r":
1✔
338
                local.template_metadata = remote.template_metadata
1✔
339
            elif action == "a":
1✔
340
                raise errors.MetadataMergeError("Merge aborted")
×
341
            elif action != "l":
1✔
342
                raise errors.MetadataMergeError(f"Invalid merge option selected: {action}")
×
343
        elif remote_template_changed:
1✔
344
            local.template_metadata = remote.template_metadata
×
345

346
        return local
1✔
347

348
    def _compare_objects(self, local_object: Persistent, remote_object: Persistent) -> Optional[Persistent]:
4✔
349
        """Compare two objects and return the newer/relevant one, if possible."""
350
        if local_object == remote_object or getattr(local_object, "id", object()) == getattr(
1✔
351
            remote_object, "id", object()
352
        ):
353
            # NOTE: Objects are the same, nothing to do
354
            return local_object
×
355

356
        if isinstance(local_object, Dataset) and local_object.is_derivation():
1✔
357
            if self._is_dataset_derived_from(local_object, cast(Dataset, remote_object), self.local_database):
1✔
358
                return local_object
1✔
359
        if isinstance(remote_object, Dataset) and remote_object.is_derivation():
1✔
360
            if self._is_dataset_derived_from(remote_object, cast(Dataset, local_object), self.remote_database):
1✔
361
                return remote_object
1✔
362
        if isinstance(local_object, AbstractPlan) and local_object.is_derivation():
1✔
363
            if self._is_plan_derived_from(local_object, cast(AbstractPlan, remote_object), self.local_database):
1✔
364
                return local_object
1✔
365
        if isinstance(remote_object, AbstractPlan) and remote_object.is_derivation():
1✔
366
            if self._is_plan_derived_from(remote_object, cast(AbstractPlan, local_object), self.remote_database):
1✔
367
                return remote_object
1✔
368
        if isinstance(local_object, list) and isinstance(remote_object, list):
1✔
369
            local_object.extend(r for r in remote_object if r not in local_object)
×
370
            return local_object
×
371
        if isinstance(local_object, PersistentList) and isinstance(remote_object, PersistentList):
1✔
372
            local_object.extend(r for r in remote_object if r not in local_object)
×
373
            return local_object
×
374

375
        return None
1✔
376

377
    def _is_dataset_derived_from(self, potential_child: Dataset, potential_parent: Dataset, database: Database) -> bool:
4✔
378
        """Check if a dataset is a derivation of another dataset."""
379
        parent = potential_child
1✔
380
        while parent.is_derivation():
1✔
381
            parent = database.get_by_id(cast(Url, parent.derived_from).value)
1✔
382
            if parent.id == potential_parent.id:
1✔
383
                return True
1✔
384

385
        return False
1✔
386

387
    def _is_plan_derived_from(
4✔
388
        self, potential_child: AbstractPlan, potential_parent: AbstractPlan, database: Database
389
    ) -> bool:
390
        """Check if a dataset is a derivation of another dataset."""
391
        parent = potential_child
1✔
392
        while parent.is_derivation():
1✔
393
            parent = database.get_by_id(cast(str, parent.derived_from))
1✔
394
            if parent.id == potential_parent.id:
1✔
395
                return True
1✔
396

397
        return False
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc