• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-python / 4145649460

pending completion
4145649460

push

github-actions

GitHub
Merge branch 'develop' into allow-ref-target-for-release-action

25096 of 28903 relevant lines covered (86.83%)

4.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

29.93
/renku/core/session/renkulab.py
1
# -*- coding: utf-8 -*-
2
#
3
# Copyright 2018-2022 - Swiss Data Science Center (SDSC)
4
# A partnership between École Polytechnique Fédérale de Lausanne (EPFL) and
5
# Eidgenössische Technische Hochschule Zürich (ETHZ).
6
#
7
# Licensed under the Apache License, Version 2.0 (the "License");
8
# you may not use this file except in compliance with the License.
9
# You may obtain a copy of the License at
10
#
11
#     http://www.apache.org/licenses/LICENSE-2.0
12
#
13
# Unless required by applicable law or agreed to in writing, software
14
# distributed under the License is distributed on an "AS IS" BASIS,
15
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16
# See the License for the specific language governing permissions and
17
# limitations under the License.
18
"""Docker based interactive session provider."""
10✔
19

20
import urllib
10✔
21
from pathlib import Path
10✔
22
from time import monotonic, sleep
10✔
23
from typing import Any, Dict, List, Optional, Tuple, Union
10✔
24

25
from renku.core import errors
10✔
26
from renku.core.login import read_renku_token
10✔
27
from renku.core.plugin import hookimpl
10✔
28
from renku.core.session.utils import get_renku_project_name, get_renku_url
10✔
29
from renku.core.util import communication, requests
10✔
30
from renku.core.util.git import get_remote
10✔
31
from renku.core.util.jwt import is_token_expired
10✔
32
from renku.domain_model.project_context import project_context
10✔
33
from renku.domain_model.session import ISessionProvider, Session
10✔
34

35

36
class RenkulabSessionProvider(ISessionProvider):
10✔
37
    """A session provider that uses the notebook service API to launch sessions."""
38

39
    DEFAULT_TIMEOUT_SECONDS = 300
10✔
40

41
    def __init__(self):
10✔
42
        self.__renku_url = None
10✔
43
        self.__notebooks_url = None
10✔
44

45
    def _renku_url(self) -> str:
10✔
46
        """Get the URL of the renku instance."""
47
        if not self.__renku_url:
×
48
            renku_url = get_renku_url()
×
49
            if not renku_url:
×
50
                raise errors.RenkulabSessionGetUrlError()
×
51
            self.__renku_url = renku_url
×
52
        return self.__renku_url
×
53

54
    def _notebooks_url(self) -> str:
10✔
55
        """Get the url of the notebooks API."""
56
        if not self.__notebooks_url:
×
57
            url = urllib.parse.urljoin(self._renku_url(), "api/notebooks")
×
58
            self.__notebooks_url = url
×
59
        return self.__notebooks_url
×
60

61
    def _get_token(self) -> str:
10✔
62
        """Get the JWT token used to authenticate against Renku."""
63
        token = read_renku_token(endpoint=self._renku_url())
×
64
        if token is None:
×
65
            raise errors.AuthenticationError("Please run the renku login command to authenticate with Renku.")
×
66
        elif is_token_expired(token):
×
67
            raise errors.AuthenticationError(
×
68
                "Authentication token is expired: Please run the renku login command to authenticate with Renku."
69
            )
70
        return token
×
71

72
    def _auth_header(self) -> Dict[str, str]:
10✔
73
        """Get the authentication header with the JWT token or cookie needed to authenticate with Renku."""
74
        return {"Authorization": f"Bearer {self._get_token()}"}
×
75

76
    @staticmethod
10✔
77
    def _get_renku_project_name_parts():
10✔
78
        repository = project_context.repository
×
79
        if project_context.remote.name and project_context.remote.owner:
×
80
            if get_remote(repository, name="renku-backup-origin") and project_context.remote.owner.startswith("repos/"):
×
81
                owner = project_context.remote.owner.replace("repos/", "", 1)
×
82
            else:
83
                owner = project_context.remote.owner
×
84
            return {
×
85
                "namespace": owner,
86
                "project": project_context.remote.name,
87
            }
88
        else:
89
            # INFO: In this case the owner/name split is not available. The project name is then
90
            # derived from the combined name of the remote and has to be split up in the two parts.
91
            parts = get_renku_project_name().split("/")
×
92
            return {
×
93
                "namespace": "/".join(parts[:-1]),
94
                "project": parts[:-1],
95
            }
96

97
    def _wait_for_session_status(
10✔
98
        self,
99
        name: Optional[str],
100
        status: str,
101
    ):
102
        if not name:
×
103
            return
×
104
        start = monotonic()
×
105
        while monotonic() - start < self.DEFAULT_TIMEOUT_SECONDS:
×
106
            res = self._send_renku_request(
×
107
                "get", f"{self._notebooks_url()}/servers/{name}", headers=self._auth_header()
108
            )
109
            if res.status_code == 404 and status == "stopping":
×
110
                return
×
111
            if res.status_code == 200 and status != "stopping":
×
112
                if res.json().get("status", {}).get("state") == status:
×
113
                    return
×
114
            sleep(5)
×
115
        raise errors.RenkulabSessionError(f"Waiting for the session {name} to reach status {status} timed out.")
×
116

117
    def _wait_for_image(
10✔
118
        self,
119
        image_name: str,
120
        config: Optional[Dict[str, Any]],
121
    ):
122
        """Check if an image exists, and if it does not wait for it to appear.
123

124
        Timeout after a specific period of time.
125
        """
126
        start = monotonic()
×
127
        while monotonic() - start < self.DEFAULT_TIMEOUT_SECONDS:
×
128
            if self.find_image(image_name, config):
×
129
                return
×
130
            sleep(5)
×
131
        raise errors.RenkulabSessionError(
×
132
            f"Waiting for the image {image_name} to be built timed out."
133
            "Are you sure that the image was successfully built? This could be the result "
134
            "of problems with your Dockerfile."
135
        )
136

137
    def pre_start_checks(self):
10✔
138
        """Check if the state of the repository is as expected before starting a session."""
139
        repository = project_context.repository
×
140

141
        if repository.is_dirty(untracked_files=True):
×
142
            communication.confirm(
×
143
                "You have new uncommitted or untracked changes to your repository. "
144
                "Renku can automatically commit these changes so that it builds "
145
                "the correct environment for your session. Do you wish to proceed?",
146
                abort=True,
147
            )
148
            repository.add(all=True)
×
149
            repository.commit("Automated commit by Renku CLI.")
×
150

151
    @staticmethod
10✔
152
    def _remote_head_hexsha():
10✔
153
        remote = get_remote(repository=project_context.repository)
×
154

155
        if remote is None:
×
156
            raise errors.GitRemoteNotFoundError()
×
157

158
        return remote.head
×
159

160
    def _send_renku_request(self, req_type: str, *args, **kwargs):
10✔
161
        res = getattr(requests, req_type)(*args, **kwargs)
×
162
        if res.status_code == 401:
×
163
            # NOTE: Check if logged in to KC but not the Renku UI
164
            token = read_renku_token(endpoint=self._renku_url())
×
165
            if token and not is_token_expired(token):
×
166
                raise errors.AuthenticationError(
×
167
                    f"Please log in the Renku UI at {self._renku_url()} to complete authentication with Renku"
168
                )
169
            raise errors.AuthenticationError(
×
170
                "Please run the renku login command to authenticate with Renku or to refresh your expired credentials."
171
            )
172
        return res
×
173

174
    def get_name(self) -> str:
10✔
175
        """Return session provider's name."""
176
        return "renkulab"
4✔
177

178
    def is_remote_provider(self) -> bool:
10✔
179
        """Return True for remote providers (i.e. not local Docker)."""
180
        return True
2✔
181

182
    def build_image(self, image_descriptor: Path, image_name: str, config: Optional[Dict[str, Any]]):
10✔
183
        """Builds the container image."""
184
        if self.find_image(image_name, config=config):
×
185
            return
×
186
        repository = project_context.repository
×
187
        if repository.head.commit.hexsha != self._remote_head_hexsha():
×
188
            repository.push()
×
189
        self._wait_for_image(image_name=image_name, config=config)
×
190

191
    def find_image(self, image_name: str, config: Optional[Dict[str, Any]]) -> bool:
10✔
192
        """Find the given container image."""
193
        return (
×
194
            self._send_renku_request(
195
                "get",
196
                f"{self._notebooks_url()}/images",
197
                headers=self._auth_header(),
198
                params={"image_url": image_name},
199
            ).status_code
200
            == 200
201
        )
202

203
    @hookimpl
10✔
204
    def session_provider(self) -> ISessionProvider:
10✔
205
        """Supported session provider.
206

207
        Returns:
208
            a reference to ``self``.
209
        """
210
        return self
4✔
211

212
    def session_list(self, project_name: str, config: Optional[Dict[str, Any]]) -> List[Session]:
10✔
213
        """Lists all the sessions currently running by the given session provider.
214

215
        Returns:
216
            list: a list of sessions.
217
        """
218
        sessions_res = self._send_renku_request(
×
219
            "get",
220
            f"{self._notebooks_url()}/servers",
221
            headers=self._auth_header(),
222
            params=self._get_renku_project_name_parts(),
223
        )
224
        if sessions_res.status_code == 200:
×
225
            return [
×
226
                Session(
227
                    session["name"],
228
                    session.get("status", {}).get("state", "unknown"),
229
                    self.session_url(session["name"]),
230
                )
231
                for session in sessions_res.json().get("servers", {}).values()
232
            ]
233
        return []
×
234

235
    def session_start(
10✔
236
        self,
237
        image_name: str,
238
        project_name: str,
239
        config: Optional[Dict[str, Any]],
240
        cpu_request: Optional[float] = None,
241
        mem_request: Optional[str] = None,
242
        disk_request: Optional[str] = None,
243
        gpu_request: Optional[str] = None,
244
    ) -> Tuple[str, str]:
245
        """Creates an interactive session.
246

247
        Returns:
248
            Tuple[str, str]: Provider message and a possible warning message.
249
        """
250
        repository = project_context.repository
×
251

252
        session_commit = repository.head.commit.hexsha
×
253
        if repository.head.commit.hexsha != self._remote_head_hexsha():
×
254
            # INFO: The user is registered, the image is pinned or already available
255
            # but the local repository is not fully in sync with the remote
256
            communication.confirm(
×
257
                "You have unpushed commits that will not be present in your session. "
258
                "Renku can automatically push these commits so that they are present "
259
                "in the session you are launching. Do you wish to proceed?",
260
                abort=True,
261
            )
262
            repository.push()
×
263

264
        server_options: Dict[str, Union[str, float]] = {}
×
265
        if cpu_request:
×
266
            server_options["cpu_request"] = cpu_request
×
267
        if mem_request:
×
268
            server_options["mem_request"] = mem_request
×
269
        if gpu_request:
×
270
            server_options["gpu_request"] = int(gpu_request)
×
271
        if disk_request:
×
272
            server_options["disk_request"] = disk_request
×
273
        payload = {
×
274
            "image": image_name,
275
            "commit_sha": session_commit,
276
            "serverOptions": server_options,
277
            **self._get_renku_project_name_parts(),
278
        }
279
        res = self._send_renku_request(
×
280
            "post",
281
            f"{self._notebooks_url()}/servers",
282
            headers=self._auth_header(),
283
            json=payload,
284
        )
285
        if res.status_code in [200, 201]:
×
286
            session_name = res.json()["name"]
×
287
            self._wait_for_session_status(session_name, "running")
×
288
            return f"Session {session_name} successfully started", ""
×
289
        raise errors.RenkulabSessionError("Cannot start session via the notebook service because " + res.text)
×
290

291
    def session_stop(self, project_name: str, session_name: Optional[str], stop_all: bool) -> bool:
10✔
292
        """Stops all sessions (for the given project) or a specific interactive session."""
293
        responses = []
×
294
        if stop_all:
×
295
            sessions = self.session_list(project_name=project_name, config=None)
×
296
            for session in sessions:
×
297
                responses.append(
×
298
                    self._send_renku_request(
299
                        "delete", f"{self._notebooks_url()}/servers/{session.id}", headers=self._auth_header()
300
                    )
301
                )
302
                self._wait_for_session_status(session.id, "stopping")
×
303
        else:
304
            responses.append(
×
305
                self._send_renku_request(
306
                    "delete", f"{self._notebooks_url()}/servers/{session_name}", headers=self._auth_header()
307
                )
308
            )
309
            self._wait_for_session_status(session_name, "stopping")
×
310
        return all([response.status_code == 204 for response in responses]) if responses else False
×
311

312
    def session_url(self, session_name: str) -> str:
10✔
313
        """Get the URL of the interactive session."""
314
        project_name_parts = self._get_renku_project_name_parts()
×
315
        session_url_parts = [
×
316
            "projects",
317
            project_name_parts["namespace"],
318
            project_name_parts["project"],
319
            "sessions/show",
320
            session_name,
321
        ]
322
        return urllib.parse.urljoin(self._renku_url(), "/".join(session_url_parts))
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc