Spaces:
Sleeping
Sleeping
Upload 4 files
Browse files
aworld/cmd/web/routers/chats.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import json
|
3 |
+
from typing import Dict
|
4 |
+
from fastapi import APIRouter, Depends
|
5 |
+
from fastapi.responses import StreamingResponse
|
6 |
+
from aworld.cmd import AgentModel, ChatCompletionRequest
|
7 |
+
from aworld.cmd.utils import agent_loader, agent_executor
|
8 |
+
from aworld.cmd.web.web_server import get_user_id_from_jwt
|
9 |
+
import aworld.trace as trace
|
10 |
+
|
11 |
+
logger = logging.getLogger(__name__)
|
12 |
+
|
13 |
+
router = APIRouter()
|
14 |
+
|
15 |
+
prefix = "/api/agent"
|
16 |
+
|
17 |
+
|
18 |
+
@router.get("/list")
|
19 |
+
@router.get("/models")
|
20 |
+
async def list_agents() -> Dict[str, AgentModel]:
|
21 |
+
return agent_loader.list_agents()
|
22 |
+
|
23 |
+
|
24 |
+
@router.post("/chat/completions")
|
25 |
+
async def chat_completion(
|
26 |
+
form_data: ChatCompletionRequest, user_id: str = Depends(get_user_id_from_jwt)
|
27 |
+
) -> StreamingResponse:
|
28 |
+
# Set user_id from JWT to form_data
|
29 |
+
form_data.user_id = user_id
|
30 |
+
|
31 |
+
async def generate_stream():
|
32 |
+
async with trace.span(
|
33 |
+
"/chat/chat_completion", attributes={"model": form_data.model}
|
34 |
+
) as span:
|
35 |
+
form_data.trace_id = span.get_trace_id()
|
36 |
+
async for chunk in agent_executor.stream_run(form_data):
|
37 |
+
yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
|
38 |
+
|
39 |
+
return StreamingResponse(
|
40 |
+
generate_stream(),
|
41 |
+
media_type="text/event-stream",
|
42 |
+
headers={
|
43 |
+
"Cache-Control": "no-cache",
|
44 |
+
"Connection": "keep-alive",
|
45 |
+
},
|
46 |
+
)
|
aworld/cmd/web/routers/sessions.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List
|
3 |
+
from fastapi import APIRouter, Depends
|
4 |
+
from pydantic import BaseModel, Field
|
5 |
+
from aworld.cmd import SessionModel
|
6 |
+
from aworld.cmd.utils.agent_server import CURRENT_SERVER
|
7 |
+
from aworld.cmd.web.web_server import get_user_id_from_jwt
|
8 |
+
|
9 |
+
logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
router = APIRouter()
|
12 |
+
|
13 |
+
prefix = "/api/session"
|
14 |
+
|
15 |
+
|
16 |
+
@router.get("/list")
|
17 |
+
async def list_sessions(
|
18 |
+
user_id: str = Depends(get_user_id_from_jwt),
|
19 |
+
) -> List[SessionModel]:
|
20 |
+
return await CURRENT_SERVER.get_session_service().list_sessions(user_id)
|
21 |
+
|
22 |
+
|
23 |
+
class CommonResponse(BaseModel):
|
24 |
+
code: int = Field(..., description="The code")
|
25 |
+
message: str = Field(..., description="The message")
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def success(message: str = "success"):
|
29 |
+
return CommonResponse(code=0, message=message)
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
def error(message: str):
|
33 |
+
return CommonResponse(code=1, message=message)
|
34 |
+
|
35 |
+
|
36 |
+
class DeleteSessionRequest(BaseModel):
|
37 |
+
session_id: str = Field(..., description="The session id")
|
38 |
+
|
39 |
+
|
40 |
+
@router.post("/delete")
|
41 |
+
async def delete_session(
|
42 |
+
request: DeleteSessionRequest, user_id: str = Depends(get_user_id_from_jwt)
|
43 |
+
) -> CommonResponse:
|
44 |
+
try:
|
45 |
+
await CURRENT_SERVER.get_session_service().delete_session(
|
46 |
+
user_id, request.session_id
|
47 |
+
)
|
48 |
+
return CommonResponse.success()
|
49 |
+
except Exception as e:
|
50 |
+
return CommonResponse.error(str(e))
|
aworld/cmd/web/routers/traces.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from fastapi import APIRouter
|
3 |
+
from aworld.trace.server import get_trace_server
|
4 |
+
from aworld.trace.constants import RunType
|
5 |
+
from aworld.trace.server.util import build_trace_tree
|
6 |
+
|
7 |
+
logger = logging.getLogger(__name__)
|
8 |
+
|
9 |
+
router = APIRouter()
|
10 |
+
|
11 |
+
prefix = "/api/trace"
|
12 |
+
|
13 |
+
|
14 |
+
@router.get("/list")
|
15 |
+
async def list_traces():
|
16 |
+
storage = get_trace_server().get_storage()
|
17 |
+
trace_data = []
|
18 |
+
for trace_id in storage.get_all_traces():
|
19 |
+
spans = storage.get_all_spans(trace_id)
|
20 |
+
spans_sorted = sorted(spans, key=lambda x: x.start_time)
|
21 |
+
trace_tree = build_trace_tree(spans_sorted)
|
22 |
+
trace_data.append({
|
23 |
+
'trace_id': trace_id,
|
24 |
+
'root_span': trace_tree,
|
25 |
+
})
|
26 |
+
return {
|
27 |
+
"data": trace_data
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
@router.get("/agent")
|
32 |
+
async def get_agent_trace(trace_id: str):
|
33 |
+
storage = get_trace_server().get_storage()
|
34 |
+
spans = storage.get_all_spans(trace_id)
|
35 |
+
spans_dict = {span.span_id: span.dict() for span in spans}
|
36 |
+
|
37 |
+
filtered_spans = {}
|
38 |
+
for span_id, span in spans_dict.items():
|
39 |
+
if span.get('is_event', False) and span.get('run_type') == RunType.AGNET.value:
|
40 |
+
span['show_name'] = _get_agent_show_name(span)
|
41 |
+
filtered_spans[span_id] = span
|
42 |
+
|
43 |
+
for span in list(filtered_spans.values()):
|
44 |
+
parent_id = span['parent_id'] if span['parent_id'] else None
|
45 |
+
|
46 |
+
while parent_id and parent_id not in filtered_spans:
|
47 |
+
parent_span = spans_dict.get(parent_id)
|
48 |
+
parent_id = parent_span['parent_id'] if parent_span and parent_span['parent_id'] else None
|
49 |
+
|
50 |
+
if parent_id:
|
51 |
+
parent_span = filtered_spans.get(parent_id)
|
52 |
+
if not parent_span:
|
53 |
+
continue
|
54 |
+
|
55 |
+
if 'children' not in parent_span:
|
56 |
+
parent_span['children'] = []
|
57 |
+
parent_span['children'].append(span)
|
58 |
+
|
59 |
+
root_spans = [span for span in filtered_spans.values()
|
60 |
+
if span['parent_id'] is None or span['parent_id'] not in filtered_spans]
|
61 |
+
return {
|
62 |
+
"data": root_spans
|
63 |
+
}
|
64 |
+
|
65 |
+
|
66 |
+
def _get_agent_show_name(span: dict):
|
67 |
+
agent_name_prefix = "agent_event_"
|
68 |
+
name = span.get("name")
|
69 |
+
if name and name.startswith(agent_name_prefix):
|
70 |
+
name = name[len(agent_name_prefix):]
|
71 |
+
return name
|
aworld/cmd/web/routers/workspaces.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import List, Optional
|
4 |
+
from pydantic import BaseModel
|
5 |
+
|
6 |
+
from fastapi import APIRouter, HTTPException, status, Query, Body
|
7 |
+
|
8 |
+
from aworld.output import WorkSpace, ArtifactType
|
9 |
+
from aworld.output.utils import load_workspace
|
10 |
+
|
11 |
+
router = APIRouter()
|
12 |
+
|
13 |
+
prefix = "/api/workspaces"
|
14 |
+
|
15 |
+
@router.get("/{workspace_id}/tree")
|
16 |
+
async def get_workspace_tree(workspace_id: str):
|
17 |
+
logging.info(f"get_workspace_tree: {workspace_id}")
|
18 |
+
workspace = await get_workspace(workspace_id)
|
19 |
+
return workspace.generate_tree_data()
|
20 |
+
|
21 |
+
|
22 |
+
class ArtifactRequest(BaseModel):
|
23 |
+
artifact_ids: Optional[List[str]] = None
|
24 |
+
artifact_types: Optional[List[str]] = None
|
25 |
+
|
26 |
+
|
27 |
+
@router.post("/{workspace_id}/artifacts")
|
28 |
+
async def get_workspace_artifacts(workspace_id: str, request: ArtifactRequest):
|
29 |
+
"""
|
30 |
+
Get artifacts by workspace id and filter by a list of artifact types.
|
31 |
+
Args:
|
32 |
+
workspace_id: Workspace ID
|
33 |
+
request: Request body containing optional artifact_types list
|
34 |
+
Returns:
|
35 |
+
Dict with filtered artifacts
|
36 |
+
"""
|
37 |
+
artifact_types = request.artifact_types
|
38 |
+
if artifact_types:
|
39 |
+
# Validate all types
|
40 |
+
invalid_types = [t for t in artifact_types if t not in ArtifactType.__members__]
|
41 |
+
if invalid_types:
|
42 |
+
logging.error(f"Invalid artifact_types: {invalid_types}")
|
43 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
|
44 |
+
detail=f"Invalid artifact types: {invalid_types}")
|
45 |
+
logging.info(f"Fetching artifacts of types: {artifact_types}")
|
46 |
+
else:
|
47 |
+
logging.info(f"Fetching all artifacts (no type filter)")
|
48 |
+
|
49 |
+
workspace = await get_workspace(workspace_id)
|
50 |
+
all_artifacts = workspace.list_artifacts()
|
51 |
+
filtered_artifacts = all_artifacts
|
52 |
+
if request.artifact_ids:
|
53 |
+
filtered_artifacts = [a for a in filtered_artifacts if a.artifact_id in request.artifact_ids]
|
54 |
+
if artifact_types:
|
55 |
+
filtered_artifacts = [a for a in filtered_artifacts if a.artifact_type.name in artifact_types]
|
56 |
+
|
57 |
+
return {
|
58 |
+
"data": filtered_artifacts
|
59 |
+
}
|
60 |
+
|
61 |
+
|
62 |
+
@router.get("/{workspace_id}/file/{artifact_id}/content")
|
63 |
+
async def get_workspace_file_content(workspace_id: str, artifact_id: str):
|
64 |
+
logging.info(f"get_workspace_file_content: {workspace_id}, {artifact_id}")
|
65 |
+
workspace = await get_workspace(workspace_id)
|
66 |
+
return {
|
67 |
+
"data": workspace.get_file_content_by_artifact_id(artifact_id)
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
async def get_workspace(workspace_id: str) -> WorkSpace:
|
72 |
+
workspace_type = os.environ.get("WORKSPACE_TYPE", "local")
|
73 |
+
workspace_path = os.environ.get("WORKSPACE_PATH", "./data/workspaces")
|
74 |
+
return await load_workspace(workspace_id, workspace_type, workspace_path)
|