Spaces:
Sleeping
Sleeping
import asyncio | |
import sys | |
import streamlit as st | |
from dotenv import load_dotenv | |
import logging | |
import os | |
import traceback | |
import importlib.util | |
import utils | |
import aworld.trace as trace | |
from trace_net import generate_trace_graph_full | |
from aworld.trace.base import get_tracer_provider | |
load_dotenv(os.path.join(os.getcwd(), ".env")) | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
sys.path.insert(0, os.getcwd()) | |
def agent_page(): | |
st.set_page_config( | |
page_title="AWorld Agent", | |
page_icon=":robot_face:", | |
layout="wide", | |
) | |
st.markdown( | |
"""\ | |
<style> | |
.stAppHeader { display: none; } | |
div[data-testid="stMarkdownContainer"] pre { | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
div[data-testid="stMarkdownContainer"] img { | |
max-height: 500px; | |
} | |
</style>""", | |
unsafe_allow_html=True, | |
) | |
query_params = st.query_params | |
selected_agent_from_url = query_params.get("agent", None) | |
if "selected_agent" not in st.session_state: | |
st.session_state.selected_agent = selected_agent_from_url | |
logger.info(f"Initialized selected_agent from URL: {selected_agent_from_url}") | |
if selected_agent_from_url != st.session_state.selected_agent: | |
st.session_state.selected_agent = selected_agent_from_url | |
with st.sidebar: | |
st.title("AWorld Agents List") | |
for agent in utils.list_agents(): | |
if st.button(agent): | |
st.query_params["agent"] = agent | |
st.session_state.selected_agent = agent | |
logger.info(f"selected_agent={st.session_state.selected_agent}") | |
if st.session_state.selected_agent: | |
agent_name = st.session_state.selected_agent | |
st.title(f"AWorld Agent: {agent_name}") | |
if prompt := st.chat_input("Input message here~"): | |
with st.chat_message("user"): | |
st.markdown(prompt) | |
with st.chat_message("assistant"): | |
agent_name = st.session_state.selected_agent | |
agent_package_path = utils.get_agent_package_path(agent_name) | |
agent_module_file = os.path.join(agent_package_path, "agent.py") | |
try: | |
spec = importlib.util.spec_from_file_location( | |
agent_name, agent_module_file | |
) | |
if spec is None or spec.loader is None: | |
logger.error( | |
f"Could not load spec for agent {agent_name} from {agent_module_file}" | |
) | |
st.error(f"Error: Could not load agent! {agent_name}") | |
return | |
agent_module = importlib.util.module_from_spec(spec) | |
spec.loader.exec_module(agent_module) | |
except Exception as e: | |
logger.error( | |
f"Error loading agent {agent_name}, cwd:{os.getcwd()}, sys.path:{sys.path}: {traceback.format_exc()}" | |
) | |
st.error(f"Error: Could not load agent! {agent_name}") | |
return | |
agent = agent_module.AWorldAgent() | |
async def markdown_generator(): | |
trace_id = None | |
async with trace.span("start") as span: | |
trace_id = span.get_trace_id() | |
async for line in agent.run(prompt): | |
st.write(line) | |
await asyncio.sleep(0.1) | |
get_tracer_provider().force_flush(5000) | |
file_name = f"graph.{trace_id}.html" | |
folder_name = "trace_data" | |
generate_trace_graph_full( | |
trace_id, folder_name=folder_name, file_name=file_name | |
) | |
view_page_url = f"/trace?trace_id={trace_id}" | |
st.write(f"\n---\n[View Trace]({view_page_url})\n") | |
asyncio.run(markdown_generator()) | |
else: | |
st.title("AWorld Agent Chat Assistant") | |
st.info("Please select an Agent from the left sidebar to start") | |
try: | |
agent_page() | |
except Exception as e: | |
logger.error(f">>> Error: {traceback.format_exc()}") | |
st.error(f"Error: {str(e)}") | |