mjschock commited on
Commit
81d00fe
·
unverified ·
1 Parent(s): 81917a3

Add initial implementation of AgentRunner and agent graph; include .gitignore and update requirements

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. agent.py +41 -0
  3. app.py +3 -14
  4. graph.py +92 -0
  5. requirements.txt +8 -2
  6. tools.py +50 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ __pycache__
2
+ .pytest_cache
3
+ .venv
agent.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from graph import agent_graph
4
+
5
+ # Configure logging
6
+ logging.basicConfig(level=logging.INFO) # Default to INFO level
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Enable LiteLLM debug logging only if environment variable is set
10
+ import litellm
11
+ if os.getenv('LITELLM_DEBUG', 'false').lower() == 'true':
12
+ litellm.set_verbose = True
13
+ logger.setLevel(logging.DEBUG)
14
+ else:
15
+ litellm.set_verbose = False
16
+ logger.setLevel(logging.INFO)
17
+
18
+ class AgentRunner:
19
+ def __init__(self):
20
+ logger.debug("Initializing AgentRunner")
21
+ logger.info("AgentRunner initialized.")
22
+
23
+ def __call__(self, question: str) -> str:
24
+ logger.debug(f"Processing question: {question[:50]}...")
25
+ logger.info(f"Agent received question (first 50 chars): {question[:50]}...")
26
+ try:
27
+ # Run the graph with the question
28
+ result = agent_graph.invoke({
29
+ "messages": [],
30
+ "question": question,
31
+ "answer": None
32
+ })
33
+
34
+ # Extract and return the answer
35
+ answer = result["answer"]
36
+ logger.debug(f"Successfully generated answer: {answer}")
37
+ logger.info(f"Agent returning answer: {answer}")
38
+ return answer
39
+ except Exception as e:
40
+ logger.error(f"Error in agent execution: {str(e)}", exc_info=True)
41
+ raise
app.py CHANGED
@@ -1,27 +1,16 @@
1
  import os
2
  import gradio as gr
3
  import requests
4
- import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
- # --- Basic Agent Definition ---
12
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
13
- class BasicAgent:
14
- def __init__(self):
15
- print("BasicAgent initialized.")
16
- def __call__(self, question: str) -> str:
17
- print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
21
-
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
24
- Fetches all questions, runs the BasicAgent on them, submits all answers,
25
  and displays the results.
26
  """
27
  # --- Determine HF Space Runtime URL and Repo URL ---
@@ -40,7 +29,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
40
 
41
  # 1. Instantiate Agent ( modify this part to create your agent)
42
  try:
43
- agent = BasicAgent()
44
  except Exception as e:
45
  print(f"Error instantiating agent: {e}")
46
  return f"Error initializing agent: {e}", None
 
1
  import os
2
  import gradio as gr
3
  import requests
 
4
  import pandas as pd
5
+ from agent import AgentRunner
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
 
 
 
 
 
 
 
 
 
 
 
11
  def run_and_submit_all( profile: gr.OAuthProfile | None):
12
  """
13
+ Fetches all questions, runs the AgentRunner on them, submits all answers,
14
  and displays the results.
15
  """
16
  # --- Determine HF Space Runtime URL and Repo URL ---
 
29
 
30
  # 1. Instantiate Agent ( modify this part to create your agent)
31
  try:
32
+ agent = AgentRunner()
33
  except Exception as e:
34
  print(f"Error instantiating agent: {e}")
35
  return f"Error initializing agent: {e}", None
graph.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import TypedDict
3
+ from langgraph.graph import StateGraph, END
4
+ from smolagents import ToolCallingAgent, LiteLLMModel
5
+ from tools import tools
6
+ import yaml
7
+ import importlib.resources
8
+
9
+ # Configure logging
10
+ logging.basicConfig(level=logging.INFO)
11
+ logger = logging.getLogger(__name__)
12
+
13
+ # Define the state for our agent graph
14
+ class AgentState(TypedDict):
15
+ messages: list
16
+ question: str
17
+ answer: str | None
18
+
19
+ class AgentNode:
20
+ def __init__(self):
21
+ # Load default prompt templates
22
+ prompt_templates = yaml.safe_load(
23
+ importlib.resources.files("smolagents.prompts").joinpath("toolcalling_agent.yaml").read_text()
24
+ )
25
+
26
+ # Log the default system prompt
27
+ logger.info("Default system prompt:")
28
+ logger.info("-" * 80)
29
+ logger.info(prompt_templates["system_prompt"])
30
+ logger.info("-" * 80)
31
+
32
+ # # Define our custom system prompt
33
+ # custom_system_prompt = "..."
34
+
35
+ # # Update the system prompt in the loaded templates
36
+ # prompt_templates["system_prompt"] = custom_system_prompt
37
+
38
+ # Log our custom system prompt
39
+ # logger.info("Custom system prompt:")
40
+ # logger.info("-" * 80)
41
+ # logger.info(custom_system_prompt)
42
+ # logger.info("-" * 80)
43
+
44
+ # In"itialize the model and agent
45
+ self.model = LiteLLMModel(
46
+ model="ollama/codellama",
47
+ temperature=0.0,
48
+ max_tokens=4096,
49
+ top_p=0.9,
50
+ frequency_penalty=0.0,
51
+ presence_penalty=0.0,
52
+ stop=["Observation:"],
53
+ )
54
+
55
+ self.agent = ToolCallingAgent(
56
+ model=self.model,
57
+ prompt_templates=prompt_templates,
58
+ tools=tools
59
+ )
60
+
61
+ def __call__(self, state: AgentState) -> AgentState:
62
+ try:
63
+ # Process the question through the agent
64
+ result = self.agent.run(state["question"])
65
+
66
+ # Update the state with the answer
67
+ state["answer"] = result
68
+ return state
69
+
70
+ except Exception as e:
71
+ logger.error(f"Error in agent node: {str(e)}", exc_info=True)
72
+ state["answer"] = f"Error: {str(e)}"
73
+ return state
74
+
75
+ def build_agent_graph():
76
+ # Create the graph
77
+ graph = StateGraph(AgentState)
78
+
79
+ # Add the agent node
80
+ graph.add_node("agent", AgentNode())
81
+
82
+ # Add edges
83
+ graph.add_edge("agent", END)
84
+
85
+ # Set the entry point
86
+ graph.set_entry_point("agent")
87
+
88
+ # Compile the graph
89
+ return graph.compile()
90
+
91
+ # Create an instance of the compiled graph
92
+ agent_graph = build_agent_graph()
requirements.txt CHANGED
@@ -1,2 +1,8 @@
1
- gradio
2
- requests
 
 
 
 
 
 
 
1
+ duckduckgo-search>=8.0.1
2
+ gradio[oauth]>=5.26.0
3
+ langgraph>=0.3.34
4
+ pytest>=8.3.5
5
+ pytest-cov>=6.1.1
6
+ requests>=2.32.3
7
+ smolagents[litellm]>=0.1.3
8
+ wikipedia-api>=0.8.1
tools.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from smolagents import DuckDuckGoSearchTool, WikipediaSearchTool, Tool
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+ class GeneralSearchTool(Tool):
7
+ name = "search"
8
+ description = """Performs a general web search using both DuckDuckGo and Wikipedia, then returns the combined search results."""
9
+ inputs = {"query": {"type": "string", "description": "The search query to perform."}}
10
+ output_type = "string"
11
+
12
+ def __init__(self, max_results=10, **kwargs):
13
+ super().__init__()
14
+ self.max_results = max_results
15
+ self.ddg_tool = DuckDuckGoSearchTool()
16
+ self.wiki_tool = WikipediaSearchTool()
17
+
18
+ def forward(self, query: str) -> str:
19
+ # Get DuckDuckGo results
20
+ try:
21
+ ddg_results = self.ddg_tool.forward(query)
22
+ except Exception as e:
23
+ ddg_results = "No DuckDuckGo results found."
24
+ logger.warning(f"DuckDuckGo search failed: {str(e)}")
25
+
26
+ # Get Wikipedia results
27
+ try:
28
+ wiki_results = self.wiki_tool.forward(query)
29
+ except Exception as e:
30
+ wiki_results = "No Wikipedia results found."
31
+ logger.warning(f"Wikipedia search failed: {str(e)}")
32
+
33
+ # Combine and format results
34
+ output = []
35
+ if ddg_results and ddg_results != "No DuckDuckGo results found.":
36
+ output.append("## DuckDuckGo Search Results\n\n" + ddg_results)
37
+ if wiki_results and wiki_results != "No Wikipedia results found.":
38
+ output.append("## Wikipedia Results\n\n" + wiki_results)
39
+
40
+ if not output:
41
+ raise Exception("No results found! Try a less restrictive/shorter query.")
42
+
43
+ return "\n\n---\n\n".join(output)
44
+
45
+ # Export all tools
46
+ tools = [
47
+ # DuckDuckGoSearchTool(),
48
+ GeneralSearchTool(),
49
+ # WikipediaSearchTool(),
50
+ ]