Spaces:
Build error
Build error
Refactor app.py and update import paths in test_agent.py to improve code organization. Introduce new files for agent configuration, graph definition, and tools, enhancing the overall structure and functionality of the agent system.
Browse files- agent.py → api/runner.py +8 -20
- app.py +0 -1
- configuration.py → services/configuration.py +0 -0
- graph.py → services/graph.py +13 -17
- tools.py → services/tools.py +0 -0
- test_agent.py +23 -3
agent.py → api/runner.py
RENAMED
@@ -4,7 +4,7 @@ import uuid
|
|
4 |
|
5 |
from langgraph.types import Command
|
6 |
|
7 |
-
from graph import agent_graph
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
@@ -86,32 +86,20 @@ class AgentRunner:
|
|
86 |
}
|
87 |
logger.info(f"Initial state: {initial_state}")
|
88 |
|
89 |
-
# Use stream to get
|
90 |
logger.info("Starting graph stream for initial question")
|
91 |
for chunk in self.graph.stream(initial_state, config):
|
92 |
logger.debug(f"Received chunk: {chunk}")
|
93 |
-
|
94 |
if isinstance(chunk, dict):
|
95 |
if "__interrupt__" in chunk:
|
96 |
logger.info("Detected interrupt in stream")
|
97 |
logger.info(f"Interrupt details: {chunk['__interrupt__']}")
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
logger.debug(f"Received resume result: {result}")
|
105 |
-
if isinstance(result, dict):
|
106 |
-
answer = self._extract_answer(result)
|
107 |
-
if answer:
|
108 |
-
self.last_state = result
|
109 |
-
return answer
|
110 |
-
else:
|
111 |
-
answer = self._extract_answer(chunk)
|
112 |
-
if answer:
|
113 |
-
self.last_state = chunk
|
114 |
-
return answer
|
115 |
else:
|
116 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
117 |
else:
|
|
|
4 |
|
5 |
from langgraph.types import Command
|
6 |
|
7 |
+
from services.graph import agent_graph
|
8 |
|
9 |
# Configure logging
|
10 |
logging.basicConfig(level=logging.INFO) # Default to INFO level
|
|
|
86 |
}
|
87 |
logger.info(f"Initial state: {initial_state}")
|
88 |
|
89 |
+
# Use stream to get results
|
90 |
logger.info("Starting graph stream for initial question")
|
91 |
for chunk in self.graph.stream(initial_state, config):
|
92 |
logger.debug(f"Received chunk: {chunk}")
|
|
|
93 |
if isinstance(chunk, dict):
|
94 |
if "__interrupt__" in chunk:
|
95 |
logger.info("Detected interrupt in stream")
|
96 |
logger.info(f"Interrupt details: {chunk['__interrupt__']}")
|
97 |
+
# Let the graph handle the interrupt naturally
|
98 |
+
continue
|
99 |
+
answer = self._extract_answer(chunk)
|
100 |
+
if answer:
|
101 |
+
self.last_state = chunk
|
102 |
+
return answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
else:
|
104 |
logger.debug(f"Skipping chunk without answer: {chunk}")
|
105 |
else:
|
app.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
5 |
import requests
|
6 |
-
|
7 |
from agent import AgentRunner
|
8 |
|
9 |
# (Keep Constants as is)
|
|
|
3 |
import gradio as gr
|
4 |
import pandas as pd
|
5 |
import requests
|
|
|
6 |
from agent import AgentRunner
|
7 |
|
8 |
# (Keep Constants as is)
|
configuration.py → services/configuration.py
RENAMED
File without changes
|
graph.py → services/graph.py
RENAMED
@@ -6,14 +6,14 @@ from datetime import datetime
|
|
6 |
from typing import Dict, List, Optional, TypedDict, Union
|
7 |
|
8 |
import yaml
|
|
|
9 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
10 |
from langchain_core.runnables import RunnableConfig
|
11 |
from langgraph.graph import END, StateGraph
|
12 |
from langgraph.types import interrupt
|
13 |
from smolagents import CodeAgent, LiteLLMModel
|
14 |
|
15 |
-
from
|
16 |
-
from tools import tools
|
17 |
|
18 |
# Configure logging
|
19 |
logging.basicConfig(level=logging.INFO)
|
@@ -33,7 +33,7 @@ else:
|
|
33 |
litellm.drop_params = True
|
34 |
|
35 |
# Load default prompt templates from local file
|
36 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
39 |
|
@@ -182,9 +182,7 @@ class StepCallbackNode:
|
|
182 |
logger.info(f"Current answer: {state['answer']}")
|
183 |
return state
|
184 |
elif user_input.lower() == "c":
|
185 |
-
#
|
186 |
-
if state["answer"]:
|
187 |
-
state["is_complete"] = True
|
188 |
return state
|
189 |
else:
|
190 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
@@ -192,9 +190,7 @@ class StepCallbackNode:
|
|
192 |
|
193 |
except Exception as e:
|
194 |
logger.warning(f"Error during interrupt: {str(e)}")
|
195 |
-
#
|
196 |
-
if state["answer"]:
|
197 |
-
state["is_complete"] = True
|
198 |
return state
|
199 |
|
200 |
|
@@ -213,19 +209,19 @@ def build_agent_graph(agent: AgentNode) -> StateGraph:
|
|
213 |
# Add conditional edges for callback
|
214 |
def should_continue(state: AgentState) -> str:
|
215 |
"""Determine the next node based on state."""
|
216 |
-
# If we have
|
217 |
-
if state["answer"]
|
218 |
-
logger.info(
|
219 |
-
return
|
220 |
|
221 |
# If we have an answer but it's not complete, continue
|
222 |
-
if state["
|
223 |
logger.info(f"Found answer but not complete: {state['answer']}")
|
224 |
return "agent"
|
225 |
|
226 |
-
# If we have
|
227 |
-
logger.info("
|
228 |
-
return
|
229 |
|
230 |
workflow.add_conditional_edges(
|
231 |
"callback", should_continue, {END: END, "agent": "agent"}
|
|
|
6 |
from typing import Dict, List, Optional, TypedDict, Union
|
7 |
|
8 |
import yaml
|
9 |
+
from services.configuration import Configuration
|
10 |
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
11 |
from langchain_core.runnables import RunnableConfig
|
12 |
from langgraph.graph import END, StateGraph
|
13 |
from langgraph.types import interrupt
|
14 |
from smolagents import CodeAgent, LiteLLMModel
|
15 |
|
16 |
+
from services.tools import tools
|
|
|
17 |
|
18 |
# Configure logging
|
19 |
logging.basicConfig(level=logging.INFO)
|
|
|
33 |
litellm.drop_params = True
|
34 |
|
35 |
# Load default prompt templates from local file
|
36 |
+
current_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
37 |
prompts_dir = os.path.join(current_dir, "prompts")
|
38 |
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
39 |
|
|
|
182 |
logger.info(f"Current answer: {state['answer']}")
|
183 |
return state
|
184 |
elif user_input.lower() == "c":
|
185 |
+
# Continue without marking as complete
|
|
|
|
|
186 |
return state
|
187 |
else:
|
188 |
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
|
|
190 |
|
191 |
except Exception as e:
|
192 |
logger.warning(f"Error during interrupt: {str(e)}")
|
193 |
+
# Continue without marking as complete
|
|
|
|
|
194 |
return state
|
195 |
|
196 |
|
|
|
209 |
# Add conditional edges for callback
|
210 |
def should_continue(state: AgentState) -> str:
|
211 |
"""Determine the next node based on state."""
|
212 |
+
# If we have no answer, continue
|
213 |
+
if not state["answer"]:
|
214 |
+
logger.info("No answer found, continuing")
|
215 |
+
return "agent"
|
216 |
|
217 |
# If we have an answer but it's not complete, continue
|
218 |
+
if not state["is_complete"]:
|
219 |
logger.info(f"Found answer but not complete: {state['answer']}")
|
220 |
return "agent"
|
221 |
|
222 |
+
# If we have an answer and it's complete, we're done
|
223 |
+
logger.info(f"Found complete answer: {state['answer']}")
|
224 |
+
return END
|
225 |
|
226 |
workflow.add_conditional_edges(
|
227 |
"callback", should_continue, {END: END, "agent": "agent"}
|
tools.py → services/tools.py
RENAMED
File without changes
|
test_agent.py
CHANGED
@@ -2,7 +2,7 @@ import logging
|
|
2 |
|
3 |
import pytest
|
4 |
|
5 |
-
from
|
6 |
|
7 |
# Configure test logger
|
8 |
test_logger = logging.getLogger("test_agent")
|
@@ -194,9 +194,29 @@ def test_simple_math_calculation_with_steps():
|
|
194 |
|
195 |
# Verify final answer
|
196 |
expected_result = 1302.678
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
197 |
assert (
|
198 |
-
|
199 |
-
), f"Response should contain the result {expected_result}"
|
|
|
|
|
200 |
assert (
|
201 |
"final_answer" in response.lower()
|
202 |
), "Response should indicate it's using final_answer"
|
|
|
2 |
|
3 |
import pytest
|
4 |
|
5 |
+
from api.runner import AgentRunner
|
6 |
|
7 |
# Configure test logger
|
8 |
test_logger = logging.getLogger("test_agent")
|
|
|
194 |
|
195 |
# Verify final answer
|
196 |
expected_result = 1302.678
|
197 |
+
|
198 |
+
# Extract all numbers from the response
|
199 |
+
import re
|
200 |
+
|
201 |
+
# First check for LaTeX formatting
|
202 |
+
latex_match = re.search(r"\\boxed{([^}]+)}", response)
|
203 |
+
if latex_match:
|
204 |
+
# Extract number from LaTeX box
|
205 |
+
latex_content = latex_match.group(1)
|
206 |
+
numbers = re.findall(r"\d+\.?\d*", latex_content)
|
207 |
+
else:
|
208 |
+
# Extract all numbers from the response
|
209 |
+
numbers = re.findall(r"\d+\.?\d*", response)
|
210 |
+
|
211 |
+
assert numbers, "Response should contain at least one number"
|
212 |
+
|
213 |
+
# Check if any number matches the expected result
|
214 |
+
has_correct_result = any(abs(float(n) - expected_result) < 0.001 for n in numbers)
|
215 |
assert (
|
216 |
+
has_correct_result
|
217 |
+
), f"Response should contain the result {expected_result}, got {response}"
|
218 |
+
|
219 |
+
# Verify the response indicates it's a final answer
|
220 |
assert (
|
221 |
"final_answer" in response.lower()
|
222 |
), "Response should indicate it's using final_answer"
|