mjschock commited on
Commit
401799d
·
unverified ·
1 Parent(s): 4ff8224

Refactor agent.py and graph.py to enhance agent functionality and logging. Introduce Configuration class for managing parameters, improve state handling in AgentRunner, and update agent graph to support step logging and user interaction. Add new tests for agent capabilities and update requirements for code formatting tools.

Browse files
Files changed (7) hide show
  1. agent.py +37 -18
  2. app.py +51 -27
  3. configuration.py +33 -0
  4. graph.py +171 -111
  5. requirements.txt +2 -0
  6. test_agent.py +170 -67
  7. tools.py +11 -6
agent.py CHANGED
@@ -1,5 +1,7 @@
1
- import os
2
  import logging
 
 
 
3
  from graph import agent_graph
4
 
5
  # Configure logging
@@ -8,34 +10,51 @@ logger = logging.getLogger(__name__)
8
 
9
  # Enable LiteLLM debug logging only if environment variable is set
10
  import litellm
11
- if os.getenv('LITELLM_DEBUG', 'false').lower() == 'true':
 
12
  litellm.set_verbose = True
13
  logger.setLevel(logging.DEBUG)
14
  else:
15
  litellm.set_verbose = False
16
  logger.setLevel(logging.INFO)
17
 
 
18
  class AgentRunner:
 
 
19
  def __init__(self):
20
- logger.debug("Initializing AgentRunner")
21
- logger.info("AgentRunner initialized.")
 
 
22
 
23
  def __call__(self, question: str) -> str:
24
- logger.debug(f"Processing question: {question[:50]}...")
25
- logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
 
 
 
 
 
 
26
  try:
27
- # Run the graph with the question
28
- result = agent_graph.invoke({
29
- "messages": [],
30
  "question": question,
31
- "answer": None
32
- })
33
-
34
- # Extract and return the answer
35
- answer = result["answer"]
36
- logger.debug(f"Successfully generated answer: {answer}")
37
- logger.info(f"Agent returning answer: {answer}")
38
- return answer
 
 
 
 
 
 
39
  except Exception as e:
40
- logger.error(f"Error in agent execution: {str(e)}", exc_info=True)
41
  raise
 
 
1
  import logging
2
+ import os
3
+ import uuid
4
+
5
  from graph import agent_graph
6
 
7
  # Configure logging
 
10
 
11
  # Enable LiteLLM debug logging only if environment variable is set
12
  import litellm
13
+
14
+ if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
15
  litellm.set_verbose = True
16
  logger.setLevel(logging.DEBUG)
17
  else:
18
  litellm.set_verbose = False
19
  logger.setLevel(logging.INFO)
20
 
21
+
22
  class AgentRunner:
23
+ """Runner class for the code agent."""
24
+
25
  def __init__(self):
26
+ """Initialize the agent runner with graph and tools."""
27
+ logger.info("Initializing AgentRunner")
28
+ self.graph = agent_graph
29
+ self.last_state = None # Store the last state for testing/debugging
30
 
31
  def __call__(self, question: str) -> str:
32
+ """Process a question through the agent graph and return the answer.
33
+
34
+ Args:
35
+ question: The question to process
36
+
37
+ Returns:
38
+ str: The agent's response
39
+ """
40
  try:
41
+ logger.info(f"Processing question: {question}")
42
+ initial_state = {
 
43
  "question": question,
44
+ "messages": [],
45
+ "answer": None,
46
+ "step_logs": [],
47
+ "is_complete": False, # Initialize is_complete
48
+ "step_count": 0, # Initialize step_count
49
+ }
50
+
51
+ # Generate a unique thread_id for this interaction
52
+ thread_id = str(uuid.uuid4())
53
+ config = {"configurable": {"thread_id": thread_id}}
54
+
55
+ final_state = self.graph.invoke(initial_state, config)
56
+ self.last_state = final_state # Store the final state
57
+ return final_state.get("answer", "No answer generated")
58
  except Exception as e:
59
+ logger.error(f"Error processing question: {str(e)}")
60
  raise
app.py CHANGED
@@ -1,23 +1,26 @@
1
  import os
 
2
  import gradio as gr
3
- import requests
4
  import pandas as pd
 
 
5
  from agent import AgentRunner
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
- def run_and_submit_all( profile: gr.OAuthProfile | None):
 
12
  """
13
  Fetches all questions, runs the AgentRunner on them, submits all answers,
14
  and displays the results.
15
  """
16
  # --- Determine HF Space Runtime URL and Repo URL ---
17
- space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
18
 
19
  if profile:
20
- username= f"{profile.username}"
21
  print(f"User logged in: {username}")
22
  else:
23
  print("User not logged in.")
@@ -44,16 +47,16 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
44
  response.raise_for_status()
45
  questions_data = response.json()
46
  if not questions_data:
47
- print("Fetched questions list is empty.")
48
- return "Fetched questions list is empty or invalid format.", None
49
  print(f"Fetched {len(questions_data)} questions.")
50
  except requests.exceptions.RequestException as e:
51
  print(f"Error fetching questions: {e}")
52
  return f"Error fetching questions: {e}", None
53
  except requests.exceptions.JSONDecodeError as e:
54
- print(f"Error decoding JSON response from questions endpoint: {e}")
55
- print(f"Response text: {response.text[:500]}")
56
- return f"Error decoding server response for questions: {e}", None
57
  except Exception as e:
58
  print(f"An unexpected error occurred fetching questions: {e}")
59
  return f"An unexpected error occurred fetching questions: {e}", None
@@ -70,18 +73,36 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
70
  continue
71
  try:
72
  submitted_answer = agent(question_text)
73
- answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
74
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
 
 
 
 
 
 
 
 
75
  except Exception as e:
76
- print(f"Error running agent on task {task_id}: {e}")
77
- results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
 
 
78
 
79
  if not answers_payload:
80
  print("Agent did not produce any answers to submit.")
81
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
82
 
83
- # 4. Prepare Submission
84
- submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
 
 
 
 
85
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
86
  print(status_update)
87
 
@@ -151,20 +172,19 @@ with gr.Blocks() as demo:
151
 
152
  run_button = gr.Button("Run Evaluation & Submit All Answers")
153
 
154
- status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
 
 
155
  # Removed max_rows=10 from DataFrame constructor
156
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
157
 
158
- run_button.click(
159
- fn=run_and_submit_all,
160
- outputs=[status_output, results_table]
161
- )
162
 
163
  if __name__ == "__main__":
164
- print("\n" + "-"*30 + " App Starting " + "-"*30)
165
  # Check for SPACE_HOST and SPACE_ID at startup for information
166
  space_host_startup = os.getenv("SPACE_HOST")
167
- space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
168
 
169
  if space_host_startup:
170
  print(f"✅ SPACE_HOST found: {space_host_startup}")
@@ -172,14 +192,18 @@ if __name__ == "__main__":
172
  else:
173
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
174
 
175
- if space_id_startup: # Print repo URLs if SPACE_ID is found
176
  print(f"✅ SPACE_ID found: {space_id_startup}")
177
  print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
178
- print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
 
 
179
  else:
180
- print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
 
 
181
 
182
- print("-"*(60 + len(" App Starting ")) + "\n")
183
 
184
  print("Launching Gradio Interface for Basic Agent Evaluation...")
185
- demo.launch(debug=True, share=False)
 
1
  import os
2
+
3
  import gradio as gr
 
4
  import pandas as pd
5
+ import requests
6
+
7
  from agent import AgentRunner
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
12
 
13
+
14
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
15
  """
16
  Fetches all questions, runs the AgentRunner on them, submits all answers,
17
  and displays the results.
18
  """
19
  # --- Determine HF Space Runtime URL and Repo URL ---
20
+ space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
21
 
22
  if profile:
23
+ username = f"{profile.username}"
24
  print(f"User logged in: {username}")
25
  else:
26
  print("User not logged in.")
 
47
  response.raise_for_status()
48
  questions_data = response.json()
49
  if not questions_data:
50
+ print("Fetched questions list is empty.")
51
+ return "Fetched questions list is empty or invalid format.", None
52
  print(f"Fetched {len(questions_data)} questions.")
53
  except requests.exceptions.RequestException as e:
54
  print(f"Error fetching questions: {e}")
55
  return f"Error fetching questions: {e}", None
56
  except requests.exceptions.JSONDecodeError as e:
57
+ print(f"Error decoding JSON response from questions endpoint: {e}")
58
+ print(f"Response text: {response.text[:500]}")
59
+ return f"Error decoding server response for questions: {e}", None
60
  except Exception as e:
61
  print(f"An unexpected error occurred fetching questions: {e}")
62
  return f"An unexpected error occurred fetching questions: {e}", None
 
73
  continue
74
  try:
75
  submitted_answer = agent(question_text)
76
+ answers_payload.append(
77
+ {"task_id": task_id, "submitted_answer": submitted_answer}
78
+ )
79
+ results_log.append(
80
+ {
81
+ "Task ID": task_id,
82
+ "Question": question_text,
83
+ "Submitted Answer": submitted_answer,
84
+ }
85
+ )
86
  except Exception as e:
87
+ print(f"Error running agent on task {task_id}: {e}")
88
+ results_log.append(
89
+ {
90
+ "Task ID": task_id,
91
+ "Question": question_text,
92
+ "Submitted Answer": f"AGENT ERROR: {e}",
93
+ }
94
+ )
95
 
96
  if not answers_payload:
97
  print("Agent did not produce any answers to submit.")
98
  return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
99
 
100
+ # 4. Prepare Submission
101
+ submission_data = {
102
+ "username": username.strip(),
103
+ "agent_code": agent_code,
104
+ "answers": answers_payload,
105
+ }
106
  status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
107
  print(status_update)
108
 
 
172
 
173
  run_button = gr.Button("Run Evaluation & Submit All Answers")
174
 
175
+ status_output = gr.Textbox(
176
+ label="Run Status / Submission Result", lines=5, interactive=False
177
+ )
178
  # Removed max_rows=10 from DataFrame constructor
179
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
180
 
181
+ run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
 
 
 
182
 
183
  if __name__ == "__main__":
184
+ print("\n" + "-" * 30 + " App Starting " + "-" * 30)
185
  # Check for SPACE_HOST and SPACE_ID at startup for information
186
  space_host_startup = os.getenv("SPACE_HOST")
187
+ space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
188
 
189
  if space_host_startup:
190
  print(f"✅ SPACE_HOST found: {space_host_startup}")
 
192
  else:
193
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
194
 
195
+ if space_id_startup: # Print repo URLs if SPACE_ID is found
196
  print(f"✅ SPACE_ID found: {space_id_startup}")
197
  print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
198
+ print(
199
+ f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main"
200
+ )
201
  else:
202
+ print(
203
+ "ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined."
204
+ )
205
 
206
+ print("-" * (60 + len(" App Starting ")) + "\n")
207
 
208
  print("Launching Gradio Interface for Basic Agent Evaluation...")
209
+ demo.launch(debug=True, share=False)
configuration.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define the configurable parameters for the agent."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+ from dataclasses import dataclass, fields
7
+ from typing import Optional
8
+
9
+ from langchain_core.runnables import RunnableConfig
10
+
11
+
12
+ @dataclass(kw_only=True)
13
+ class Configuration:
14
+ """The configuration for the agent."""
15
+
16
+ # API configuration
17
+ api_base: Optional[str] = "http://localhost:11434"
18
+ api_key: Optional[str] = os.getenv("MODEL_API_KEY")
19
+ model_id: Optional[str] = (
20
+ f"ollama/{os.getenv('OLLAMA_MODEL', 'qwen2.5-coder:0.5b')}"
21
+ )
22
+
23
+ # Agent configuration
24
+ my_configurable_param: str = "changeme"
25
+
26
+ @classmethod
27
+ def from_runnable_config(
28
+ cls, config: Optional[RunnableConfig] = None
29
+ ) -> Configuration:
30
+ """Create a Configuration instance from a RunnableConfig object."""
31
+ configurable = (config.get("configurable") or {}) if config else {}
32
+ _fields = {f.name for f in fields(cls) if f.init}
33
+ return cls(**{k: v for k, v in configurable.items() if k in _fields})
graph.py CHANGED
@@ -1,130 +1,190 @@
 
 
1
  import logging
2
- from typing import Callable, List, Optional, TypedDict
3
- from langgraph.graph import StateGraph, END
4
- from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel
5
- from tools import tools
6
- import yaml
7
  import os
8
- import litellm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Configure LiteLLM to drop unsupported parameters
15
  litellm.drop_params = True
16
 
17
- # Define the state for our agent graph
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  class AgentState(TypedDict):
19
- messages: list
 
 
20
  question: str
21
- answer: str | None
 
 
 
 
22
 
23
  class AgentNode:
24
- def __init__(self):
25
- # Load default prompt templates from local file
26
- current_dir = os.path.dirname(os.path.abspath(__file__))
27
- prompts_dir = os.path.join(current_dir, "prompts")
28
- # yaml_path = os.path.join(prompts_dir, "toolcalling_agent.yaml")
29
- yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
30
-
31
- with open(yaml_path, 'r') as f:
32
- prompt_templates = yaml.safe_load(f)
33
-
34
- # Log the default system prompt
35
- logger.info("Default system prompt:")
36
- logger.info("-" * 80)
37
- logger.info(prompt_templates["system_prompt"])
38
- logger.info("-" * 80)
39
-
40
- # # Define our custom system prompt
41
- # custom_system_prompt = "..."
42
-
43
- # # Update the system prompt in the loaded templates
44
- # prompt_templates["system_prompt"] = custom_system_prompt
45
-
46
- # Log our custom system prompt
47
- # logger.info("Custom system prompt:")
48
- # logger.info("-" * 80)
49
- # logger.info(custom_system_prompt)
50
- # logger.info("-" * 80)
51
-
52
- # Initialize the model and agent
53
- self.model = LiteLLMModel(
54
- api_base="http://localhost:11434",
55
- api_key=None,
56
- model_id="ollama/codellama",
57
- )
58
-
59
- # self.agent = ToolCallingAgent(
60
- # max_steps=1,
61
- # model=self.model,
62
- # prompt_templates=prompt_templates,
63
- # tools=tools
64
- # )
65
-
66
- step_callbacks: Optional[List[Callable]] = [
67
- lambda step: logger.info(f"Step {step.step_number} completed: {step.action}")
68
- ]
69
-
70
- self.agent = CodeAgent(
71
- add_base_tools=True,
72
- max_steps=1,
73
- model=self.model,
74
- prompt_templates=prompt_templates,
75
- step_callbacks=step_callbacks,
76
- tools=tools,
77
- verbosity_level=logging.DEBUG
78
- )
79
-
80
- def __call__(self, state: AgentState) -> AgentState:
 
 
 
 
 
 
 
81
  try:
82
- # Log the current state before processing
83
- logger.info("Current state before processing:")
84
- logger.info(f"Messages: {state['messages']}")
85
- logger.info(f"Question: {state['question']}")
86
- logger.info(f"Answer: {state['answer']}")
87
-
88
- # Process the question through the agent
89
- logger.info("Calling agent.run()...")
90
- result = self.agent.run(state["question"])
91
-
92
- # Log the result details
93
- logger.info("Agent run completed:")
94
- logger.info(f"Result type: {type(result)}")
95
- logger.info(f"Result value: {result}")
96
-
97
- # Update the state with the answer
98
- state["answer"] = result
99
-
100
- # Log the updated state
101
- logger.info("Updated state after processing:")
102
- logger.info(f"Messages: {state['messages']}")
103
- logger.info(f"Question: {state['question']}")
104
- logger.info(f"Answer: {state['answer']}")
105
-
106
- return state
107
-
108
  except Exception as e:
109
- logger.error(f"Error in agent node: {str(e)}", exc_info=True)
110
- state["answer"] = f"Error: {str(e)}"
111
  return state
112
 
113
- def build_agent_graph():
114
- # Create the graph
115
- graph = StateGraph(AgentState)
116
-
117
- # Add the agent node
118
- graph.add_node("agent", AgentNode())
119
-
 
 
 
120
  # Add edges
121
- graph.add_edge("agent", END)
122
-
123
- # Set the entry point
124
- graph.set_entry_point("agent")
125
-
126
- # Compile the graph
127
- return graph.compile()
128
-
129
- # Create an instance of the compiled graph
130
- agent_graph = build_agent_graph()
 
 
 
 
 
 
1
+ """Define the agent graph and its components."""
2
+
3
  import logging
 
 
 
 
 
4
  import os
5
+ import uuid
6
+ from typing import Dict, List, Optional, TypedDict, Union, cast
7
+
8
+ import yaml
9
+ from langchain_core.language_models import BaseChatModel
10
+ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from langchain_core.runnables import RunnableConfig
14
+ from langgraph.graph import END, StateGraph
15
+ from langgraph.prebuilt import ToolExecutor, ToolNode
16
+ from langgraph.types import interrupt
17
+ from smolagents import CodeAgent, LiteLLMModel, ToolCallingAgent
18
+
19
+ from configuration import Configuration
20
+ from tools import tools
21
 
22
  # Configure logging
23
  logging.basicConfig(level=logging.INFO)
24
  logger = logging.getLogger(__name__)
25
 
26
+ # Enable LiteLLM debug logging only if environment variable is set
27
+ import litellm
28
+
29
+ if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
30
+ litellm.set_verbose = True
31
+ logger.setLevel(logging.DEBUG)
32
+ else:
33
+ litellm.set_verbose = False
34
+ logger.setLevel(logging.INFO)
35
+
36
  # Configure LiteLLM to drop unsupported parameters
37
  litellm.drop_params = True
38
 
39
+ # Load default prompt templates from local file
40
+ current_dir = os.path.dirname(os.path.abspath(__file__))
41
+ prompts_dir = os.path.join(current_dir, "prompts")
42
+ yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
43
+
44
+ with open(yaml_path, "r") as f:
45
+ prompt_templates = yaml.safe_load(f)
46
+
47
+ # Initialize the model and agent using configuration
48
+ config = Configuration()
49
+ model = LiteLLMModel(
50
+ api_base=config.api_base,
51
+ api_key=config.api_key,
52
+ model_id=config.model_id,
53
+ )
54
+
55
+ agent = CodeAgent(
56
+ add_base_tools=True,
57
+ max_steps=1, # Execute one step at a time
58
+ model=model,
59
+ prompt_templates=prompt_templates,
60
+ tools=tools,
61
+ verbosity_level=logging.DEBUG,
62
+ )
63
+
64
+
65
  class AgentState(TypedDict):
66
+ """State for the agent graph."""
67
+
68
+ messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
69
  question: str
70
+ answer: Optional[str]
71
+ step_logs: List[Dict]
72
+ is_complete: bool
73
+ step_count: int
74
+
75
 
76
  class AgentNode:
77
+ """Node that runs the agent."""
78
+
79
+ def __init__(self, agent: CodeAgent):
80
+ """Initialize the agent node with an agent."""
81
+ self.agent = agent
82
+
83
+ def __call__(
84
+ self, state: AgentState, config: Optional[RunnableConfig] = None
85
+ ) -> AgentState:
86
+ """Run the agent on the current state."""
87
+ # Log current state
88
+ logger.info("Current state before processing:")
89
+ logger.info(f"Messages: {state['messages']}")
90
+ logger.info(f"Question: {state['question']}")
91
+ logger.info(f"Answer: {state['answer']}")
92
+
93
+ # Get configuration
94
+ cfg = Configuration.from_runnable_config(config)
95
+ logger.info(f"Using configuration: {cfg}")
96
+
97
+ # Log execution start
98
+ logger.info("Starting agent execution")
99
+
100
+ # Run the agent
101
+ result = self.agent.run(state["question"])
102
+
103
+ # Log result
104
+ logger.info(f"Agent execution result type: {type(result)}")
105
+ logger.info(f"Agent execution result value: {result}")
106
+
107
+ # Update state
108
+ new_state = state.copy()
109
+ new_state["messages"].append(AIMessage(content=result))
110
+ new_state["answer"] = result
111
+ new_state["step_count"] += 1
112
+
113
+ # Log updated state
114
+ logger.info("Updated state after processing:")
115
+ logger.info(f"Messages: {new_state['messages']}")
116
+ logger.info(f"Question: {new_state['question']}")
117
+ logger.info(f"Answer: {new_state['answer']}")
118
+
119
+ return new_state
120
+
121
+
122
+ class StepCallbackNode:
123
+ """Node that handles step callbacks and user interaction."""
124
+
125
+ def __call__(
126
+ self, state: AgentState, config: Optional[RunnableConfig] = None
127
+ ) -> AgentState:
128
+ """Handle step callback and user interaction."""
129
+ # Get configuration
130
+ cfg = Configuration.from_runnable_config(config)
131
+
132
+ # Log the step
133
+ step_log = {
134
+ "step": state["step_count"],
135
+ "messages": [msg.content for msg in state["messages"]],
136
+ "question": state["question"],
137
+ "answer": state["answer"],
138
+ }
139
+ state["step_logs"].append(step_log)
140
+
141
  try:
142
+ # Use interrupt for user input
143
+ user_input = interrupt(
144
+ "Press 'c' to continue, 'q' to quit, or 'i' for more info: "
145
+ )
146
+
147
+ if user_input.lower() == "q":
148
+ state["is_complete"] = True
149
+ return state
150
+ elif user_input.lower() == "i":
151
+ logger.info(f"Current step: {state['step_count']}")
152
+ logger.info(f"Question: {state['question']}")
153
+ logger.info(f"Current answer: {state['answer']}")
154
+ return self(state, config) # Recursively call for new input
155
+ elif user_input.lower() == "c":
156
+ return state
157
+ else:
158
+ logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
159
+ return self(state, config) # Recursively call for new input
160
+
 
 
 
 
 
 
 
161
  except Exception as e:
162
+ logger.warning(f"Error during interrupt: {str(e)}")
 
163
  return state
164
 
165
+
166
+ def build_agent_graph(agent: AgentNode) -> StateGraph:
167
+ """Build the agent graph."""
168
+ # Initialize the graph
169
+ workflow = StateGraph(AgentState)
170
+
171
+ # Add nodes
172
+ workflow.add_node("agent", agent)
173
+ workflow.add_node("callback", StepCallbackNode())
174
+
175
  # Add edges
176
+ workflow.add_edge("agent", "callback")
177
+ workflow.add_conditional_edges(
178
+ "callback",
179
+ lambda x: END if x["is_complete"] else "agent",
180
+ {True: END, False: "agent"},
181
+ )
182
+
183
+ # Set entry point
184
+ workflow.set_entry_point("agent")
185
+
186
+ return workflow.compile()
187
+
188
+
189
+ # Initialize the agent graph
190
+ agent_graph = build_agent_graph(AgentNode(agent))
requirements.txt CHANGED
@@ -1,5 +1,7 @@
 
1
  duckduckgo-search>=8.0.1
2
  gradio[oauth]>=5.26.0
 
3
  langgraph>=0.3.34
4
  pytest>=8.3.5
5
  pytest-cov>=6.1.1
 
1
+ black>=25.1.0
2
  duckduckgo-search>=8.0.1
3
  gradio[oauth]>=5.26.0
4
+ isort>=6.0.1
5
  langgraph>=0.3.34
6
  pytest>=8.3.5
7
  pytest-cov>=6.1.1
test_agent.py CHANGED
@@ -1,84 +1,84 @@
1
  import logging
 
2
  import pytest
3
  import requests
 
 
4
  from agent import AgentRunner
5
 
6
- # Configure logging
7
- logging.basicConfig(
8
- level=logging.INFO,
9
- format='%(asctime)s - %(levelname)s - %(message)s'
10
- )
11
- logger = logging.getLogger(__name__)
12
 
13
  # Suppress specific warnings
14
- pytestmark = pytest.mark.filterwarnings(
15
- "ignore::DeprecationWarning:httpx._models"
16
- )
17
 
18
  # Constants
19
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
20
  QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
21
 
 
22
  @pytest.fixture(scope="session")
23
  def agent():
24
  """Fixture to create and return an AgentRunner instance."""
25
- logger.info("Creating AgentRunner instance")
26
  return AgentRunner()
27
 
 
28
  # @pytest.fixture(scope="session")
29
  # def questions_data():
30
  # """Fixture to fetch questions from the API."""
31
- # logger.info(f"Fetching questions from: {QUESTIONS_URL}")
32
  # try:
33
  # response = requests.get(QUESTIONS_URL, timeout=15)
34
  # response.raise_for_status()
35
  # data = response.json()
36
  # if not data:
37
- # logger.error("Fetched questions list is empty.")
38
  # return []
39
- # logger.info(f"Fetched {len(data)} questions.")
40
  # return data
41
  # except requests.exceptions.RequestException as e:
42
- # logger.error(f"Error fetching questions: {e}")
43
  # return []
44
  # except requests.exceptions.JSONDecodeError as e:
45
- # logger.error(f"Error decoding JSON response from questions endpoint: {e}")
46
  # return []
47
  # except Exception as e:
48
- # logger.error(f"An unexpected error occurred fetching questions: {e}")
49
  # return []
50
  #
51
  # class TestAppQuestions:
52
  # """Test cases for questions from the app."""
53
- #
54
  # def test_first_app_question(self, agent, questions_data):
55
  # """Test the agent's response to the first app question."""
56
  # if not questions_data:
57
  # pytest.skip("No questions available from API")
58
- #
59
  # first_question = questions_data[0]
60
  # question_text = first_question.get("question")
61
  # task_id = first_question.get("task_id")
62
- #
63
  # if not question_text or not task_id:
64
  # pytest.skip("First question is missing required fields")
65
- #
66
- # logger.info(f"Testing with app question: {question_text}")
67
- #
68
  # response = agent(question_text)
69
- # logger.info(f"Agent response: {response}")
70
- #
71
  # # Check that the response contains the expected information
72
  # assert "Mercedes Sosa" in response, "Response should mention Mercedes Sosa"
73
  # assert "studio albums" in response.lower(), "Response should mention studio albums"
74
  # assert "2000" in response and "2009" in response, "Response should mention the year range"
75
- #
76
  # # Verify that a number is mentioned (either as word or digit)
77
  # import re
78
  # number_pattern = r'\b(one|two|three|four|five|six|seven|eight|nine|ten|\d+)\b'
79
  # has_number = bool(re.search(number_pattern, response.lower()))
80
  # assert has_number, "Response should include the number of albums"
81
- #
82
  # # Check for album names in the response
83
  # known_albums = [
84
  # "Corazón Libre",
@@ -89,54 +89,157 @@ def agent():
89
  # ]
90
  # found_albums = [album for album in known_albums if album in response]
91
  # assert len(found_albums) > 0, "Response should mention at least some of the known albums"
92
- #
93
  # # Check for a structured response
94
  # assert re.search(r'\d+\.\s+[^(]+\(\d{4}\)', response), \
95
  # "Response should list albums with years"
96
 
 
97
  class TestBasicCodeAgentCapabilities:
98
- """Test cases for basic CodeAgent capabilities using examples from the YAML file."""
99
-
100
- def test_simple_math_calculation(self, agent):
101
- """Test the agent's ability to perform basic mathematical operations."""
102
- # Test the second example from code_agent.yaml
 
 
 
 
103
  question = "What is the result of the following operation: 5 + 3 + 1294.678?"
104
-
105
- logger.info("Testing simple math calculation capabilities")
106
- logger.info(f"Question: {question}")
107
-
108
- response = agent(question)
109
- logger.info(f"Agent response: {response}")
110
-
111
  # Verify the response contains the correct result
112
  expected_result = str(5 + 3 + 1294.678)
113
- assert expected_result in response, f"Response should contain the result {expected_result}"
114
-
115
- # Check that the response is a clear answer
116
- assert "answer" in response.lower(), "Response should indicate it's providing an answer"
117
-
118
- def test_document_qa_and_image_generation(self, agent):
119
- """Test the agent's ability to process a document QA task and generate an image."""
120
- # Test the first example from code_agent.yaml
121
- question = "Generate an image of the oldest person in this document."
122
-
123
- logger.info("Testing document QA and image generation capabilities")
124
- logger.info(f"Question: {question}")
125
-
126
- response = agent(question)
127
- logger.info(f"Agent response: {response}")
128
-
129
- # Verify the response contains key elements
130
- assert "Bob Wilson" in response, "Response should identify Bob Wilson as the oldest person"
131
- assert "60" in response, "Response should mention the age 60"
132
- assert "engineer" in response, "Response should mention the profession"
133
- assert "Vancouver" in response, "Response should mention the location"
134
-
135
- # Check for image generation related content
136
- assert "image" in response.lower() or "portrait" in response.lower(), \
137
- "Response should indicate image generation"
138
- assert "description" in response.lower(), \
139
- "Response should include a description of the image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  if __name__ == "__main__":
142
- pytest.main([__file__, "-v", "-x"])
 
1
  import logging
2
+
3
  import pytest
4
  import requests
5
+ from langgraph.types import Command
6
+
7
  from agent import AgentRunner
8
 
9
+ # Configure test logger
10
+ test_logger = logging.getLogger("test_agent")
11
+ test_logger.setLevel(logging.INFO)
 
 
 
12
 
13
  # Suppress specific warnings
14
+ pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:httpx._models")
 
 
15
 
16
  # Constants
17
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
18
  QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
19
 
20
+
21
  @pytest.fixture(scope="session")
22
  def agent():
23
  """Fixture to create and return an AgentRunner instance."""
24
+ test_logger.info("Creating AgentRunner instance")
25
  return AgentRunner()
26
 
27
+
28
  # @pytest.fixture(scope="session")
29
  # def questions_data():
30
  # """Fixture to fetch questions from the API."""
31
+ # test_logger.info(f"Fetching questions from: {QUESTIONS_URL}")
32
  # try:
33
  # response = requests.get(QUESTIONS_URL, timeout=15)
34
  # response.raise_for_status()
35
  # data = response.json()
36
  # if not data:
37
+ # test_logger.error("Fetched questions list is empty.")
38
  # return []
39
+ # test_logger.info(f"Fetched {len(data)} questions.")
40
  # return data
41
  # except requests.exceptions.RequestException as e:
42
+ # test_logger.error(f"Error fetching questions: {e}")
43
  # return []
44
  # except requests.exceptions.JSONDecodeError as e:
45
+ # test_logger.error(f"Error decoding JSON response from questions endpoint: {e}")
46
  # return []
47
  # except Exception as e:
48
+ # test_logger.error(f"An unexpected error occurred fetching questions: {e}")
49
  # return []
50
  #
51
  # class TestAppQuestions:
52
  # """Test cases for questions from the app."""
53
+ #
54
  # def test_first_app_question(self, agent, questions_data):
55
  # """Test the agent's response to the first app question."""
56
  # if not questions_data:
57
  # pytest.skip("No questions available from API")
58
+ #
59
  # first_question = questions_data[0]
60
  # question_text = first_question.get("question")
61
  # task_id = first_question.get("task_id")
62
+ #
63
  # if not question_text or not task_id:
64
  # pytest.skip("First question is missing required fields")
65
+ #
66
+ # test_logger.info(f"Testing with app question: {question_text}")
67
+ #
68
  # response = agent(question_text)
69
+ # test_logger.info(f"Agent response: {response}")
70
+ #
71
  # # Check that the response contains the expected information
72
  # assert "Mercedes Sosa" in response, "Response should mention Mercedes Sosa"
73
  # assert "studio albums" in response.lower(), "Response should mention studio albums"
74
  # assert "2000" in response and "2009" in response, "Response should mention the year range"
75
+ #
76
  # # Verify that a number is mentioned (either as word or digit)
77
  # import re
78
  # number_pattern = r'\b(one|two|three|four|five|six|seven|eight|nine|ten|\d+)\b'
79
  # has_number = bool(re.search(number_pattern, response.lower()))
80
  # assert has_number, "Response should include the number of albums"
81
+ #
82
  # # Check for album names in the response
83
  # known_albums = [
84
  # "Corazón Libre",
 
89
  # ]
90
  # found_albums = [album for album in known_albums if album in response]
91
  # assert len(found_albums) > 0, "Response should mention at least some of the known albums"
92
+ #
93
  # # Check for a structured response
94
  # assert re.search(r'\d+\.\s+[^(]+\(\d{4}\)', response), \
95
  # "Response should list albums with years"
96
 
97
+
98
  class TestBasicCodeAgentCapabilities:
99
+ """Test basic capabilities of the code agent."""
100
+
101
+ def setup_method(self):
102
+ """Setup method to initialize the agent before each test."""
103
+ test_logger.info("Creating AgentRunner instance")
104
+ self.agent = AgentRunner()
105
+
106
+ def test_simple_math_calculation_with_steps(self):
107
+ """Test that the agent can perform basic math calculations and log steps."""
108
  question = "What is the result of the following operation: 5 + 3 + 1294.678?"
109
+ test_logger.info(f"Testing math calculation with question: {question}")
110
+
111
+ # Run the agent and get the response
112
+ response = self.agent(question)
113
+
 
 
114
  # Verify the response contains the correct result
115
  expected_result = str(5 + 3 + 1294.678)
116
+ assert (
117
+ expected_result in response
118
+ ), f"Response should contain the result {expected_result}"
119
+
120
+ # Verify step logs exist and have required fields
121
+ assert self.agent.last_state is not None, "Agent should store last state"
122
+ assert "step_logs" in self.agent.last_state, "State should contain step_logs"
123
+ assert (
124
+ len(self.agent.last_state["step_logs"]) > 0
125
+ ), "Should have at least one step logged"
126
+
127
+ # Verify each step has required fields
128
+ for step in self.agent.last_state["step_logs"]:
129
+ assert "step_number" in step, "Each step should have a step_number"
130
+ assert any(
131
+ key in step for key in ["thought", "code", "observation"]
132
+ ), "Each step should have at least one of: thought, code, or observation"
133
+
134
+ # Verify the final answer is indicated
135
+ assert (
136
+ "final_answer" in response.lower()
137
+ ), "Response should indicate it's providing an answer"
138
+
139
+ def test_document_qa_and_image_generation_with_steps(self):
140
+ """Test that the agent can search for information and generate images, with step logging."""
141
+ question = (
142
+ "Search for information about the Mona Lisa and generate an image of it."
143
+ )
144
+ test_logger.info(
145
+ f"Testing document QA and image generation with question: {question}"
146
+ )
147
+
148
+ # Run the agent and get the response
149
+ response = self.agent(question)
150
+
151
+ # Verify the response contains both search and image generation
152
+ assert "mona lisa" in response.lower(), "Response should mention Mona Lisa"
153
+ assert "image" in response.lower(), "Response should mention image generation"
154
+
155
+ # Verify step logs exist and show logical progression
156
+ assert self.agent.last_state is not None, "Agent should store last state"
157
+ assert "step_logs" in self.agent.last_state, "State should contain step_logs"
158
+ assert (
159
+ len(self.agent.last_state["step_logs"]) > 1
160
+ ), "Should have multiple steps logged"
161
+
162
+ # Verify steps show logical progression
163
+ steps = self.agent.last_state["step_logs"]
164
+ search_steps = [step for step in steps if "search" in str(step).lower()]
165
+ image_steps = [step for step in steps if "image" in str(step).lower()]
166
+
167
+ assert len(search_steps) > 0, "Should have search steps"
168
+ assert len(image_steps) > 0, "Should have image generation steps"
169
+
170
+ # Verify each step has required fields
171
+ for step in steps:
172
+ assert "step_number" in step, "Each step should have a step_number"
173
+ assert any(
174
+ key in step for key in ["thought", "code", "observation"]
175
+ ), "Each step should have at least one of: thought, code, or observation"
176
+
177
+
178
+ def test_simple_math_calculation_with_steps():
179
+ """Test that the agent can perform a simple math calculation and verify intermediate steps."""
180
+ agent = AgentRunner()
181
+ question = "What is the result of the following operation: 5 + 3 + 1294.678?"
182
+
183
+ # Process the question
184
+ response = agent(question)
185
+
186
+ # Verify step logs exist and have required fields
187
+ assert agent.last_state is not None, "Last state should be stored"
188
+ step_logs = agent.last_state.get("step_logs", [])
189
+ assert len(step_logs) > 0, "Should have recorded step logs"
190
+
191
+ for step in step_logs:
192
+ assert "step_number" in step, "Each step should have a step number"
193
+ assert any(
194
+ key in step for key in ["thought", "code", "observation"]
195
+ ), "Each step should have at least one of thought/code/observation"
196
+
197
+ # Verify final answer
198
+ expected_result = 1302.678
199
+ assert (
200
+ str(expected_result) in response
201
+ ), f"Response should contain the result {expected_result}"
202
+ assert (
203
+ "final_answer" in response.lower()
204
+ ), "Response should indicate it's using final_answer"
205
+
206
+
207
+ def test_document_qa_and_image_generation_with_steps():
208
+ """Test document QA and image generation with step verification."""
209
+ agent = AgentRunner()
210
+ question = "Can you search for information about the Mona Lisa and generate an image inspired by it?"
211
+
212
+ # Process the question
213
+ response = agent(question)
214
+
215
+ # Verify step logs exist and demonstrate logical progression
216
+ assert agent.last_state is not None, "Last state should be stored"
217
+ step_logs = agent.last_state.get("step_logs", [])
218
+ assert len(step_logs) > 0, "Should have recorded step logs"
219
+
220
+ # Check for search and image generation steps
221
+ has_search_step = False
222
+ has_image_step = False
223
+
224
+ for step in step_logs:
225
+ assert "step_number" in step, "Each step should have a step number"
226
+ assert any(
227
+ key in step for key in ["thought", "code", "observation"]
228
+ ), "Each step should have at least one of thought/code/observation"
229
+
230
+ # Look for search and image steps in thoughts or code
231
+ step_content = str(step.get("thought", "")) + str(step.get("code", ""))
232
+ if "search" in step_content.lower():
233
+ has_search_step = True
234
+ if "image" in step_content.lower() or "dalle" in step_content.lower():
235
+ has_image_step = True
236
+
237
+ assert has_search_step, "Should include a search step"
238
+ assert has_image_step, "Should include an image generation step"
239
+ assert (
240
+ "final_answer" in response.lower()
241
+ ), "Response should indicate it's using final_answer"
242
+
243
 
244
  if __name__ == "__main__":
245
+ pytest.main([__file__, "-s", "-v", "-x"])
tools.py CHANGED
@@ -1,12 +1,16 @@
1
  import logging
2
- from smolagents import DuckDuckGoSearchTool, WikipediaSearchTool, Tool
 
3
 
4
  logger = logging.getLogger(__name__)
5
 
 
6
  class GeneralSearchTool(Tool):
7
  name = "search"
8
  description = """Performs a general web search using both DuckDuckGo and Wikipedia, then returns the combined search results."""
9
- inputs = {"query": {"type": "string", "description": "The search query to perform."}}
 
 
10
  output_type = "string"
11
 
12
  def __init__(self, max_results=10, **kwargs):
@@ -22,26 +26,27 @@ class GeneralSearchTool(Tool):
22
  except Exception as e:
23
  ddg_results = "No DuckDuckGo results found."
24
  logger.warning(f"DuckDuckGo search failed: {str(e)}")
25
-
26
  # Get Wikipedia results
27
  try:
28
  wiki_results = self.wiki_tool.forward(query)
29
  except Exception as e:
30
  wiki_results = "No Wikipedia results found."
31
  logger.warning(f"Wikipedia search failed: {str(e)}")
32
-
33
  # Combine and format results
34
  output = []
35
  if ddg_results and ddg_results != "No DuckDuckGo results found.":
36
  output.append("## DuckDuckGo Search Results\n\n" + ddg_results)
37
  if wiki_results and wiki_results != "No Wikipedia results found.":
38
  output.append("## Wikipedia Results\n\n" + wiki_results)
39
-
40
  if not output:
41
  raise Exception("No results found! Try a less restrictive/shorter query.")
42
-
43
  return "\n\n---\n\n".join(output)
44
 
 
45
  # Export all tools
46
  tools = [
47
  # DuckDuckGoSearchTool(),
 
1
  import logging
2
+
3
+ from smolagents import DuckDuckGoSearchTool, Tool, WikipediaSearchTool
4
 
5
  logger = logging.getLogger(__name__)
6
 
7
+
8
  class GeneralSearchTool(Tool):
9
  name = "search"
10
  description = """Performs a general web search using both DuckDuckGo and Wikipedia, then returns the combined search results."""
11
+ inputs = {
12
+ "query": {"type": "string", "description": "The search query to perform."}
13
+ }
14
  output_type = "string"
15
 
16
  def __init__(self, max_results=10, **kwargs):
 
26
  except Exception as e:
27
  ddg_results = "No DuckDuckGo results found."
28
  logger.warning(f"DuckDuckGo search failed: {str(e)}")
29
+
30
  # Get Wikipedia results
31
  try:
32
  wiki_results = self.wiki_tool.forward(query)
33
  except Exception as e:
34
  wiki_results = "No Wikipedia results found."
35
  logger.warning(f"Wikipedia search failed: {str(e)}")
36
+
37
  # Combine and format results
38
  output = []
39
  if ddg_results and ddg_results != "No DuckDuckGo results found.":
40
  output.append("## DuckDuckGo Search Results\n\n" + ddg_results)
41
  if wiki_results and wiki_results != "No Wikipedia results found.":
42
  output.append("## Wikipedia Results\n\n" + wiki_results)
43
+
44
  if not output:
45
  raise Exception("No results found! Try a less restrictive/shorter query.")
46
+
47
  return "\n\n---\n\n".join(output)
48
 
49
+
50
  # Export all tools
51
  tools = [
52
  # DuckDuckGoSearchTool(),