• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

run-llama / llama_deploy / 15192340923

22 May 2025 04:52PM UTC coverage: 83.388% (+0.04%) from 83.349%
15192340923

Pull #507

github

web-flow
Merge 41c9c3086 into 31ec7f8eb
Pull Request #507: chore: remove orchestrators

12 of 19 new or added lines in 1 file covered. (63.16%)

1 existing line in 1 file now uncovered.

2555 of 3064 relevant lines covered (83.39%)

1.67 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

53.41
/llama_deploy/control_plane/server.py
1
import asyncio
2✔
2
import json
2✔
3
import uuid
2✔
4
from logging import getLogger
2✔
5
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
2✔
6

7
import uvicorn
2✔
8
from fastapi import FastAPI, HTTPException
2✔
9
from fastapi.middleware.cors import CORSMiddleware
2✔
10
from fastapi.responses import StreamingResponse
2✔
11
from llama_index.core.storage.kvstore import SimpleKVStore
2✔
12
from llama_index.core.storage.kvstore.types import BaseKVStore
2✔
13

14
from llama_deploy.control_plane.base import BaseControlPlane
2✔
15
from llama_deploy.message_consumers.base import (
2✔
16
    BaseMessageQueueConsumer,
17
    StartConsumingCallable,
18
)
19
from llama_deploy.message_consumers.remote import RemoteMessageConsumer
2✔
20
from llama_deploy.message_queues.base import AbstractMessageQueue, PublishCallback
2✔
21
from llama_deploy.messages.base import QueueMessage
2✔
22
from llama_deploy.types import (
2✔
23
    ActionTypes,
24
    EventDefinition,
25
    ServiceDefinition,
26
    SessionDefinition,
27
    TaskDefinition,
28
    TaskResult,
29
    TaskStream,
30
)
31

32
from .config import ControlPlaneConfig, parse_state_store_uri
2✔
33
from .utils import get_result_key, get_stream_key
2✔
34

35
logger = getLogger(__name__)
2✔
36

37
CONTROL_PLANE_MESSAGE_TYPE = "control_plane"
2✔
38

39

40
class ControlPlaneServer(BaseControlPlane):
2✔
41
    """Control plane server.
42

43
    The control plane is responsible for managing the state of the system, including:
44
    - Registering services.
45
    - Submitting tasks.
46
    - Managing task state.
47
    - Handling service completion.
48
    - Launching the control plane server.
49

50
    Args:
51
        message_queue (AbstractMessageQueue): Message queue for the system.
52
        orchestrator (BaseOrchestrator): Orchestrator for the system.
53
        publish_callback (Optional[PublishCallback], optional): Callback for publishing messages. Defaults to None.
54
        state_store (Optional[BaseKVStore], optional): State store for the system. Defaults to None.
55

56
    Examples:
57
        ```python
58
        from llama_deploy import ControlPlaneServer
59
        from llama_deploy import SimpleMessageQueue, SimpleOrchestrator
60
        from llama_index.llms.openai import OpenAI
61

62
        control_plane = ControlPlaneServer(
63
            SimpleMessageQueue(),
64
            SimpleOrchestrator(),
65
        )
66
        ```
67
    """
68

69
    def __init__(
2✔
70
        self,
71
        message_queue: AbstractMessageQueue,
72
        publish_callback: PublishCallback | None = None,
73
        state_store: BaseKVStore | None = None,
74
        config: ControlPlaneConfig | None = None,
75
    ) -> None:
76
        self._config = config or ControlPlaneConfig()
2✔
77

78
        if state_store is not None and self._config.state_store_uri is not None:
2✔
79
            raise ValueError("Please use either 'state_store' or 'state_store_uri'.")
2✔
80

81
        if state_store:
2✔
82
            self._state_store = state_store
2✔
83
        elif self._config.state_store_uri:
2✔
84
            self._state_store = parse_state_store_uri(self._config.state_store_uri)
2✔
85
        else:
86
            self._state_store = state_store or SimpleKVStore()
2✔
87

88
        self._message_queue = message_queue
2✔
89
        self._publisher_id = f"{self.__class__.__qualname__}-{uuid.uuid4()}"
2✔
90
        self._publish_callback = publish_callback
2✔
91

92
        self.app = FastAPI()
2✔
93
        if self._config.cors_origins:
2✔
94
            self.app.add_middleware(
2✔
95
                CORSMiddleware,
96
                allow_origins=self._config.cors_origins,
97
                allow_methods=["*"],
98
                allow_headers=["*"],
99
            )
100
        self.app.add_api_route("/", self.home, methods=["GET"], tags=["Control Plane"])
2✔
101
        self.app.add_api_route(
2✔
102
            "/process_message",
103
            self.process_message,
104
            methods=["POST"],
105
            tags=["Control Plane"],
106
        )
107
        self.app.add_api_route(
2✔
108
            "/queue_config",
109
            self.get_message_queue_config,
110
            methods=["GET"],
111
            tags=["Message Queue"],
112
        )
113

114
        self.app.add_api_route(
2✔
115
            "/services/register",
116
            self.register_service,
117
            methods=["POST"],
118
            tags=["Services"],
119
        )
120
        self.app.add_api_route(
2✔
121
            "/services/deregister",
122
            self.deregister_service,
123
            methods=["POST"],
124
            tags=["Services"],
125
        )
126
        self.app.add_api_route(
2✔
127
            "/services/{service_name}",
128
            self.get_service,
129
            methods=["GET"],
130
            tags=["Services"],
131
        )
132
        self.app.add_api_route(
2✔
133
            "/services",
134
            self.get_all_services,
135
            methods=["GET"],
136
            tags=["Services"],
137
        )
138

139
        self.app.add_api_route(
2✔
140
            "/sessions/{session_id}",
141
            self.get_session,
142
            methods=["GET"],
143
            tags=["Sessions"],
144
        )
145
        self.app.add_api_route(
2✔
146
            "/sessions/create",
147
            self.create_session,
148
            methods=["POST"],
149
            tags=["Sessions"],
150
        )
151
        self.app.add_api_route(
2✔
152
            "/sessions/{session_id}/delete",
153
            self.delete_session,
154
            methods=["POST"],
155
            tags=["Sessions"],
156
        )
157
        self.app.add_api_route(
2✔
158
            "/sessions/{session_id}/tasks",
159
            self.add_task_to_session,
160
            methods=["POST"],
161
            tags=["Sessions"],
162
        )
163
        self.app.add_api_route(
2✔
164
            "/sessions",
165
            self.get_all_sessions,
166
            methods=["GET"],
167
            tags=["Sessions"],
168
        )
169
        self.app.add_api_route(
2✔
170
            "/sessions/{session_id}/tasks",
171
            self.get_session_tasks,
172
            methods=["GET"],
173
            tags=["Sessions"],
174
        )
175
        self.app.add_api_route(
2✔
176
            "/sessions/{session_id}/current_task",
177
            self.get_current_task,
178
            methods=["GET"],
179
            tags=["Sessions"],
180
        )
181
        self.app.add_api_route(
2✔
182
            "/sessions/{session_id}/tasks/{task_id}/result",
183
            self.get_task_result,
184
            methods=["GET"],
185
            tags=["Sessions"],
186
        )
187
        self.app.add_api_route(
2✔
188
            "/sessions/{session_id}/tasks/{task_id}/result_stream",
189
            self.get_task_result_stream,
190
            methods=["GET"],
191
            tags=["Sessions"],
192
        )
193
        self.app.add_api_route(
2✔
194
            "/sessions/{session_id}/tasks/{task_id}/send_event",
195
            self.send_event,
196
            methods=["POST"],
197
            tags=["Sessions"],
198
        )
199
        self.app.add_api_route(
2✔
200
            "/sessions/{session_id}/state",
201
            self.get_session_state,
202
            methods=["GET"],
203
            tags=["Sessions"],
204
        )
205
        self.app.add_api_route(
2✔
206
            "/sessions/{session_id}/state",
207
            self.update_session_state,
208
            methods=["POST"],
209
            tags=["Sessions"],
210
        )
211

212
    @property
2✔
213
    def message_queue(self) -> AbstractMessageQueue:
2✔
214
        return self._message_queue
2✔
215

216
    @property
2✔
217
    def publisher_id(self) -> str:
2✔
218
        return self._publisher_id
2✔
219

220
    @property
2✔
221
    def publish_callback(self) -> Optional[PublishCallback]:
2✔
222
        return self._publish_callback
2✔
223

224
    async def process_message(self, message: QueueMessage) -> None:
2✔
225
        if not message.data:
2✔
226
            raise ValueError(f"Invalid field 'data' in QueueMessage: {message.data}")
2✔
227

228
        action = message.action
2✔
229
        if action == ActionTypes.NEW_TASK:
2✔
230
            task_def = TaskDefinition(**message.data)
2✔
231
            if task_def.session_id is None:
2✔
232
                task_def.session_id = await self.create_session()
2✔
233

234
            await self.add_task_to_session(task_def.session_id, task_def)
2✔
235
        elif action == ActionTypes.COMPLETED_TASK:
2✔
236
            await self.handle_service_completion(TaskResult(**message.data))
2✔
237
        elif action == ActionTypes.TASK_STREAM:
2✔
238
            await self.add_stream_to_session(TaskStream(**message.data))
2✔
239
        else:
240
            raise ValueError(f"Action {action} not supported by control plane")
2✔
241

242
    def as_consumer(self) -> BaseMessageQueueConsumer:
2✔
243
        return RemoteMessageConsumer(
×
244
            id_=self.publisher_id,
245
            url=f"{self._config.url}/process_message",
246
            message_type=CONTROL_PLANE_MESSAGE_TYPE,
247
        )
248

249
    async def launch_server(self) -> None:
2✔
250
        # give precedence to external settings
251
        host = self._config.internal_host or self._config.host
×
252
        port = self._config.internal_port or self._config.port
×
253
        logger.info(f"Launching control plane server at {host}:{port}")
×
254

255
        class CustomServer(uvicorn.Server):
×
256
            def install_signal_handlers(self) -> None:
×
257
                pass
×
258

259
        cfg = uvicorn.Config(self.app, host=host, port=port)
×
260
        server = CustomServer(cfg)
×
261
        try:
×
262
            await server.serve()
×
263
        except asyncio.CancelledError:
×
264
            self._running = False
×
265
            await asyncio.gather(server.shutdown(), return_exceptions=True)
×
266

267
    async def home(self) -> Dict[str, str]:
2✔
268
        return {
×
269
            "running": str(self._config.running),
270
            "step_interval": str(self._config.step_interval),
271
            "services_store_key": self._config.services_store_key,
272
            "tasks_store_key": self._config.tasks_store_key,
273
            "session_store_key": self._config.session_store_key,
274
        }
275

276
    async def register_service(
2✔
277
        self, service_def: ServiceDefinition
278
    ) -> ControlPlaneConfig:
279
        await self._state_store.aput(
×
280
            service_def.service_name,
281
            service_def.model_dump(),
282
            collection=self._config.services_store_key,
283
        )
284
        return self._config
×
285

286
    async def deregister_service(self, service_name: str) -> None:
2✔
287
        await self._state_store.adelete(
×
288
            service_name, collection=self._config.services_store_key
289
        )
290

291
    async def get_service(self, service_name: str) -> ServiceDefinition:
2✔
292
        service_dict = await self._state_store.aget(
×
293
            service_name, collection=self._config.services_store_key
294
        )
295
        if service_dict is None:
×
296
            raise HTTPException(status_code=404, detail="Service not found")
×
297

298
        return ServiceDefinition.model_validate(service_dict)
×
299

300
    async def get_all_services(self) -> Dict[str, ServiceDefinition]:
2✔
301
        service_dicts = await self._state_store.aget_all(
×
302
            collection=self._config.services_store_key
303
        )
304

305
        return {
×
306
            service_name: ServiceDefinition.model_validate(service_dict)
307
            for service_name, service_dict in service_dicts.items()
308
        }
309

310
    async def create_session(self) -> str:
2✔
311
        session = SessionDefinition()
×
312
        await self._state_store.aput(
×
313
            session.session_id,
314
            session.model_dump(),
315
            collection=self._config.session_store_key,
316
        )
317

318
        return session.session_id
×
319

320
    async def get_session(self, session_id: str) -> SessionDefinition:
2✔
321
        session_dict = await self._state_store.aget(
2✔
322
            session_id, collection=self._config.session_store_key
323
        )
324
        if session_dict is None:
2✔
325
            raise HTTPException(status_code=404, detail="Session not found")
×
326

327
        return SessionDefinition.model_validate(session_dict)
2✔
328

329
    async def delete_session(self, session_id: str) -> None:
2✔
330
        await self._state_store.adelete(
×
331
            session_id, collection=self._config.session_store_key
332
        )
333

334
    async def get_all_sessions(self) -> Dict[str, SessionDefinition]:
2✔
335
        session_dicts = await self._state_store.aget_all(
×
336
            collection=self._config.session_store_key
337
        )
338

339
        return {
×
340
            session_id: SessionDefinition.model_validate(session_dict)
341
            for session_id, session_dict in session_dicts.items()
342
        }
343

344
    async def get_session_tasks(self, session_id: str) -> List[TaskDefinition]:
2✔
345
        session = await self.get_session(session_id)
×
346
        task_defs = []
×
347
        for task_id in session.task_ids:
×
348
            task_defs.append(await self.get_task(task_id))
×
349
        return task_defs
×
350

351
    async def get_current_task(self, session_id: str) -> Optional[TaskDefinition]:
2✔
352
        session = await self.get_session(session_id)
×
353
        if len(session.task_ids) == 0:
×
354
            return None
×
355
        return await self.get_task(session.task_ids[-1])
×
356

357
    async def add_task_to_session(
2✔
358
        self, session_id: str, task_def: TaskDefinition
359
    ) -> str:
360
        session_dict = await self._state_store.aget(
2✔
361
            session_id, collection=self._config.session_store_key
362
        )
363
        if session_dict is None:
2✔
364
            raise HTTPException(status_code=404, detail="Session not found")
2✔
365

366
        if not task_def.session_id:
2✔
367
            task_def.session_id = session_id
2✔
368

369
        if task_def.session_id != session_id:
2✔
370
            msg = f"Wrong task definition: task.session_id is {task_def.session_id} but should be {session_id}"
2✔
371
            raise HTTPException(status_code=400, detail=msg)
2✔
372

373
        session = SessionDefinition(**session_dict)
2✔
374
        session.task_ids.append(task_def.task_id)
2✔
375
        await self._state_store.aput(
2✔
376
            session_id, session.model_dump(), collection=self._config.session_store_key
377
        )
378

379
        await self._state_store.aput(
2✔
380
            task_def.task_id,
381
            task_def.model_dump(),
382
            collection=self._config.tasks_store_key,
383
        )
384

385
        task_def = await self.send_task_to_service(task_def)
2✔
386

387
        return task_def.task_id
2✔
388

389
    async def send_task_to_service(self, task_def: TaskDefinition) -> TaskDefinition:
2✔
390
        if task_def.session_id is None:
2✔
391
            raise ValueError(f"Task with id {task_def.task_id} has no session")
×
392

393
        session = await self.get_session(task_def.session_id)
2✔
394

395
        next_messages, session_state = await self.get_next_messages(
2✔
396
            task_def, session.state
397
        )
398

399
        logger.debug(f"Sending task {task_def.task_id} to services: {next_messages}")
2✔
400

401
        for message in next_messages:
2✔
402
            await self.publish(message)
2✔
403

404
        session.state.update(session_state)
2✔
405

406
        await self._state_store.aput(
2✔
407
            task_def.session_id,
408
            session.model_dump(),
409
            collection=self._config.session_store_key,
410
        )
411

412
        return task_def
2✔
413

414
    async def handle_service_completion(
2✔
415
        self,
416
        task_result: TaskResult,
417
    ) -> None:
418
        # add result to task state
419
        task_def = await self.get_task(task_result.task_id)
×
420
        if task_def.session_id is None:
×
421
            raise ValueError(f"Task with id {task_result.task_id} has no session")
×
422

423
        session = await self.get_session(task_def.session_id)
×
NEW
424
        state = await self.add_result_to_state(task_result, session.state)
×
425

426
        # update session state
427
        session.state.update(state)
×
428
        await self._state_store.aput(
×
429
            session.session_id,
430
            session.model_dump(),
431
            collection=self._config.session_store_key,
432
        )
433

434
        # generate and send new tasks when needed
435
        task_def = await self.send_task_to_service(task_def)
×
436

437
        await self._state_store.aput(
×
438
            task_def.task_id,
439
            task_def.model_dump(),
440
            collection=self._config.tasks_store_key,
441
        )
442

443
    async def get_task(self, task_id: str) -> TaskDefinition:
2✔
444
        state_dict = await self._state_store.aget(
×
445
            task_id, collection=self._config.tasks_store_key
446
        )
447
        if state_dict is None:
×
448
            raise HTTPException(status_code=404, detail="Task not found")
×
449

450
        return TaskDefinition(**state_dict)
×
451

452
    async def get_task_result(
2✔
453
        self, task_id: str, session_id: str
454
    ) -> Optional[TaskResult]:
455
        """Get the result of a task if it has one.
456

457
        Args:
458
            task_id (str): The ID of the task to get the result for.
459
            session_id (str): The ID of the session the task belongs to.
460

461
        Returns:
462
            Optional[TaskResult]: The result of the task if it has one, otherwise None.
463
        """
464
        session = await self.get_session(session_id)
×
465

466
        result_key = get_result_key(task_id)
×
467
        if result_key not in session.state:
×
468
            return None
×
469

470
        result = session.state[result_key]
×
471
        if not isinstance(result, TaskResult):
×
472
            if isinstance(result, dict):
×
473
                result = TaskResult(**result)
×
474
            elif isinstance(result, str):
×
475
                result = TaskResult(**json.loads(result))
×
476
            else:
477
                raise HTTPException(status_code=500, detail="Unexpected result type")
×
478

479
        # sanity check
480
        if result.task_id != task_id:
×
481
            logger.debug(
×
482
                f"Retrieved result did not match requested task_id: {str(result)}"
483
            )
484
            return None
×
485

486
        return result
×
487

488
    async def add_stream_to_session(self, task_stream: TaskStream) -> None:
2✔
489
        # get session
490
        if task_stream.session_id is None:
×
491
            raise ValueError(
×
492
                f"Task stream with id {task_stream.task_id} has no session"
493
            )
494

495
        session = await self.get_session(task_stream.session_id)
×
496

497
        # add new stream data to session state
498
        existing_stream = session.state.get(get_stream_key(task_stream.task_id), [])
×
499
        existing_stream.append(task_stream.model_dump())
×
500
        session.state[get_stream_key(task_stream.task_id)] = existing_stream
×
501

502
        # update session state in store
503
        await self._state_store.aput(
×
504
            task_stream.session_id,
505
            session.model_dump(),
506
            collection=self._config.session_store_key,
507
        )
508

509
    async def get_task_result_stream(
2✔
510
        self, session_id: str, task_id: str
511
    ) -> StreamingResponse:
512
        session = await self.get_session(session_id)
×
513

514
        stream_key = get_stream_key(task_id)
×
515
        if stream_key not in session.state:
×
516
            raise HTTPException(status_code=404, detail="Task stream not found")
×
517

518
        async def event_generator(
×
519
            session: SessionDefinition, stream_key: str
520
        ) -> AsyncGenerator[str, None]:
521
            try:
×
522
                last_index = 0
×
523
                while True:
×
524
                    session = await self.get_session(session_id)
×
525
                    stream_results = session.state[stream_key][last_index:]
×
526
                    stream_results = sorted(stream_results, key=lambda x: x["index"])
×
527
                    for result in stream_results:
×
528
                        if not isinstance(result, TaskStream):
×
529
                            if isinstance(result, dict):
×
530
                                result = TaskStream(**result)
×
531
                            elif isinstance(result, str):
×
532
                                result = TaskStream(**json.loads(result))
×
533
                            else:
534
                                raise ValueError("Unexpected result type in stream")
×
535

536
                        yield json.dumps(result.data) + "\n"
×
537

538
                    # check if there is a final result
539
                    final_result = await self.get_task_result(task_id, session_id)
×
540
                    if final_result is not None:
×
541
                        return
×
542

543
                    last_index += len(stream_results)
×
544
                    # Small delay to prevent tight loop
545
                    await asyncio.sleep(self._config.step_interval)
×
546
            except Exception as e:
×
547
                logger.error(
×
548
                    f"Error in event stream for session {session_id}, task {task_id}: {str(e)}"
549
                )
550
                yield json.dumps({"error": str(e)}) + "\n"
×
551

552
        return StreamingResponse(
×
553
            event_generator(session, stream_key),
554
            media_type="application/x-ndjson",
555
        )
556

557
    async def send_event(
2✔
558
        self,
559
        session_id: str,
560
        task_id: str,
561
        event_def: EventDefinition,
562
    ) -> None:
563
        task_def = TaskDefinition(
×
564
            task_id=task_id,
565
            session_id=session_id,
566
            input=event_def.event_obj_str,
567
            service_id=event_def.service_id,
568
        )
569
        message = QueueMessage(
×
570
            type=event_def.service_id,
571
            action=ActionTypes.SEND_EVENT,
572
            data=task_def.model_dump(),
573
        )
574
        await self.publish(message)
×
575

576
    async def get_session_state(self, session_id: str) -> Dict[str, Any]:
2✔
577
        session = await self.get_session(session_id)
×
578
        if session.task_ids is None:
×
579
            raise HTTPException(status_code=404, detail="Session not found")
×
580

581
        return session.state
×
582

583
    async def update_session_state(
2✔
584
        self, session_id: str, state: Dict[str, Any]
585
    ) -> None:
586
        session = await self.get_session(session_id)
×
587

588
        session.state.update(state)
×
589
        await self._state_store.aput(
×
590
            session_id, session.model_dump(), collection=self._config.session_store_key
591
        )
592

593
    async def get_message_queue_config(self) -> Dict[str, dict]:
2✔
594
        """
595
        Gets the config dict for the message queue being used.
596

597
        Returns:
598
            Dict[str, dict]: A dict of message queue name -> config dict
599
        """
600
        queue_config = self._message_queue.as_config()
×
601
        return {queue_config.__class__.__name__: queue_config.model_dump()}
×
602

603
    async def register_to_message_queue(self) -> StartConsumingCallable:
2✔
604
        return await self.message_queue.register_consumer(
×
605
            self.as_consumer(), topic=self.get_topic(CONTROL_PLANE_MESSAGE_TYPE)
606
        )
607

608
    def get_topic(self, msg_type: str) -> str:
2✔
609
        return f"{self._config.topic_namespace}.{msg_type}"
2✔
610

611
    async def get_next_messages(
2✔
612
        self, task_def: TaskDefinition, state: Dict[str, Any]
613
    ) -> Tuple[List[QueueMessage], Dict[str, Any]]:
614
        """Get the next message to process. Returns the message and the new state.
615

616
        Assumes the service_id is the destination for the next message.
617

618
        Runs the required service, then sends the result to the final message type.
619
        """
620
        if task_def.service_id is None:
2✔
NEW
621
            raise ValueError(
×
622
                "Task definition must have an service_id specified to identify a service"
623
            )
624

625
        if task_def.task_id not in state:
2✔
626
            state[task_def.task_id] = {}
2✔
627

628
        if state.get(get_result_key(task_def.task_id)) is not None:
2✔
NEW
629
            return [], state
×
630

631
        destination = task_def.service_id
2✔
632
        destination_messages = [
2✔
633
            QueueMessage(
634
                type=destination,
635
                action=ActionTypes.NEW_TASK,
636
                data=task_def.model_dump(),
637
            )
638
        ]
639

640
        return destination_messages, state
2✔
641

642
    async def add_result_to_state(
2✔
643
        self, result: TaskResult, state: Dict[str, Any]
644
    ) -> Dict[str, Any]:
645
        """Add the result of processing a message to the state. Returns the new state."""
646

647
        # TODO: detect failures + retries
NEW
648
        cur_retries = state.get("retries", -1) + 1
×
NEW
649
        state["retries"] = cur_retries
×
650

651
        # add result to state
NEW
652
        state[get_result_key(result.task_id)] = result
×
653

NEW
654
        return state
×
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2026 Coveralls, Inc