Duibonduil commited on
Commit
aeec429
·
verified ·
1 Parent(s): 8af6ba8

Upload agent_executor.py

Browse files
Files changed (1) hide show
  1. aworld/agent_executor.py +83 -0
aworld/agent_executor.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import AsyncGenerator
2
+ from aworld.output.ui.base import AworldUI
3
+ from aworld.output.workspace import WorkSpace
4
+ from .. import (
5
+ BaseAWorldAgent,
6
+ ChatCompletionChoice,
7
+ ChatCompletionMessage,
8
+ ChatCompletionRequest,
9
+ ChatCompletionResponse,
10
+ )
11
+ from . import agent_loader
12
+ from .agent_ui_parser import AWorldAgentUI
13
+ import logging
14
+ import aworld.trace as trace
15
+ import os
16
+ import uuid
17
+ from dotenv import load_dotenv
18
+ from .agent_server import CURRENT_SERVER
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+ # bugfix for tracer exception
23
+ trace.configure()
24
+
25
+
26
+ async def stream_run(request: ChatCompletionRequest):
27
+ if not request.session_id:
28
+ request.session_id = str(uuid.uuid4())
29
+ if not request.query_id:
30
+ request.query_id = str(uuid.uuid4())
31
+
32
+ logger.info(f"Stream run agent: request={request.model_dump_json()}")
33
+ agent = agent_loader.get_agent(request.model)
34
+ instance: BaseAWorldAgent = agent.instance
35
+ env_file = os.path.join(agent.path, ".env")
36
+ if os.path.exists(env_file):
37
+ logger.info(f"Loading environment variables from {env_file}")
38
+ load_dotenv(env_file, override=True, verbose=True)
39
+
40
+ final_response: str = ""
41
+
42
+ def build_response(delta_content: str):
43
+ nonlocal final_response
44
+ final_response += delta_content
45
+ logger.info(f"Agent {agent.name} response: {delta_content}")
46
+ return ChatCompletionResponse(
47
+ choices=[
48
+ ChatCompletionChoice(
49
+ index=0,
50
+ delta=ChatCompletionMessage(
51
+ role="assistant",
52
+ content=delta_content,
53
+ trace_id=request.trace_id,
54
+ ),
55
+ )
56
+ ]
57
+ )
58
+
59
+ rich_ui = AWorldAgentUI(
60
+ session_id=request.session_id,
61
+ workspace=WorkSpace.from_local_storages(
62
+ workspace_id=request.session_id,
63
+ storage_path=os.path.join(os.curdir, "workspaces", request.session_id),
64
+ ),
65
+ )
66
+
67
+ await CURRENT_SERVER.on_chat_completion_request(request)
68
+
69
+ async for output in instance.run(request=request):
70
+ logger.info(f"Agent {agent.name} output: {output}")
71
+
72
+ if isinstance(output, str):
73
+ yield build_response(output)
74
+ else:
75
+ res = await AworldUI.parse_output(output, rich_ui)
76
+ for item in res if isinstance(res, list) else [res]:
77
+ if isinstance(item, AsyncGenerator):
78
+ async for sub_item in item:
79
+ yield build_response(sub_item)
80
+ else:
81
+ yield build_response(item)
82
+
83
+ await CURRENT_SERVER.on_chat_completion_end(request, final_response)