Duibonduil's picture
Rename AWorld-main/aworlddistributed/base.py to aworlddistributed/base.py
897ed32 verified
raw
history blame
3.13 kB
import uuid
from typing import Any
from typing import List, Optional
from datetime import datetime
from pydantic import BaseModel, ConfigDict
from pydantic import Field
class AworldTask(BaseModel):
task_id: str = Field(default=None, description="task id")
agent_id: str = Field(default=None, description="agent id")
agent_input: str = Field(default=None, description="agent input")
session_id: Optional[str] = Field(default=None, description="session id")
user_id: Optional[str] = Field(default=None, description="user id")
llm_provider: Optional[str] = Field(default=None, description="llm provider")
llm_model_name: Optional[str] = Field(default=None, description="llm model name")
llm_api_key: Optional[str] = Field(default=None, description="llm api key")
llm_base_url: Optional[str] = Field(default=None, description="llm base url")
llm_custom_input: Optional[str] = Field(default=None, description="custom_input")
task_system_prompt: Optional[str] = Field(default=None, description="task_system_prompt")
mcp_servers: Optional[list[str]] = Field(default=None, description="mcp_servers")
node_id: Optional[str] = Field(default=None, description="execute task node_id")
client_id: Optional[str] = Field(default=None, description="submit client ip")
status: Optional[str] = Field(default="INIT", description="submitted/running/execute_failed/execute_success")
history_messages: Optional[int] = Field(default=100, description="history_message")
max_steps: Optional[int] = Field(default=100, description="max_steps")
max_retries: Optional[int] = Field(default=5, description="max_retries use Exponential backoff with jitter")
ext_info: Optional[dict] = Field(default_factory=dict, description="custom")
created_at: Optional[datetime] = Field(default=None, description="created time")
updated_at: Optional[datetime] = Field(default=None, description="updated time")
def mark_running(self):
self.status = 'RUNNING'
def mark_failed(self):
self.status = 'FAILED'
def mark_success(self):
self.status = 'SUCCESS'
class AworldTaskResult(BaseModel):
task: AworldTask = Field(default=None, description="task")
server_host: Optional[str] = Field(default=None, description="aworld server id")
data: Any = Field(default=None, description="result data")
class AworldTaskForm(BaseModel):
batch_id: str = Field(default=str(uuid.uuid4()), description="batch_id")
task: Optional[AworldTask] = Field(default=None, description="task")
user_id: Optional[str] = Field(default=None, description="user id")
client_id: Optional[str] = Field(default=None, description="submit client ip")
class OpenAIChatMessage(BaseModel):
role: str
content: str | List
model_config = ConfigDict(extra="allow")
class OpenAIChatCompletionForm(BaseModel):
stream: bool = True
model: str
messages: List[OpenAIChatMessage]
model_config = ConfigDict(extra="allow")
class FilterForm(BaseModel):
body: dict
user: Optional[dict] = None
model_config = ConfigDict(extra="allow")