Spaces:
Sleeping
Sleeping
Rename AWorld-main/aworlddistributed/aworldspace/utils/job.py to aworlddistributed/aworldspace/utils/job.py
f9b02d8
verified
| import inspect | |
| import json | |
| import inspect | |
| import json | |
| import logging | |
| import time | |
| import uuid | |
| from typing import Generator, Iterator, AsyncGenerator, Optional | |
| from aworld.core.task import Task | |
| from aworld.utils.common import get_local_ip | |
| from fastapi import status, HTTPException | |
| from fastapi.concurrency import run_in_threadpool | |
| from pydantic import BaseModel | |
| from starlette.responses import StreamingResponse | |
| from aworldspace.base import AGENT_SPACE | |
| from aworldspace.utils.utils import get_last_user_message | |
| from base import OpenAIChatCompletionForm | |
| async def generate_openai_chat_completion(form_data: OpenAIChatCompletionForm): | |
| messages = [message.model_dump() for message in form_data.messages] | |
| user_message = get_last_user_message(messages) | |
| PIPELINES = await AGENT_SPACE.get_agents_meta() | |
| PIPELINE_MODULES = await AGENT_SPACE.get_agent_modules() | |
| if ( | |
| form_data.model not in PIPELINES | |
| or PIPELINES[form_data.model]["type"] == "filter" | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Pipeline {form_data.model} not found", | |
| ) | |
| def job(): | |
| pipeline = PIPELINES[form_data.model] | |
| pipeline_id = form_data.model | |
| if pipeline["type"] == "manifold": | |
| manifold_id, pipeline_id = pipeline_id.split(".", 1) | |
| pipe = PIPELINE_MODULES[manifold_id].pipe | |
| else: | |
| pipe = PIPELINE_MODULES[pipeline_id].pipe | |
| def process_line(model, line): | |
| if isinstance(line, Task): | |
| task_output_meta = line.outputs._metadata | |
| line = openai_chat_chunk_message_template(model, "", task_output_meta=task_output_meta) | |
| return f"data: {json.dumps(line)}\n\n" | |
| if isinstance(line, BaseModel): | |
| line = line.model_dump_json() | |
| line = f"data: {line}" | |
| if isinstance(line, dict): | |
| line = f"data: {json.dumps(line)}" | |
| try: | |
| line = line.decode("utf-8") | |
| except Exception: | |
| pass | |
| if line.startswith("data:"): | |
| return f"{line}\n\n" | |
| else: | |
| line = openai_chat_chunk_message_template(model, line) | |
| return f"data: {json.dumps(line)}\n\n" | |
| if form_data.stream: | |
| async def stream_content(): | |
| async def execute_pipe(_pipe): | |
| if inspect.iscoroutinefunction(_pipe): | |
| return await _pipe(user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump()) | |
| else: | |
| return _pipe(user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump()) | |
| try: | |
| res = await execute_pipe(pipe) | |
| # Directly return if the response is a StreamingResponse | |
| if isinstance(res, StreamingResponse): | |
| async for data in res.body_iterator: | |
| yield data | |
| return | |
| if isinstance(res, dict): | |
| yield f"data: {json.dumps(res)}\n\n" | |
| return | |
| except Exception as e: | |
| logging.error(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| yield f"data: {json.dumps({'error': {'detail': str(e)}})}\n\n" | |
| return | |
| if isinstance(res, str): | |
| message = openai_chat_chunk_message_template(form_data.model, res) | |
| yield f"data: {json.dumps(message)}\n\n" | |
| if isinstance(res, Iterator): | |
| for line in res: | |
| yield process_line(form_data.model, line) | |
| if isinstance(res, AsyncGenerator): | |
| async for line in res: | |
| yield process_line(form_data.model, line) | |
| logging.info(f"AsyncGenerator end...") | |
| if isinstance(res, str) or isinstance(res, Generator) or isinstance(res, AsyncGenerator): | |
| finish_message = openai_chat_chunk_message_template( | |
| form_data.model, "" | |
| ) | |
| finish_message["choices"][0]["finish_reason"] = "stop" | |
| print(f"Pipe-Dataline:::: DONE") | |
| yield f"data: {json.dumps(finish_message)}\n\n" | |
| yield "data: [DONE]" | |
| return StreamingResponse(stream_content(), media_type="text/event-stream") | |
| else: | |
| res = pipe( | |
| user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump(), | |
| ) | |
| logging.info(f"stream:false:{res}") | |
| if isinstance(res, dict): | |
| return res | |
| elif isinstance(res, BaseModel): | |
| return res.model_dump() | |
| else: | |
| message = "" | |
| if isinstance(res, str): | |
| message = res | |
| if isinstance(res, Generator): | |
| for stream in res: | |
| message = f"{message}{stream}" | |
| logging.info(f"stream:false:{message}") | |
| return { | |
| "id": f"{form_data.model}-{str(uuid.uuid4())}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": form_data.model, | |
| "choices": [ | |
| { | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": message, | |
| }, | |
| "logprobs": None, | |
| "finish_reason": "stop", | |
| } | |
| ], | |
| } | |
| return await run_in_threadpool(job) | |
| async def call_pipeline(form_data: OpenAIChatCompletionForm): | |
| messages = [message.model_dump() for message in form_data.messages] | |
| user_message = get_last_user_message(messages) | |
| PIPELINES = await AGENT_SPACE.get_agents_meta() | |
| PIPELINE_MODULES = await AGENT_SPACE.get_agent_modules() | |
| if ( | |
| form_data.model not in PIPELINES | |
| or PIPELINES[form_data.model]["type"] == "filter" | |
| ): | |
| raise HTTPException( | |
| status_code=status.HTTP_404_NOT_FOUND, | |
| detail=f"Pipeline {form_data.model} not found", | |
| ) | |
| pipeline = PIPELINES[form_data.model] | |
| pipeline_id = form_data.model | |
| if pipeline["type"] == "manifold": | |
| manifold_id, pipeline_id = pipeline_id.split(".", 1) | |
| pipe = PIPELINE_MODULES[manifold_id].pipe | |
| else: | |
| pipe = PIPELINE_MODULES[pipeline_id].pipe | |
| if form_data.stream: | |
| async def execute_pipe(_pipe): | |
| if inspect.iscoroutinefunction(_pipe): | |
| return await _pipe(user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump()) | |
| else: | |
| return _pipe(user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump()) | |
| res = await execute_pipe(pipe) | |
| return res | |
| else: | |
| if not inspect.iscoroutinefunction(pipe): | |
| return await run_in_threadpool( | |
| pipe, | |
| user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump() | |
| ) | |
| else: | |
| return await pipe( | |
| user_message=user_message, | |
| model_id=pipeline_id, | |
| messages=messages, | |
| body=form_data.model_dump() | |
| ) | |
| def openai_chat_chunk_message_template( | |
| model: str, | |
| content: Optional[str] = None, | |
| tool_calls: Optional[list[dict]] = None, | |
| usage: Optional[dict] = None, | |
| **kwargs | |
| ) -> dict: | |
| template = openai_chat_message_template(model, **kwargs) | |
| template["object"] = "chat.completion.chunk" | |
| template["choices"][0]["index"] = 0 | |
| template["choices"][0]["delta"] = {} | |
| if content: | |
| template["choices"][0]["delta"]["content"] = content | |
| if tool_calls: | |
| template["choices"][0]["delta"]["tool_calls"] = tool_calls | |
| if not content and not tool_calls: | |
| template["choices"][0]["finish_reason"] = "stop" | |
| if usage: | |
| template["usage"] = usage | |
| return template | |
| def openai_chat_message_template(model: str, **kwargs): | |
| return { | |
| "id": f"{model}-{str(uuid.uuid4())}", | |
| "created": int(time.time()), | |
| "model": model, | |
| "node_id": get_local_ip(), | |
| "task_output_meta": kwargs.get("task_output_meta"), | |
| "choices": [{"index": 0, "logprobs": None, "finish_reason": None}], | |
| } |