Spaces:
Sleeping
Sleeping
Upload job.py
Browse files
AWorld-main/aworlddistributed/aworldspace/job.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import json
|
3 |
+
import inspect
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import time
|
7 |
+
import uuid
|
8 |
+
from typing import Generator, Iterator, AsyncGenerator, Optional
|
9 |
+
|
10 |
+
from aworld.core.task import Task
|
11 |
+
from aworld.utils.common import get_local_ip
|
12 |
+
from fastapi import status, HTTPException
|
13 |
+
from fastapi.concurrency import run_in_threadpool
|
14 |
+
from pydantic import BaseModel
|
15 |
+
from starlette.responses import StreamingResponse
|
16 |
+
|
17 |
+
from aworldspace.base import AGENT_SPACE
|
18 |
+
from aworldspace.utils.utils import get_last_user_message
|
19 |
+
from base import OpenAIChatCompletionForm
|
20 |
+
|
21 |
+
async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm):
|
22 |
+
messages = [message.model_dump() for message in form_data.messages]
|
23 |
+
user_message = get_last_user_message(messages)
|
24 |
+
PIPELINES = await AGENT_SPACE.get_agents_meta()
|
25 |
+
PIPELINE_MODULES = await AGENT_SPACE.get_agent_modules()
|
26 |
+
if (
|
27 |
+
form_data.model not in PIPELINES
|
28 |
+
or PIPELINES[form_data.model]["type"] == "filter"
|
29 |
+
):
|
30 |
+
raise HTTPException(
|
31 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
32 |
+
detail=f"Pipeline {form_data.model} not found",
|
33 |
+
)
|
34 |
+
|
35 |
+
def job():
|
36 |
+
pipeline = PIPELINES[form_data.model]
|
37 |
+
pipeline_id = form_data.model
|
38 |
+
|
39 |
+
if pipeline["type"] == "manifold":
|
40 |
+
manifold_id, pipeline_id = pipeline_id.split(".", 1)
|
41 |
+
pipe = PIPELINE_MODULES[manifold_id].pipe
|
42 |
+
else:
|
43 |
+
pipe = PIPELINE_MODULES[pipeline_id].pipe
|
44 |
+
|
45 |
+
def process_line(model, line):
|
46 |
+
if isinstance(line, Task):
|
47 |
+
task_output_meta = line.outputs._metadata
|
48 |
+
line = openai_chat_chunk_message_template(model, "", task_output_meta=task_output_meta)
|
49 |
+
return f"data: {json.dumps(line)}\n\n"
|
50 |
+
if isinstance(line, BaseModel):
|
51 |
+
line = line.model_dump_json()
|
52 |
+
line = f"data: {line}"
|
53 |
+
if isinstance(line, dict):
|
54 |
+
line = f"data: {json.dumps(line)}"
|
55 |
+
|
56 |
+
try:
|
57 |
+
line = line.decode("utf-8")
|
58 |
+
except Exception:
|
59 |
+
pass
|
60 |
+
|
61 |
+
if line.startswith("data:"):
|
62 |
+
return f"{line}\n\n"
|
63 |
+
else:
|
64 |
+
line = openai_chat_chunk_message_template(model, line)
|
65 |
+
return f"data: {json.dumps(line)}\n\n"
|
66 |
+
|
67 |
+
if form_data.stream:
|
68 |
+
async def stream_content():
|
69 |
+
async def execute_pipe(_pipe):
|
70 |
+
if inspect.iscoroutinefunction(_pipe):
|
71 |
+
return await _pipe(user_message=user_message,
|
72 |
+
model_id=pipeline_id,
|
73 |
+
messages=messages,
|
74 |
+
body=form_data.model_dump())
|
75 |
+
else:
|
76 |
+
return _pipe(user_message=user_message,
|
77 |
+
model_id=pipeline_id,
|
78 |
+
messages=messages,
|
79 |
+
body=form_data.model_dump())
|
80 |
+
|
81 |
+
try:
|
82 |
+
res = await execute_pipe(pipe)
|
83 |
+
|
84 |
+
# Directly return if the response is a StreamingResponse
|
85 |
+
if isinstance(res, StreamingResponse):
|
86 |
+
async for data in res.body_iterator:
|
87 |
+
yield data
|
88 |
+
return
|
89 |
+
if isinstance(res, dict):
|
90 |
+
yield f"data: {json.dumps(res)}\n\n"
|
91 |
+
return
|
92 |
+
|
93 |
+
except Exception as e:
|
94 |
+
logging.error(f"Error: {e}")
|
95 |
+
import traceback
|
96 |
+
traceback.print_exc()
|
97 |
+
yield f"data: {json.dumps({'error': {'detail': str(e)}})}\n\n"
|
98 |
+
return
|
99 |
+
|
100 |
+
if isinstance(res, str):
|
101 |
+
message = openai_chat_chunk_message_template(form_data.model, res)
|
102 |
+
yield f"data: {json.dumps(message)}\n\n"
|
103 |
+
|
104 |
+
if isinstance(res, Iterator):
|
105 |
+
for line in res:
|
106 |
+
yield process_line(form_data.model, line)
|
107 |
+
|
108 |
+
if isinstance(res, AsyncGenerator):
|
109 |
+
async for line in res:
|
110 |
+
yield process_line(form_data.model, line)
|
111 |
+
logging.info(f"AsyncGenerator end...")
|
112 |
+
|
113 |
+
if isinstance(res, str) or isinstance(res, Generator) or isinstance(res, AsyncGenerator):
|
114 |
+
finish_message = openai_chat_chunk_message_template(
|
115 |
+
form_data.model, ""
|
116 |
+
)
|
117 |
+
finish_message["choices"][0]["finish_reason"] = "stop"
|
118 |
+
print(f"Pipe-Dataline:::: DONE")
|
119 |
+
yield f"data: {json.dumps(finish_message)}\n\n"
|
120 |
+
yield "data: [DONE]"
|
121 |
+
|
122 |
+
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
123 |
+
else:
|
124 |
+
res = pipe(
|
125 |
+
user_message=user_message,
|
126 |
+
model_id=pipeline_id,
|
127 |
+
messages=messages,
|
128 |
+
body=form_data.model_dump(),
|
129 |
+
)
|
130 |
+
logging.info(f"stream:false:{res}")
|
131 |
+
|
132 |
+
if isinstance(res, dict):
|
133 |
+
return res
|
134 |
+
elif isinstance(res, BaseModel):
|
135 |
+
return res.model_dump()
|
136 |
+
else:
|
137 |
+
|
138 |
+
message = ""
|
139 |
+
|
140 |
+
if isinstance(res, str):
|
141 |
+
message = res
|
142 |
+
|
143 |
+
if isinstance(res, Generator):
|
144 |
+
for stream in res:
|
145 |
+
message = f"{message}{stream}"
|
146 |
+
|
147 |
+
logging.info(f"stream:false:{message}")
|
148 |
+
return {
|
149 |
+
"id": f"{form_data.model}-{str(uuid.uuid4())}",
|
150 |
+
"object": "chat.completion",
|
151 |
+
"created": int(time.time()),
|
152 |
+
"model": form_data.model,
|
153 |
+
"choices": [
|
154 |
+
{
|
155 |
+
"index": 0,
|
156 |
+
"message": {
|
157 |
+
"role": "assistant",
|
158 |
+
"content": message,
|
159 |
+
},
|
160 |
+
"logprobs": None,
|
161 |
+
"finish_reason": "stop",
|
162 |
+
}
|
163 |
+
],
|
164 |
+
}
|
165 |
+
|
166 |
+
|
167 |
+
return await run_in_threadpool(job)
|
168 |
+
|
169 |
+
|
170 |
+
async def call_pipeline(form_data: OpenAIChatCompletionForm):
|
171 |
+
messages = [message.model_dump() for message in form_data.messages]
|
172 |
+
user_message = get_last_user_message(messages)
|
173 |
+
PIPELINES = await AGENT_SPACE.get_agents_meta()
|
174 |
+
PIPELINE_MODULES = await AGENT_SPACE.get_agent_modules()
|
175 |
+
if (
|
176 |
+
form_data.model not in PIPELINES
|
177 |
+
or PIPELINES[form_data.model]["type"] == "filter"
|
178 |
+
):
|
179 |
+
raise HTTPException(
|
180 |
+
status_code=status.HTTP_404_NOT_FOUND,
|
181 |
+
detail=f"Pipeline {form_data.model} not found",
|
182 |
+
)
|
183 |
+
|
184 |
+
pipeline = PIPELINES[form_data.model]
|
185 |
+
pipeline_id = form_data.model
|
186 |
+
|
187 |
+
if pipeline["type"] == "manifold":
|
188 |
+
manifold_id, pipeline_id = pipeline_id.split(".", 1)
|
189 |
+
pipe = PIPELINE_MODULES[manifold_id].pipe
|
190 |
+
else:
|
191 |
+
pipe = PIPELINE_MODULES[pipeline_id].pipe
|
192 |
+
|
193 |
+
if form_data.stream:
|
194 |
+
async def execute_pipe(_pipe):
|
195 |
+
if inspect.iscoroutinefunction(_pipe):
|
196 |
+
return await _pipe(user_message=user_message,
|
197 |
+
model_id=pipeline_id,
|
198 |
+
messages=messages,
|
199 |
+
body=form_data.model_dump())
|
200 |
+
else:
|
201 |
+
return _pipe(user_message=user_message,
|
202 |
+
model_id=pipeline_id,
|
203 |
+
messages=messages,
|
204 |
+
body=form_data.model_dump())
|
205 |
+
|
206 |
+
res = await execute_pipe(pipe)
|
207 |
+
return res
|
208 |
+
else:
|
209 |
+
if not inspect.iscoroutinefunction(pipe):
|
210 |
+
return await run_in_threadpool(
|
211 |
+
pipe,
|
212 |
+
user_message=user_message,
|
213 |
+
model_id=pipeline_id,
|
214 |
+
messages=messages,
|
215 |
+
body=form_data.model_dump()
|
216 |
+
)
|
217 |
+
else:
|
218 |
+
return await pipe(
|
219 |
+
user_message=user_message,
|
220 |
+
model_id=pipeline_id,
|
221 |
+
messages=messages,
|
222 |
+
body=form_data.model_dump()
|
223 |
+
)
|
224 |
+
|
225 |
+
def openai_chat_chunk_message_template(
|
226 |
+
model: str,
|
227 |
+
content: Optional[str] = None,
|
228 |
+
tool_calls: Optional[list[dict]] = None,
|
229 |
+
usage: Optional[dict] = None,
|
230 |
+
**kwargs
|
231 |
+
) -> dict:
|
232 |
+
template = openai_chat_message_template(model, **kwargs)
|
233 |
+
template["object"] = "chat.completion.chunk"
|
234 |
+
|
235 |
+
template["choices"][0]["index"] = 0
|
236 |
+
template["choices"][0]["delta"] = {}
|
237 |
+
|
238 |
+
if content:
|
239 |
+
template["choices"][0]["delta"]["content"] = content
|
240 |
+
|
241 |
+
if tool_calls:
|
242 |
+
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
243 |
+
|
244 |
+
if not content and not tool_calls:
|
245 |
+
template["choices"][0]["finish_reason"] = "stop"
|
246 |
+
|
247 |
+
if usage:
|
248 |
+
template["usage"] = usage
|
249 |
+
return template
|
250 |
+
|
251 |
+
def openai_chat_message_template(model: str, **kwargs):
|
252 |
+
return {
|
253 |
+
"id": f"{model}-{str(uuid.uuid4())}",
|
254 |
+
"created": int(time.time()),
|
255 |
+
"model": model,
|
256 |
+
"node_id": get_local_ip(),
|
257 |
+
"task_output_meta": kwargs.get("task_output_meta"),
|
258 |
+
"choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
|
259 |
+
}
|