Spaces:
Sleeping
Sleeping
"""Define the agent graph and its components.""" | |
import logging | |
import os | |
from typing import Dict, List, Optional, TypedDict, Union | |
import yaml | |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage | |
from langchain_core.runnables import RunnableConfig | |
from langgraph.graph import END, StateGraph | |
from langgraph.types import interrupt | |
from smolagents import CodeAgent, LiteLLMModel | |
from configuration import Configuration | |
from tools import tools | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Enable LiteLLM debug logging only if environment variable is set | |
import litellm | |
if os.getenv("LITELLM_DEBUG", "false").lower() == "true": | |
litellm.set_verbose = True | |
logger.setLevel(logging.DEBUG) | |
else: | |
litellm.set_verbose = False | |
logger.setLevel(logging.INFO) | |
# Configure LiteLLM to drop unsupported parameters | |
litellm.drop_params = True | |
# Load default prompt templates from local file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
prompts_dir = os.path.join(current_dir, "prompts") | |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml") | |
with open(yaml_path, "r") as f: | |
prompt_templates = yaml.safe_load(f) | |
# Initialize the model and agent using configuration | |
config = Configuration() | |
model = LiteLLMModel( | |
api_base=config.api_base, | |
api_key=config.api_key, | |
model_id=config.model_id, | |
) | |
agent = CodeAgent( | |
add_base_tools=True, | |
max_steps=1, # Execute one step at a time | |
model=model, | |
prompt_templates=prompt_templates, | |
tools=tools, | |
verbosity_level=logging.DEBUG, | |
) | |
class AgentState(TypedDict): | |
"""State for the agent graph.""" | |
messages: List[Union[HumanMessage, AIMessage, SystemMessage]] | |
question: str | |
answer: Optional[str] | |
step_logs: List[Dict] | |
is_complete: bool | |
step_count: int | |
class AgentNode: | |
"""Node that runs the agent.""" | |
def __init__(self, agent: CodeAgent): | |
"""Initialize the agent node with an agent.""" | |
self.agent = agent | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Run the agent on the current state.""" | |
# Log current state | |
logger.info("Current state before processing:") | |
logger.info(f"Messages: {state['messages']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Answer: {state['answer']}") | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
logger.info(f"Using configuration: {cfg}") | |
# Log execution start | |
logger.info("Starting agent execution") | |
# Run the agent | |
result = self.agent.run(state["question"]) | |
# Log result | |
logger.info(f"Agent execution result type: {type(result)}") | |
logger.info(f"Agent execution result value: {result}") | |
# Update state | |
new_state = state.copy() | |
new_state["messages"].append(AIMessage(content=result)) | |
new_state["answer"] = result | |
new_state["step_count"] += 1 | |
# Log updated state | |
logger.info("Updated state after processing:") | |
logger.info(f"Messages: {new_state['messages']}") | |
logger.info(f"Question: {new_state['question']}") | |
logger.info(f"Answer: {new_state['answer']}") | |
return new_state | |
class StepCallbackNode: | |
"""Node that handles step callbacks and user interaction.""" | |
def __call__( | |
self, state: AgentState, config: Optional[RunnableConfig] = None | |
) -> AgentState: | |
"""Handle step callback and user interaction.""" | |
# Get configuration | |
cfg = Configuration.from_runnable_config(config) | |
# Log the step | |
step_log = { | |
"step": state["step_count"], | |
"messages": [msg.content for msg in state["messages"]], | |
"question": state["question"], | |
"answer": state["answer"], | |
} | |
state["step_logs"].append(step_log) | |
try: | |
# Use interrupt for user input | |
user_input = interrupt( | |
"Press 'c' to continue, 'q' to quit, or 'i' for more info: " | |
) | |
if user_input.lower() == "q": | |
state["is_complete"] = True | |
return state | |
elif user_input.lower() == "i": | |
logger.info(f"Current step: {state['step_count']}") | |
logger.info(f"Question: {state['question']}") | |
logger.info(f"Current answer: {state['answer']}") | |
return self(state, config) # Recursively call for new input | |
elif user_input.lower() == "c": | |
return state | |
else: | |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.") | |
return self(state, config) # Recursively call for new input | |
except Exception as e: | |
logger.warning(f"Error during interrupt: {str(e)}") | |
return state | |
def build_agent_graph(agent: AgentNode) -> StateGraph: | |
"""Build the agent graph.""" | |
# Initialize the graph | |
workflow = StateGraph(AgentState) | |
# Add nodes | |
workflow.add_node("agent", agent) | |
workflow.add_node("callback", StepCallbackNode()) | |
# Add edges | |
workflow.add_edge("agent", "callback") | |
workflow.add_conditional_edges( | |
"callback", | |
lambda x: END if x["is_complete"] else "agent", | |
{True: END, False: "agent"}, | |
) | |
# Set entry point | |
workflow.set_entry_point("agent") | |
return workflow.compile() | |
# Initialize the agent graph | |
agent_graph = build_agent_graph(AgentNode(agent)) | |