|
import asyncio |
|
from collections import OrderedDict |
|
from dataclasses import dataclass, field |
|
from datetime import datetime |
|
import json |
|
from typing import Any, Awaitable, Coroutine, Optional, Dict, TypedDict |
|
import uuid |
|
import models |
|
|
|
from python.helpers import extract_tools, rate_limiter, files, errors, history, tokens |
|
from python.helpers import dirty_json |
|
from python.helpers.print_style import PrintStyle |
|
from langchain_core.prompts import ( |
|
ChatPromptTemplate, |
|
) |
|
from langchain_core.messages import HumanMessage, SystemMessage, AIMessage, BaseMessage |
|
|
|
import python.helpers.log as Log |
|
from python.helpers.dirty_json import DirtyJson |
|
from python.helpers.defer import DeferredTask |
|
from typing import Callable |
|
from python.helpers.localization import Localization |
|
|
|
|
|
class AgentContext: |
|
|
|
_contexts: dict[str, "AgentContext"] = {} |
|
_counter: int = 0 |
|
|
|
def __init__( |
|
self, |
|
config: "AgentConfig", |
|
id: str | None = None, |
|
name: str | None = None, |
|
agent0: "Agent|None" = None, |
|
log: Log.Log | None = None, |
|
paused: bool = False, |
|
streaming_agent: "Agent|None" = None, |
|
created_at: datetime | None = None, |
|
): |
|
|
|
self.id = id or str(uuid.uuid4()) |
|
self.name = name |
|
self.config = config |
|
self.log = log or Log.Log() |
|
self.agent0 = agent0 or Agent(0, self.config, self) |
|
self.paused = paused |
|
self.streaming_agent = streaming_agent |
|
self.task: DeferredTask | None = None |
|
self.created_at = created_at or datetime.now() |
|
AgentContext._counter += 1 |
|
self.no = AgentContext._counter |
|
|
|
existing = self._contexts.get(self.id, None) |
|
if existing: |
|
AgentContext.remove(self.id) |
|
self._contexts[self.id] = self |
|
|
|
@staticmethod |
|
def get(id: str): |
|
return AgentContext._contexts.get(id, None) |
|
|
|
@staticmethod |
|
def first(): |
|
if not AgentContext._contexts: |
|
return None |
|
return list(AgentContext._contexts.values())[0] |
|
|
|
@staticmethod |
|
def remove(id: str): |
|
context = AgentContext._contexts.pop(id, None) |
|
if context and context.task: |
|
context.task.kill() |
|
return context |
|
|
|
def serialize(self): |
|
return { |
|
"id": self.id, |
|
"name": self.name, |
|
"created_at": ( |
|
Localization.get().serialize_datetime(self.created_at) |
|
if self.created_at else Localization.get().serialize_datetime(datetime.fromtimestamp(0)) |
|
), |
|
"no": self.no, |
|
"log_guid": self.log.guid, |
|
"log_version": len(self.log.updates), |
|
"log_length": len(self.log.logs), |
|
"paused": self.paused, |
|
} |
|
|
|
def get_created_at(self): |
|
return self.created_at |
|
|
|
def kill_process(self): |
|
if self.task: |
|
self.task.kill() |
|
|
|
def reset(self): |
|
self.kill_process() |
|
self.log.reset() |
|
self.agent0 = Agent(0, self.config, self) |
|
self.streaming_agent = None |
|
self.paused = False |
|
|
|
def nudge(self): |
|
self.kill_process() |
|
self.paused = False |
|
if self.streaming_agent: |
|
current_agent = self.streaming_agent |
|
else: |
|
current_agent = self.agent0 |
|
|
|
self.task = self.run_task(current_agent.monologue) |
|
return self.task |
|
|
|
def communicate(self, msg: "UserMessage", broadcast_level: int = 1): |
|
self.paused = False |
|
|
|
if self.streaming_agent: |
|
current_agent = self.streaming_agent |
|
else: |
|
current_agent = self.agent0 |
|
|
|
if self.task and self.task.is_alive(): |
|
|
|
intervention_agent = current_agent |
|
while intervention_agent and broadcast_level != 0: |
|
intervention_agent.intervention = msg |
|
broadcast_level -= 1 |
|
intervention_agent = intervention_agent.data.get( |
|
Agent.DATA_NAME_SUPERIOR, None |
|
) |
|
else: |
|
self.task = self.run_task(self._process_chain, current_agent, msg) |
|
|
|
return self.task |
|
|
|
def run_task( |
|
self, func: Callable[..., Coroutine[Any, Any, Any]], *args: Any, **kwargs: Any |
|
): |
|
if not self.task: |
|
self.task = DeferredTask( |
|
thread_name=self.__class__.__name__, |
|
) |
|
self.task.start_task(func, *args, **kwargs) |
|
return self.task |
|
|
|
|
|
async def _process_chain(self, agent: "Agent", msg: "UserMessage|str", user=True): |
|
try: |
|
msg_template = ( |
|
agent.hist_add_user_message(msg) |
|
if user |
|
else agent.hist_add_tool_result( |
|
tool_name="call_subordinate", tool_result=msg |
|
) |
|
) |
|
response = await agent.monologue() |
|
superior = agent.data.get(Agent.DATA_NAME_SUPERIOR, None) |
|
if superior: |
|
response = await self._process_chain(superior, response, False) |
|
return response |
|
except Exception as e: |
|
agent.handle_critical_exception(e) |
|
|
|
|
|
@dataclass |
|
class ModelConfig: |
|
provider: models.ModelProvider |
|
name: str |
|
ctx_length: int = 0 |
|
limit_requests: int = 0 |
|
limit_input: int = 0 |
|
limit_output: int = 0 |
|
vision: bool = False |
|
kwargs: dict = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
class AgentConfig: |
|
chat_model: ModelConfig |
|
utility_model: ModelConfig |
|
embeddings_model: ModelConfig |
|
browser_model: ModelConfig |
|
prompts_subdir: str = "" |
|
memory_subdir: str = "" |
|
knowledge_subdirs: list[str] = field(default_factory=lambda: ["default", "custom"]) |
|
code_exec_docker_enabled: bool = False |
|
code_exec_docker_name: str = "A0-dev" |
|
code_exec_docker_image: str = "frdel/agent-zero-run:development" |
|
code_exec_docker_ports: dict[str, int] = field( |
|
default_factory=lambda: {"22/tcp": 55022, "80/tcp": 55080} |
|
) |
|
code_exec_docker_volumes: dict[str, dict[str, str]] = field( |
|
default_factory=lambda: { |
|
files.get_base_dir(): {"bind": "/a0", "mode": "rw"}, |
|
files.get_abs_path("work_dir"): {"bind": "/root", "mode": "rw"}, |
|
} |
|
) |
|
code_exec_ssh_enabled: bool = True |
|
code_exec_ssh_addr: str = "localhost" |
|
code_exec_ssh_port: int = 55022 |
|
code_exec_ssh_user: str = "root" |
|
code_exec_ssh_pass: str = "" |
|
additional: Dict[str, Any] = field(default_factory=dict) |
|
|
|
|
|
@dataclass |
|
class UserMessage: |
|
message: str |
|
attachments: list[str] = field(default_factory=list[str]) |
|
system_message: list[str] = field(default_factory=list[str]) |
|
|
|
|
|
class LoopData: |
|
def __init__(self, **kwargs): |
|
self.iteration = -1 |
|
self.system = [] |
|
self.user_message: history.Message | None = None |
|
self.history_output: list[history.OutputMessage] = [] |
|
self.extras_temporary: OrderedDict[str, history.MessageContent] = OrderedDict() |
|
self.extras_persistent: OrderedDict[str, history.MessageContent] = OrderedDict() |
|
self.last_response = "" |
|
|
|
|
|
for key, value in kwargs.items(): |
|
setattr(self, key, value) |
|
|
|
|
|
|
|
class InterventionException(Exception): |
|
pass |
|
|
|
|
|
|
|
class RepairableException(Exception): |
|
pass |
|
|
|
|
|
class HandledException(Exception): |
|
pass |
|
|
|
|
|
class Agent: |
|
|
|
DATA_NAME_SUPERIOR = "_superior" |
|
DATA_NAME_SUBORDINATE = "_subordinate" |
|
DATA_NAME_CTX_WINDOW = "ctx_window" |
|
|
|
def __init__( |
|
self, number: int, config: AgentConfig, context: AgentContext | None = None |
|
): |
|
|
|
|
|
self.config = config |
|
|
|
|
|
self.context = context or AgentContext(config) |
|
|
|
|
|
self.number = number |
|
self.agent_name = f"Agent {self.number}" |
|
|
|
self.history = history.History(self) |
|
self.last_user_message: history.Message | None = None |
|
self.intervention: UserMessage | None = None |
|
self.data = {} |
|
|
|
async def monologue(self): |
|
while True: |
|
try: |
|
|
|
self.loop_data = LoopData(user_message=self.last_user_message) |
|
|
|
await self.call_extensions("monologue_start", loop_data=self.loop_data) |
|
|
|
printer = PrintStyle(italic=True, font_color="#b3ffd9", padding=False) |
|
|
|
|
|
while True: |
|
|
|
self.context.streaming_agent = self |
|
self.loop_data.iteration += 1 |
|
|
|
|
|
await self.call_extensions("message_loop_start", loop_data=self.loop_data) |
|
|
|
try: |
|
|
|
prompt = await self.prepare_prompt(loop_data=self.loop_data) |
|
|
|
|
|
PrintStyle( |
|
bold=True, |
|
font_color="green", |
|
padding=True, |
|
background_color="white", |
|
).print(f"{self.agent_name}: Generating") |
|
log = self.context.log.log( |
|
type="agent", heading=f"{self.agent_name}: Generating" |
|
) |
|
|
|
async def stream_callback(chunk: str, full: str): |
|
|
|
if chunk: |
|
printer.stream(chunk) |
|
self.log_from_stream(full, log) |
|
|
|
agent_response = await self.call_chat_model( |
|
prompt, callback=stream_callback |
|
) |
|
|
|
await self.handle_intervention(agent_response) |
|
|
|
if ( |
|
self.loop_data.last_response == agent_response |
|
): |
|
|
|
self.hist_add_ai_response(agent_response) |
|
|
|
warning_msg = self.read_prompt("fw.msg_repeat.md") |
|
self.hist_add_warning(message=warning_msg) |
|
PrintStyle(font_color="orange", padding=True).print( |
|
warning_msg |
|
) |
|
self.context.log.log(type="warning", content=warning_msg) |
|
|
|
else: |
|
|
|
self.hist_add_ai_response(agent_response) |
|
|
|
tools_result = await self.process_tools(agent_response) |
|
if tools_result: |
|
return tools_result |
|
|
|
|
|
except InterventionException as e: |
|
pass |
|
except RepairableException as e: |
|
|
|
error_message = errors.format_error(e) |
|
self.hist_add_warning(error_message) |
|
PrintStyle(font_color="red", padding=True).print(error_message) |
|
self.context.log.log(type="error", content=error_message) |
|
except Exception as e: |
|
|
|
self.handle_critical_exception(e) |
|
|
|
finally: |
|
|
|
await self.call_extensions( |
|
"message_loop_end", loop_data=self.loop_data |
|
) |
|
|
|
|
|
except InterventionException as e: |
|
pass |
|
except Exception as e: |
|
self.handle_critical_exception(e) |
|
finally: |
|
self.context.streaming_agent = None |
|
|
|
await self.call_extensions("monologue_end", loop_data=self.loop_data) |
|
|
|
async def prepare_prompt(self, loop_data: LoopData) -> ChatPromptTemplate: |
|
|
|
await self.call_extensions("message_loop_prompts_before", loop_data=loop_data) |
|
|
|
|
|
loop_data.system = await self.get_system_prompt(self.loop_data) |
|
loop_data.history_output = self.history.output() |
|
|
|
|
|
await self.call_extensions("message_loop_prompts_after", loop_data=loop_data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extras = history.Message( |
|
False, |
|
content=self.read_prompt("agent.context.extras.md", extras=dirty_json.stringify( |
|
{**loop_data.extras_persistent, **loop_data.extras_temporary} |
|
))).output() |
|
loop_data.extras_temporary.clear() |
|
|
|
|
|
history_langchain: list[BaseMessage] = history.output_langchain( |
|
loop_data.history_output + extras |
|
) |
|
|
|
|
|
system_text = "\n\n".join(loop_data.system) |
|
prompt = ChatPromptTemplate.from_messages( |
|
[ |
|
SystemMessage(content=system_text), |
|
*history_langchain, |
|
|
|
] |
|
) |
|
|
|
|
|
self.set_data( |
|
Agent.DATA_NAME_CTX_WINDOW, |
|
{ |
|
"text": prompt.format(), |
|
"tokens": self.history.get_tokens() |
|
+ tokens.approximate_tokens(system_text) |
|
+ tokens.approximate_tokens(history.output_text(extras)), |
|
}, |
|
) |
|
|
|
return prompt |
|
|
|
def handle_critical_exception(self, exception: Exception): |
|
if isinstance(exception, HandledException): |
|
raise exception |
|
elif isinstance(exception, asyncio.CancelledError): |
|
|
|
PrintStyle(font_color="white", background_color="red", padding=True).print( |
|
f"Context {self.context.id} terminated during message loop" |
|
) |
|
raise HandledException( |
|
exception |
|
) |
|
else: |
|
|
|
error_text = errors.error_text(exception) |
|
error_message = errors.format_error(exception) |
|
PrintStyle(font_color="red", padding=True).print(error_message) |
|
self.context.log.log( |
|
type="error", |
|
heading="Error", |
|
content=error_message, |
|
kvps={"text": error_text}, |
|
) |
|
raise HandledException(exception) |
|
|
|
async def get_system_prompt(self, loop_data: LoopData) -> list[str]: |
|
system_prompt = [] |
|
await self.call_extensions( |
|
"system_prompt", system_prompt=system_prompt, loop_data=loop_data |
|
) |
|
return system_prompt |
|
|
|
def parse_prompt(self, file: str, **kwargs): |
|
prompt_dir = files.get_abs_path("prompts/default") |
|
backup_dir = [] |
|
if ( |
|
self.config.prompts_subdir |
|
): |
|
prompt_dir = files.get_abs_path("prompts", self.config.prompts_subdir) |
|
backup_dir.append(files.get_abs_path("prompts/default")) |
|
prompt = files.parse_file( |
|
files.get_abs_path(prompt_dir, file), _backup_dirs=backup_dir, **kwargs |
|
) |
|
return prompt |
|
|
|
def read_prompt(self, file: str, **kwargs) -> str: |
|
prompt_dir = files.get_abs_path("prompts/default") |
|
backup_dir = [] |
|
if ( |
|
self.config.prompts_subdir |
|
): |
|
prompt_dir = files.get_abs_path("prompts", self.config.prompts_subdir) |
|
backup_dir.append(files.get_abs_path("prompts/default")) |
|
prompt = files.read_file( |
|
files.get_abs_path(prompt_dir, file), _backup_dirs=backup_dir, **kwargs |
|
) |
|
prompt = files.remove_code_fences(prompt) |
|
return prompt |
|
|
|
def get_data(self, field: str): |
|
return self.data.get(field, None) |
|
|
|
def set_data(self, field: str, value): |
|
self.data[field] = value |
|
|
|
def hist_add_message( |
|
self, ai: bool, content: history.MessageContent, tokens: int = 0 |
|
): |
|
return self.history.add_message(ai=ai, content=content, tokens=tokens) |
|
|
|
def hist_add_user_message(self, message: UserMessage, intervention: bool = False): |
|
self.history.new_topic() |
|
|
|
|
|
if intervention: |
|
content = self.parse_prompt( |
|
"fw.intervention.md", |
|
message=message.message, |
|
attachments=message.attachments, |
|
system_message=message.system_message |
|
) |
|
else: |
|
content = self.parse_prompt( |
|
"fw.user_message.md", |
|
message=message.message, |
|
attachments=message.attachments, |
|
system_message=message.system_message |
|
) |
|
|
|
|
|
if isinstance(content, dict): |
|
content = {k: v for k, v in content.items() if v} |
|
|
|
|
|
msg = self.hist_add_message(False, content=content) |
|
self.last_user_message = msg |
|
return msg |
|
|
|
def hist_add_ai_response(self, message: str): |
|
self.loop_data.last_response = message |
|
content = self.parse_prompt("fw.ai_response.md", message=message) |
|
return self.hist_add_message(True, content=content) |
|
|
|
def hist_add_warning(self, message: history.MessageContent): |
|
content = self.parse_prompt("fw.warning.md", message=message) |
|
return self.hist_add_message(False, content=content) |
|
|
|
def hist_add_tool_result(self, tool_name: str, tool_result: str): |
|
content = self.parse_prompt( |
|
"fw.tool_result.md", tool_name=tool_name, tool_result=tool_result |
|
) |
|
return self.hist_add_message(False, content=content) |
|
|
|
def concat_messages( |
|
self, messages |
|
): |
|
return self.history.output_text(human_label="user", ai_label="assistant") |
|
|
|
def get_chat_model(self): |
|
return models.get_model( |
|
models.ModelType.CHAT, |
|
self.config.chat_model.provider, |
|
self.config.chat_model.name, |
|
**self.config.chat_model.kwargs, |
|
) |
|
|
|
def get_utility_model(self): |
|
return models.get_model( |
|
models.ModelType.CHAT, |
|
self.config.utility_model.provider, |
|
self.config.utility_model.name, |
|
**self.config.utility_model.kwargs, |
|
) |
|
|
|
def get_embedding_model(self): |
|
return models.get_model( |
|
models.ModelType.EMBEDDING, |
|
self.config.embeddings_model.provider, |
|
self.config.embeddings_model.name, |
|
**self.config.embeddings_model.kwargs, |
|
) |
|
|
|
async def call_utility_model( |
|
self, |
|
system: str, |
|
message: str, |
|
callback: Callable[[str], Awaitable[None]] | None = None, |
|
background: bool = False, |
|
): |
|
prompt = ChatPromptTemplate.from_messages( |
|
[SystemMessage(content=system), HumanMessage(content=message)] |
|
) |
|
|
|
response = "" |
|
|
|
|
|
model = self.get_utility_model() |
|
|
|
|
|
limiter = await self.rate_limiter( |
|
self.config.utility_model, prompt.format(), background |
|
) |
|
|
|
async for chunk in (prompt | model).astream({}): |
|
await self.handle_intervention() |
|
|
|
content = models.parse_chunk(chunk) |
|
limiter.add(output=tokens.approximate_tokens(content)) |
|
response += content |
|
|
|
if callback: |
|
await callback(content) |
|
|
|
return response |
|
|
|
async def call_chat_model( |
|
self, |
|
prompt: ChatPromptTemplate, |
|
callback: Callable[[str, str], Awaitable[None]] | None = None, |
|
): |
|
response = "" |
|
|
|
|
|
model = self.get_chat_model() |
|
|
|
|
|
limiter = await self.rate_limiter(self.config.chat_model, prompt.format()) |
|
|
|
async for chunk in (prompt | model).astream({}): |
|
await self.handle_intervention() |
|
|
|
content = models.parse_chunk(chunk) |
|
limiter.add(output=tokens.approximate_tokens(content)) |
|
response += content |
|
|
|
if callback: |
|
await callback(content, response) |
|
|
|
return response |
|
|
|
async def rate_limiter( |
|
self, model_config: ModelConfig, input: str, background: bool = False |
|
): |
|
|
|
wait_log = None |
|
|
|
async def wait_callback(msg: str, key: str, total: int, limit: int): |
|
nonlocal wait_log |
|
if not wait_log: |
|
wait_log = self.context.log.log( |
|
type="util", |
|
update_progress="none", |
|
heading=msg, |
|
model=f"{model_config.provider.value}\\{model_config.name}", |
|
) |
|
wait_log.update(heading=msg, key=key, value=total, limit=limit) |
|
if not background: |
|
self.context.log.set_progress(msg, -1) |
|
|
|
|
|
limiter = models.get_rate_limiter( |
|
model_config.provider, |
|
model_config.name, |
|
model_config.limit_requests, |
|
model_config.limit_input, |
|
model_config.limit_output, |
|
) |
|
limiter.add(input=tokens.approximate_tokens(input)) |
|
limiter.add(requests=1) |
|
await limiter.wait(callback=wait_callback) |
|
return limiter |
|
|
|
async def handle_intervention(self, progress: str = ""): |
|
while self.context.paused: |
|
await asyncio.sleep(0.1) |
|
if ( |
|
self.intervention |
|
): |
|
msg = self.intervention |
|
self.intervention = None |
|
if progress.strip(): |
|
self.hist_add_ai_response(progress) |
|
|
|
self.hist_add_user_message(msg, intervention=True) |
|
raise InterventionException(msg) |
|
|
|
async def wait_if_paused(self): |
|
while self.context.paused: |
|
await asyncio.sleep(0.1) |
|
|
|
async def process_tools(self, msg: str): |
|
|
|
tool_request = extract_tools.json_parse_dirty(msg) |
|
|
|
if tool_request is not None: |
|
tool_name = tool_request.get("tool_name", "") |
|
tool_method = None |
|
tool_args = tool_request.get("tool_args", {}) |
|
|
|
if ":" in tool_name: |
|
tool_name, tool_method = tool_name.split(":", 1) |
|
|
|
tool = self.get_tool(name=tool_name, method=tool_method, args=tool_args, message=msg) |
|
|
|
await self.handle_intervention() |
|
await tool.before_execution(**tool_args) |
|
await self.handle_intervention() |
|
response = await tool.execute(**tool_args) |
|
await self.handle_intervention() |
|
await tool.after_execution(response) |
|
await self.handle_intervention() |
|
if response.break_loop: |
|
return response.message |
|
else: |
|
msg = self.read_prompt("fw.msg_misformat.md") |
|
self.hist_add_warning(msg) |
|
PrintStyle(font_color="red", padding=True).print(msg) |
|
self.context.log.log( |
|
type="error", content=f"{self.agent_name}: Message misformat" |
|
) |
|
|
|
def log_from_stream(self, stream: str, logItem: Log.LogItem): |
|
try: |
|
if len(stream) < 25: |
|
return |
|
response = DirtyJson.parse_string(stream) |
|
if isinstance(response, dict): |
|
|
|
logItem.update(content=stream, kvps=response) |
|
except Exception as e: |
|
pass |
|
|
|
def get_tool(self, name: str, method: str | None, args: dict, message: str, **kwargs): |
|
from python.tools.unknown import Unknown |
|
from python.helpers.tool import Tool |
|
|
|
classes = extract_tools.load_classes_from_folder( |
|
"python/tools", name + ".py", Tool |
|
) |
|
tool_class = classes[0] if classes else Unknown |
|
return tool_class(agent=self, name=name, method=method, args=args, message=message, **kwargs) |
|
|
|
async def call_extensions(self, folder: str, **kwargs) -> Any: |
|
from python.helpers.extension import Extension |
|
|
|
classes = extract_tools.load_classes_from_folder( |
|
"python/extensions/" + folder, "*", Extension |
|
) |
|
for cls in classes: |
|
await cls(agent=self).execute(**kwargs) |
|
|