mjschock's picture
Refactor graph.py and test_agent.py by removing unused imports to streamline code and improve readability. This includes the removal of uuid, requests, and unnecessary components from langchain_core.
f622879 unverified
raw
history blame
5.63 kB
"""Define the agent graph and its components."""
import logging
import os
from typing import Dict, List, Optional, TypedDict, Union
import yaml
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, StateGraph
from langgraph.types import interrupt
from smolagents import CodeAgent, LiteLLMModel
from configuration import Configuration
from tools import tools
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Enable LiteLLM debug logging only if environment variable is set
import litellm
if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
litellm.set_verbose = True
logger.setLevel(logging.DEBUG)
else:
litellm.set_verbose = False
logger.setLevel(logging.INFO)
# Configure LiteLLM to drop unsupported parameters
litellm.drop_params = True
# Load default prompt templates from local file
current_dir = os.path.dirname(os.path.abspath(__file__))
prompts_dir = os.path.join(current_dir, "prompts")
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
with open(yaml_path, "r") as f:
prompt_templates = yaml.safe_load(f)
# Initialize the model and agent using configuration
config = Configuration()
model = LiteLLMModel(
api_base=config.api_base,
api_key=config.api_key,
model_id=config.model_id,
)
agent = CodeAgent(
add_base_tools=True,
max_steps=1, # Execute one step at a time
model=model,
prompt_templates=prompt_templates,
tools=tools,
verbosity_level=logging.DEBUG,
)
class AgentState(TypedDict):
"""State for the agent graph."""
messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
question: str
answer: Optional[str]
step_logs: List[Dict]
is_complete: bool
step_count: int
class AgentNode:
"""Node that runs the agent."""
def __init__(self, agent: CodeAgent):
"""Initialize the agent node with an agent."""
self.agent = agent
def __call__(
self, state: AgentState, config: Optional[RunnableConfig] = None
) -> AgentState:
"""Run the agent on the current state."""
# Log current state
logger.info("Current state before processing:")
logger.info(f"Messages: {state['messages']}")
logger.info(f"Question: {state['question']}")
logger.info(f"Answer: {state['answer']}")
# Get configuration
cfg = Configuration.from_runnable_config(config)
logger.info(f"Using configuration: {cfg}")
# Log execution start
logger.info("Starting agent execution")
# Run the agent
result = self.agent.run(state["question"])
# Log result
logger.info(f"Agent execution result type: {type(result)}")
logger.info(f"Agent execution result value: {result}")
# Update state
new_state = state.copy()
new_state["messages"].append(AIMessage(content=result))
new_state["answer"] = result
new_state["step_count"] += 1
# Log updated state
logger.info("Updated state after processing:")
logger.info(f"Messages: {new_state['messages']}")
logger.info(f"Question: {new_state['question']}")
logger.info(f"Answer: {new_state['answer']}")
return new_state
class StepCallbackNode:
"""Node that handles step callbacks and user interaction."""
def __call__(
self, state: AgentState, config: Optional[RunnableConfig] = None
) -> AgentState:
"""Handle step callback and user interaction."""
# Get configuration
cfg = Configuration.from_runnable_config(config)
# Log the step
step_log = {
"step": state["step_count"],
"messages": [msg.content for msg in state["messages"]],
"question": state["question"],
"answer": state["answer"],
}
state["step_logs"].append(step_log)
try:
# Use interrupt for user input
user_input = interrupt(
"Press 'c' to continue, 'q' to quit, or 'i' for more info: "
)
if user_input.lower() == "q":
state["is_complete"] = True
return state
elif user_input.lower() == "i":
logger.info(f"Current step: {state['step_count']}")
logger.info(f"Question: {state['question']}")
logger.info(f"Current answer: {state['answer']}")
return self(state, config) # Recursively call for new input
elif user_input.lower() == "c":
return state
else:
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
return self(state, config) # Recursively call for new input
except Exception as e:
logger.warning(f"Error during interrupt: {str(e)}")
return state
def build_agent_graph(agent: AgentNode) -> StateGraph:
"""Build the agent graph."""
# Initialize the graph
workflow = StateGraph(AgentState)
# Add nodes
workflow.add_node("agent", agent)
workflow.add_node("callback", StepCallbackNode())
# Add edges
workflow.add_edge("agent", "callback")
workflow.add_conditional_edges(
"callback",
lambda x: END if x["is_complete"] else "agent",
{True: END, False: "agent"},
)
# Set entry point
workflow.set_entry_point("agent")
return workflow.compile()
# Initialize the agent graph
agent_graph = build_agent_graph(AgentNode(agent))