Duibonduil commited on
Commit
8293a2b
·
verified ·
1 Parent(s): 7ba508f

Upload 4 files

Browse files
aworld/cmd/web/routers/chats.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import json
3
+ from typing import Dict
4
+ from fastapi import APIRouter, Depends
5
+ from fastapi.responses import StreamingResponse
6
+ from aworld.cmd import AgentModel, ChatCompletionRequest
7
+ from aworld.cmd.utils import agent_loader, agent_executor
8
+ from aworld.cmd.web.web_server import get_user_id_from_jwt
9
+ import aworld.trace as trace
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ router = APIRouter()
14
+
15
+ prefix = "/api/agent"
16
+
17
+
18
+ @router.get("/list")
19
+ @router.get("/models")
20
+ async def list_agents() -> Dict[str, AgentModel]:
21
+ return agent_loader.list_agents()
22
+
23
+
24
+ @router.post("/chat/completions")
25
+ async def chat_completion(
26
+ form_data: ChatCompletionRequest, user_id: str = Depends(get_user_id_from_jwt)
27
+ ) -> StreamingResponse:
28
+ # Set user_id from JWT to form_data
29
+ form_data.user_id = user_id
30
+
31
+ async def generate_stream():
32
+ async with trace.span(
33
+ "/chat/chat_completion", attributes={"model": form_data.model}
34
+ ) as span:
35
+ form_data.trace_id = span.get_trace_id()
36
+ async for chunk in agent_executor.stream_run(form_data):
37
+ yield f"data: {json.dumps(chunk.model_dump(), ensure_ascii=False)}\n\n"
38
+
39
+ return StreamingResponse(
40
+ generate_stream(),
41
+ media_type="text/event-stream",
42
+ headers={
43
+ "Cache-Control": "no-cache",
44
+ "Connection": "keep-alive",
45
+ },
46
+ )
aworld/cmd/web/routers/sessions.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import List
3
+ from fastapi import APIRouter, Depends
4
+ from pydantic import BaseModel, Field
5
+ from aworld.cmd import SessionModel
6
+ from aworld.cmd.utils.agent_server import CURRENT_SERVER
7
+ from aworld.cmd.web.web_server import get_user_id_from_jwt
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ router = APIRouter()
12
+
13
+ prefix = "/api/session"
14
+
15
+
16
+ @router.get("/list")
17
+ async def list_sessions(
18
+ user_id: str = Depends(get_user_id_from_jwt),
19
+ ) -> List[SessionModel]:
20
+ return await CURRENT_SERVER.get_session_service().list_sessions(user_id)
21
+
22
+
23
+ class CommonResponse(BaseModel):
24
+ code: int = Field(..., description="The code")
25
+ message: str = Field(..., description="The message")
26
+
27
+ @staticmethod
28
+ def success(message: str = "success"):
29
+ return CommonResponse(code=0, message=message)
30
+
31
+ @staticmethod
32
+ def error(message: str):
33
+ return CommonResponse(code=1, message=message)
34
+
35
+
36
+ class DeleteSessionRequest(BaseModel):
37
+ session_id: str = Field(..., description="The session id")
38
+
39
+
40
+ @router.post("/delete")
41
+ async def delete_session(
42
+ request: DeleteSessionRequest, user_id: str = Depends(get_user_id_from_jwt)
43
+ ) -> CommonResponse:
44
+ try:
45
+ await CURRENT_SERVER.get_session_service().delete_session(
46
+ user_id, request.session_id
47
+ )
48
+ return CommonResponse.success()
49
+ except Exception as e:
50
+ return CommonResponse.error(str(e))
aworld/cmd/web/routers/traces.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from fastapi import APIRouter
3
+ from aworld.trace.server import get_trace_server
4
+ from aworld.trace.constants import RunType
5
+ from aworld.trace.server.util import build_trace_tree
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ router = APIRouter()
10
+
11
+ prefix = "/api/trace"
12
+
13
+
14
+ @router.get("/list")
15
+ async def list_traces():
16
+ storage = get_trace_server().get_storage()
17
+ trace_data = []
18
+ for trace_id in storage.get_all_traces():
19
+ spans = storage.get_all_spans(trace_id)
20
+ spans_sorted = sorted(spans, key=lambda x: x.start_time)
21
+ trace_tree = build_trace_tree(spans_sorted)
22
+ trace_data.append({
23
+ 'trace_id': trace_id,
24
+ 'root_span': trace_tree,
25
+ })
26
+ return {
27
+ "data": trace_data
28
+ }
29
+
30
+
31
+ @router.get("/agent")
32
+ async def get_agent_trace(trace_id: str):
33
+ storage = get_trace_server().get_storage()
34
+ spans = storage.get_all_spans(trace_id)
35
+ spans_dict = {span.span_id: span.dict() for span in spans}
36
+
37
+ filtered_spans = {}
38
+ for span_id, span in spans_dict.items():
39
+ if span.get('is_event', False) and span.get('run_type') == RunType.AGNET.value:
40
+ span['show_name'] = _get_agent_show_name(span)
41
+ filtered_spans[span_id] = span
42
+
43
+ for span in list(filtered_spans.values()):
44
+ parent_id = span['parent_id'] if span['parent_id'] else None
45
+
46
+ while parent_id and parent_id not in filtered_spans:
47
+ parent_span = spans_dict.get(parent_id)
48
+ parent_id = parent_span['parent_id'] if parent_span and parent_span['parent_id'] else None
49
+
50
+ if parent_id:
51
+ parent_span = filtered_spans.get(parent_id)
52
+ if not parent_span:
53
+ continue
54
+
55
+ if 'children' not in parent_span:
56
+ parent_span['children'] = []
57
+ parent_span['children'].append(span)
58
+
59
+ root_spans = [span for span in filtered_spans.values()
60
+ if span['parent_id'] is None or span['parent_id'] not in filtered_spans]
61
+ return {
62
+ "data": root_spans
63
+ }
64
+
65
+
66
+ def _get_agent_show_name(span: dict):
67
+ agent_name_prefix = "agent_event_"
68
+ name = span.get("name")
69
+ if name and name.startswith(agent_name_prefix):
70
+ name = name[len(agent_name_prefix):]
71
+ return name
aworld/cmd/web/routers/workspaces.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from typing import List, Optional
4
+ from pydantic import BaseModel
5
+
6
+ from fastapi import APIRouter, HTTPException, status, Query, Body
7
+
8
+ from aworld.output import WorkSpace, ArtifactType
9
+ from aworld.output.utils import load_workspace
10
+
11
+ router = APIRouter()
12
+
13
+ prefix = "/api/workspaces"
14
+
15
+ @router.get("/{workspace_id}/tree")
16
+ async def get_workspace_tree(workspace_id: str):
17
+ logging.info(f"get_workspace_tree: {workspace_id}")
18
+ workspace = await get_workspace(workspace_id)
19
+ return workspace.generate_tree_data()
20
+
21
+
22
+ class ArtifactRequest(BaseModel):
23
+ artifact_ids: Optional[List[str]] = None
24
+ artifact_types: Optional[List[str]] = None
25
+
26
+
27
+ @router.post("/{workspace_id}/artifacts")
28
+ async def get_workspace_artifacts(workspace_id: str, request: ArtifactRequest):
29
+ """
30
+ Get artifacts by workspace id and filter by a list of artifact types.
31
+ Args:
32
+ workspace_id: Workspace ID
33
+ request: Request body containing optional artifact_types list
34
+ Returns:
35
+ Dict with filtered artifacts
36
+ """
37
+ artifact_types = request.artifact_types
38
+ if artifact_types:
39
+ # Validate all types
40
+ invalid_types = [t for t in artifact_types if t not in ArtifactType.__members__]
41
+ if invalid_types:
42
+ logging.error(f"Invalid artifact_types: {invalid_types}")
43
+ raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
44
+ detail=f"Invalid artifact types: {invalid_types}")
45
+ logging.info(f"Fetching artifacts of types: {artifact_types}")
46
+ else:
47
+ logging.info(f"Fetching all artifacts (no type filter)")
48
+
49
+ workspace = await get_workspace(workspace_id)
50
+ all_artifacts = workspace.list_artifacts()
51
+ filtered_artifacts = all_artifacts
52
+ if request.artifact_ids:
53
+ filtered_artifacts = [a for a in filtered_artifacts if a.artifact_id in request.artifact_ids]
54
+ if artifact_types:
55
+ filtered_artifacts = [a for a in filtered_artifacts if a.artifact_type.name in artifact_types]
56
+
57
+ return {
58
+ "data": filtered_artifacts
59
+ }
60
+
61
+
62
+ @router.get("/{workspace_id}/file/{artifact_id}/content")
63
+ async def get_workspace_file_content(workspace_id: str, artifact_id: str):
64
+ logging.info(f"get_workspace_file_content: {workspace_id}, {artifact_id}")
65
+ workspace = await get_workspace(workspace_id)
66
+ return {
67
+ "data": workspace.get_file_content_by_artifact_id(artifact_id)
68
+ }
69
+
70
+
71
+ async def get_workspace(workspace_id: str) -> WorkSpace:
72
+ workspace_type = os.environ.get("WORKSPACE_TYPE", "local")
73
+ workspace_path = os.environ.get("WORKSPACE_PATH", "./data/workspaces")
74
+ return await load_workspace(workspace_id, workspace_type, workspace_path)