• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

SwissDataScienceCenter / renku-data-services / 17240768421

26 Aug 2025 02:10PM UTC coverage: 87.15% (-0.03%) from 87.181%
17240768421

Pull #995

github

web-flow
Merge 4e66606ab into 86bb5f056
Pull Request #995: feat: upgrade the apispec and generate code

10 of 10 new or added lines in 1 file covered. (100.0%)

11 existing lines in 4 files now uncovered.

21751 of 24958 relevant lines covered (87.15%)

1.53 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

80.0
/components/renku_data_services/k8s_watcher/core.py
1
"""K8s watcher main."""
2

3
from __future__ import annotations
2✔
4

5
import asyncio
2✔
6
import contextlib
2✔
7
from asyncio import CancelledError, Task
2✔
8
from collections.abc import Awaitable, Callable
2✔
9
from datetime import datetime, timedelta
2✔
10
from typing import TYPE_CHECKING
2✔
11

12
from renku_data_services.app_config import logging
2✔
13
from renku_data_services.base_models.core import APIUser, InternalServiceAdmin, ServiceAdminId
2✔
14
from renku_data_services.base_models.metrics import MetricsService
2✔
15
from renku_data_services.crc.db import ResourcePoolRepository
2✔
16
from renku_data_services.k8s.clients import K8sClusterClient
2✔
17
from renku_data_services.k8s.models import GVK, K8sObject, K8sObjectFilter
2✔
18
from renku_data_services.k8s_watcher.db import K8sDbCache
2✔
19
from renku_data_services.notebooks.crs import State
2✔
20

21
logger = logging.getLogger(__name__)
2✔
22

23
if TYPE_CHECKING:
2✔
24
    from renku_data_services.k8s.constants import ClusterId
×
25
    from renku_data_services.k8s.models import APIObjectInCluster, Cluster
×
26

27
type EventHandler = Callable[[APIObjectInCluster, str], Awaitable[None]]
2✔
28
type SyncFunc = Callable[[], Awaitable[None]]
2✔
29

30
k8s_watcher_admin_user = InternalServiceAdmin(id=ServiceAdminId.k8s_watcher)
2✔
31

32

33
class K8sWatcher:
2✔
34
    """Watch k8s events and call the handler with every event."""
35

36
    def __init__(
2✔
37
        self,
38
        handler: EventHandler,
39
        clusters: dict[ClusterId, Cluster],
40
        kinds: list[GVK],
41
        db_cache: K8sDbCache,
42
    ) -> None:
43
        self.__handler = handler
1✔
44
        self.__watch_tasks: dict[ClusterId, list[Task]] = {}
1✔
45
        self.__full_sync_tasks: dict[ClusterId, Task] = {}
1✔
46
        self.__full_sync_times: dict[ClusterId, datetime] = {}
1✔
47
        self.__full_sync_running: set[ClusterId] = set()
1✔
48
        self.__kinds = kinds
1✔
49
        self.__clusters = clusters
1✔
50
        self.__sync_period_seconds = 600
1✔
51
        self.__cache = db_cache
1✔
52

53
    async def __sync(self, cluster: Cluster, kind: GVK) -> None:
2✔
54
        """Upsert K8s objects in the cache and remove deleted objects from the cache."""
55
        clnt = K8sClusterClient(cluster)
1✔
56
        fltr = K8sObjectFilter(gvk=kind, cluster=cluster.id, namespace=cluster.namespace)
1✔
57
        # Upsert new / updated objects
58
        objects_in_k8s: dict[str, K8sObject] = {}
1✔
59
        async for obj in clnt.list(fltr):
1✔
60
            objects_in_k8s[obj.name] = obj
×
61
            await self.__cache.upsert(obj)
×
62
        # Remove objects that have been deleted from k8s but are still in cache
63
        async for cache_obj in self.__cache.list(fltr):
1✔
64
            cache_obj_is_in_k8s = objects_in_k8s.get(cache_obj.name) is not None
×
65
            if cache_obj_is_in_k8s:
×
66
                continue
×
67
            await self.__cache.delete(cache_obj)
×
68

69
    async def __full_sync(self, cluster: Cluster) -> None:
2✔
70
        """Run the full sync if it has never run or at the required interval."""
71
        last_sync = self.__full_sync_times.get(cluster.id)
1✔
72
        since_last_sync = datetime.now() - last_sync if last_sync is not None else None
1✔
73
        if since_last_sync is not None and since_last_sync.total_seconds() < self.__sync_period_seconds:
1✔
74
            return
1✔
75
        self.__full_sync_running.add(cluster.id)
1✔
76
        for kind in self.__kinds:
1✔
77
            logger.info(f"Starting full k8s cache sync for cluster {cluster} and kind {kind}")
1✔
78
            await self.__sync(cluster, kind)
1✔
79
        self.__full_sync_times[cluster.id] = datetime.now()
1✔
80
        self.__full_sync_running.remove(cluster.id)
1✔
81

82
    async def __periodic_full_sync(self, cluster: Cluster) -> None:
2✔
83
        """Keeps trying to run the full sync."""
84
        while True:
1✔
85
            await self.__full_sync(cluster)
1✔
86
            await asyncio.sleep(self.__sync_period_seconds / 10)
1✔
87

88
    async def __watch_kind(self, kind: GVK, cluster: Cluster) -> None:
2✔
89
        while True:
1✔
90
            try:
1✔
91
                watch = cluster.api.async_watch(kind=kind.kr8s_kind, namespace=cluster.namespace)
1✔
92
                async for event_type, obj in watch:
1✔
93
                    while cluster.id in self.__full_sync_running:
1✔
94
                        logger.info(
×
95
                            f"Pausing k8s watch event processing for cluster {cluster} until full sync completes"
96
                        )
97
                        await asyncio.sleep(5)
×
98
                    await self.__handler(cluster.with_api_object(obj), event_type)
1✔
99
                    # in some cases, the kr8s loop above just never yields, especially if there's exceptions which
100
                    # can bypass async scheduling. This sleep here is as a last line of defence so this code does not
101
                    # execute indefinitely and prevent another resource kind from being watched.
102
                    await asyncio.sleep(0)
1✔
103
            except Exception as e:
1✔
104
                logger.error(f"watch loop failed for {kind} in cluster {cluster.id}", exc_info=e)
1✔
105
                # without sleeping, this can just hang the code as exceptions seem to bypass the async scheduler
106
                await asyncio.sleep(1)
1✔
107
                pass
1✔
108

109
    def __run_single(self, cluster: Cluster) -> list[Task]:
2✔
110
        # The loops and error handling here will need some testing and love
111
        tasks = []
1✔
112
        for kind in self.__kinds:
1✔
113
            logger.info(f"watching {kind} in cluster {cluster.id}")
1✔
114
            tasks.append(asyncio.create_task(self.__watch_kind(kind, cluster)))
1✔
115

116
        return tasks
1✔
117

118
    async def start(self) -> None:
2✔
119
        """Start the watcher."""
120
        for cluster in sorted(self.__clusters.values(), key=lambda x: x.id):
1✔
121
            await self.__full_sync(cluster)
1✔
122
            self.__full_sync_tasks[cluster.id] = asyncio.create_task(self.__periodic_full_sync(cluster))
1✔
123
            self.__watch_tasks[cluster.id] = self.__run_single(cluster)
1✔
124

125
    async def wait(self) -> None:
2✔
126
        """Wait for all tasks.
127

128
        This is mainly used to block the main function.
129
        """
130
        all_tasks = list(self.__full_sync_tasks.values())
×
131
        for tasks in self.__watch_tasks.values():
×
132
            all_tasks.extend(tasks)
×
133
        await asyncio.gather(*all_tasks)
×
134

135
    async def stop(self, timeout: timedelta = timedelta(seconds=10)) -> None:
2✔
136
        """Stop the watcher or timeout."""
137

138
        async def stop_task(task: Task, timeout: timedelta) -> None:
1✔
139
            if task.done():
1✔
140
                return
×
141
            task.cancel()
1✔
142
            try:
1✔
143
                async with asyncio.timeout(timeout.total_seconds()):
1✔
144
                    with contextlib.suppress(CancelledError):
1✔
145
                        await task
1✔
146
            except TimeoutError:
×
147
                logger.error("timeout trying to cancel k8s watcher task")
×
148
                return
×
149

150
        for task_list in self.__watch_tasks.values():
1✔
151
            for task in task_list:
1✔
152
                await stop_task(task, timeout)
1✔
153
        for task in self.__full_sync_tasks.values():
1✔
154
            await stop_task(task, timeout)
1✔
155

156

157
async def collect_metrics(
2✔
158
    previous_obj: K8sObject | None,
159
    new_obj: APIObjectInCluster,
160
    event_type: str,
161
    user_id: str,
162
    metrics: MetricsService,
163
    rp_repo: ResourcePoolRepository,
164
) -> None:
165
    """Track product metrics."""
166
    user = APIUser(id=user_id)
1✔
167

168
    if event_type == "DELETED":
1✔
169
        # session stopping
UNCOV
170
        await metrics.session_stopped(user=user, metadata={"session_id": new_obj.meta.name})
×
UNCOV
171
        return
×
172
    previous_state = previous_obj.manifest.get("status", {}).get("state", None) if previous_obj else None
1✔
173
    match new_obj.obj.status.state:
1✔
174
        case State.Running.value if previous_state is None or previous_state == State.NotReady.value:
1✔
175
            # session starting
176
            resource_class_id = int(new_obj.obj.metadata.annotations.get("renku.io/resource_class_id"))
×
177
            resource_pool = await rp_repo.get_resource_pool_from_class(k8s_watcher_admin_user, resource_class_id)
×
178
            resource_class = await rp_repo.get_resource_class(k8s_watcher_admin_user, resource_class_id)
×
179

180
            await metrics.session_started(
×
181
                user=user,
182
                metadata={
183
                    "cpu": int(resource_class.cpu * 1000),
184
                    "memory": resource_class.memory,
185
                    "gpu": resource_class.gpu,
186
                    "storage": new_obj.obj.spec.session.storage.size,
187
                    "resource_class_id": resource_class_id,
188
                    "resource_pool_id": resource_pool.id or "",
189
                    "resource_class_name": f"{resource_pool.name}.{resource_class.name}",
190
                    "session_id": new_obj.meta.name,
191
                },
192
            )
193
        case State.Running.value | State.NotReady.value if previous_state == State.Hibernated.value:
1✔
194
            # session resumed
195
            await metrics.session_resumed(user, metadata={"session_id": new_obj.meta.name})
×
196
        case State.Hibernated.value if previous_state != State.Hibernated.value:
1✔
197
            # session hibernated
198
            await metrics.session_hibernated(user=user, metadata={"session_id": new_obj.meta.name})
×
199
        case _:
1✔
200
            pass
1✔
201

202

203
def k8s_object_handler(cache: K8sDbCache, metrics: MetricsService, rp_repo: ResourcePoolRepository) -> EventHandler:
2✔
204
    """Listens and to k8s events and updates the cache."""
205

206
    async def handler(obj: APIObjectInCluster, event_type: str) -> None:
1✔
207
        existing = await cache.get(obj.meta)
1✔
208
        if obj.user_id is not None:
1✔
209
            try:
1✔
210
                await collect_metrics(existing, obj, event_type, obj.user_id, metrics, rp_repo)
1✔
211
            except Exception as e:
1✔
212
                logger.error("failed to track product metrics", exc_info=e)
1✔
213
        if event_type == "DELETED":
1✔
UNCOV
214
            await cache.delete(obj.meta)
×
UNCOV
215
            return
×
216
        k8s_object = obj.to_k8s_object()
1✔
217
        k8s_object.user_id = obj.user_id
1✔
218
        await cache.upsert(k8s_object)
1✔
219

220
    return handler
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc