Spaces:
Sleeping
Sleeping
| import logging | |
| import json | |
| from typing import Dict | |
| from fastapi import APIRouter, Depends | |
| from fastapi.responses import StreamingResponse | |
| from aworld.cmd import AgentModel, ChatCompletionRequest | |
| from aworld.cmd.utils import agent_loader, agent_executor | |
| from aworld.cmd.web.web_server import get_user_id_from_jwt | |
| import aworld.trace as trace | |
| logger = logging.getLogger(__name__) | |
| router = APIRouter() | |
| prefix = "/api/agent" | |
| async def list_agents() -> Dict[str, AgentModel]: | |
| return agent_loader.list_agents() | |
| async def chat_completion( | |
| form_data: ChatCompletionRequest, user_id: str = Depends(get_user_id_from_jwt) | |
| ) -> StreamingResponse: | |
| # Set user_id from JWT to form_data | |
| form_data.user_id = user_id | |
| async def generate_stream(): | |
| async with trace.span( | |
| "/chat/chat_completion", attributes={"model": form_data.model} | |
| ) as span: | |
| form_data.trace_id = span.get_trace_id() | |
| async for chunk in agent_executor.stream_run(form_data): | |
| yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| }, | |
| ) | |