Spaces:
Build error
Build error
Refactor agent.py and graph.py to enhance agent functionality and logging. Introduce Configuration class for managing parameters, improve state handling in AgentRunner, and update agent graph to support step logging and user interaction. Add new tests for agent capabilities and update requirements for code formatting tools.
Browse files- agent.py +37 -18
- app.py +51 -27
- configuration.py +33 -0
- graph.py +171 -111
- requirements.txt +2 -0
- test_agent.py +170 -67
- tools.py +11 -6
agent.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1 |
-
import os
|
2 |
import logging
|
|
|
|
|
|
|
3 |
from graph import agent_graph
|
4 |
|
5 |
# Configure logging
|
@@ -8,34 +10,51 @@ logger = logging.getLogger(__name__)
|
|
8 |
|
9 |
# Enable LiteLLM debug logging only if environment variable is set
|
10 |
import litellm
|
11 |
-
|
|
|
12 |
litellm.set_verbose = True
|
13 |
logger.setLevel(logging.DEBUG)
|
14 |
else:
|
15 |
litellm.set_verbose = False
|
16 |
logger.setLevel(logging.INFO)
|
17 |
|
|
|
18 |
class AgentRunner:
|
|
|
|
|
19 |
def __init__(self):
|
20 |
-
|
21 |
-
logger.info("AgentRunner
|
|
|
|
|
22 |
|
23 |
def __call__(self, question: str) -> str:
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
try:
|
27 |
-
|
28 |
-
|
29 |
-
"messages": [],
|
30 |
"question": question,
|
31 |
-
"
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
except Exception as e:
|
40 |
-
logger.error(f"Error
|
41 |
raise
|
|
|
|
|
1 |
import logging
|
2 |
+
import os
|
3 |
+
import uuid
|
4 |
+
|
5 |
from graph import agent_graph
|
6 |
|
7 |
# Configure logging
|
|
|
10 |
|
11 |
# Enable LiteLLM debug logging only if environment variable is set
|
12 |
import litellm
|
13 |
+
|
14 |
+
if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
|
15 |
litellm.set_verbose = True
|
16 |
logger.setLevel(logging.DEBUG)
|
17 |
else:
|
18 |
litellm.set_verbose = False
|
19 |
logger.setLevel(logging.INFO)
|
20 |
|
21 |
+
|
22 |
class AgentRunner:
|
23 |
+
"""Runner class for the code agent."""
|
24 |
+
|
25 |
def __init__(self):
|
26 |
+
"""Initialize the agent runner with graph and tools."""
|
27 |
+
logger.info("Initializing AgentRunner")
|
28 |
+
self.graph = agent_graph
|
29 |
+
self.last_state = None # Store the last state for testing/debugging
|
30 |
|
31 |
def __call__(self, question: str) -> str:
|
32 |
+
"""Process a question through the agent graph and return the answer.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
question: The question to process
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
str: The agent's response
|
39 |
+
"""
|
40 |
try:
|
41 |
+
logger.info(f"Processing question: {question}")
|
42 |
+
initial_state = {
|
|
|
43 |
"question": question,
|
44 |
+
"messages": [],
|
45 |
+
"answer": None,
|
46 |
+
"step_logs": [],
|
47 |
+
"is_complete": False, # Initialize is_complete
|
48 |
+
"step_count": 0, # Initialize step_count
|
49 |
+
}
|
50 |
+
|
51 |
+
# Generate a unique thread_id for this interaction
|
52 |
+
thread_id = str(uuid.uuid4())
|
53 |
+
config = {"configurable": {"thread_id": thread_id}}
|
54 |
+
|
55 |
+
final_state = self.graph.invoke(initial_state, config)
|
56 |
+
self.last_state = final_state # Store the final state
|
57 |
+
return final_state.get("answer", "No answer generated")
|
58 |
except Exception as e:
|
59 |
+
logger.error(f"Error processing question: {str(e)}")
|
60 |
raise
|
app.py
CHANGED
@@ -1,23 +1,26 @@
|
|
1 |
import os
|
|
|
2 |
import gradio as gr
|
3 |
-
import requests
|
4 |
import pandas as pd
|
|
|
|
|
5 |
from agent import AgentRunner
|
6 |
|
7 |
# (Keep Constants as is)
|
8 |
# --- Constants ---
|
9 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
10 |
|
11 |
-
|
|
|
12 |
"""
|
13 |
Fetches all questions, runs the AgentRunner on them, submits all answers,
|
14 |
and displays the results.
|
15 |
"""
|
16 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
17 |
-
space_id = os.getenv("SPACE_ID")
|
18 |
|
19 |
if profile:
|
20 |
-
username= f"{profile.username}"
|
21 |
print(f"User logged in: {username}")
|
22 |
else:
|
23 |
print("User not logged in.")
|
@@ -44,16 +47,16 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
44 |
response.raise_for_status()
|
45 |
questions_data = response.json()
|
46 |
if not questions_data:
|
47 |
-
|
48 |
-
|
49 |
print(f"Fetched {len(questions_data)} questions.")
|
50 |
except requests.exceptions.RequestException as e:
|
51 |
print(f"Error fetching questions: {e}")
|
52 |
return f"Error fetching questions: {e}", None
|
53 |
except requests.exceptions.JSONDecodeError as e:
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
except Exception as e:
|
58 |
print(f"An unexpected error occurred fetching questions: {e}")
|
59 |
return f"An unexpected error occurred fetching questions: {e}", None
|
@@ -70,18 +73,36 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
70 |
continue
|
71 |
try:
|
72 |
submitted_answer = agent(question_text)
|
73 |
-
answers_payload.append(
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
except Exception as e:
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
if not answers_payload:
|
80 |
print("Agent did not produce any answers to submit.")
|
81 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
82 |
|
83 |
-
# 4. Prepare Submission
|
84 |
-
submission_data = {
|
|
|
|
|
|
|
|
|
85 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
86 |
print(status_update)
|
87 |
|
@@ -151,20 +172,19 @@ with gr.Blocks() as demo:
|
|
151 |
|
152 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
153 |
|
154 |
-
status_output = gr.Textbox(
|
|
|
|
|
155 |
# Removed max_rows=10 from DataFrame constructor
|
156 |
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
157 |
|
158 |
-
run_button.click(
|
159 |
-
fn=run_and_submit_all,
|
160 |
-
outputs=[status_output, results_table]
|
161 |
-
)
|
162 |
|
163 |
if __name__ == "__main__":
|
164 |
-
print("\n" + "-"*30 + " App Starting " + "-"*30)
|
165 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
166 |
space_host_startup = os.getenv("SPACE_HOST")
|
167 |
-
space_id_startup = os.getenv("SPACE_ID")
|
168 |
|
169 |
if space_host_startup:
|
170 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
@@ -172,14 +192,18 @@ if __name__ == "__main__":
|
|
172 |
else:
|
173 |
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
174 |
|
175 |
-
if space_id_startup:
|
176 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
177 |
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
178 |
-
print(
|
|
|
|
|
179 |
else:
|
180 |
-
print(
|
|
|
|
|
181 |
|
182 |
-
print("-"*(60 + len(" App Starting ")) + "\n")
|
183 |
|
184 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
185 |
-
demo.launch(debug=True, share=False)
|
|
|
1 |
import os
|
2 |
+
|
3 |
import gradio as gr
|
|
|
4 |
import pandas as pd
|
5 |
+
import requests
|
6 |
+
|
7 |
from agent import AgentRunner
|
8 |
|
9 |
# (Keep Constants as is)
|
10 |
# --- Constants ---
|
11 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
12 |
|
13 |
+
|
14 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
15 |
"""
|
16 |
Fetches all questions, runs the AgentRunner on them, submits all answers,
|
17 |
and displays the results.
|
18 |
"""
|
19 |
# --- Determine HF Space Runtime URL and Repo URL ---
|
20 |
+
space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
|
21 |
|
22 |
if profile:
|
23 |
+
username = f"{profile.username}"
|
24 |
print(f"User logged in: {username}")
|
25 |
else:
|
26 |
print("User not logged in.")
|
|
|
47 |
response.raise_for_status()
|
48 |
questions_data = response.json()
|
49 |
if not questions_data:
|
50 |
+
print("Fetched questions list is empty.")
|
51 |
+
return "Fetched questions list is empty or invalid format.", None
|
52 |
print(f"Fetched {len(questions_data)} questions.")
|
53 |
except requests.exceptions.RequestException as e:
|
54 |
print(f"Error fetching questions: {e}")
|
55 |
return f"Error fetching questions: {e}", None
|
56 |
except requests.exceptions.JSONDecodeError as e:
|
57 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
58 |
+
print(f"Response text: {response.text[:500]}")
|
59 |
+
return f"Error decoding server response for questions: {e}", None
|
60 |
except Exception as e:
|
61 |
print(f"An unexpected error occurred fetching questions: {e}")
|
62 |
return f"An unexpected error occurred fetching questions: {e}", None
|
|
|
73 |
continue
|
74 |
try:
|
75 |
submitted_answer = agent(question_text)
|
76 |
+
answers_payload.append(
|
77 |
+
{"task_id": task_id, "submitted_answer": submitted_answer}
|
78 |
+
)
|
79 |
+
results_log.append(
|
80 |
+
{
|
81 |
+
"Task ID": task_id,
|
82 |
+
"Question": question_text,
|
83 |
+
"Submitted Answer": submitted_answer,
|
84 |
+
}
|
85 |
+
)
|
86 |
except Exception as e:
|
87 |
+
print(f"Error running agent on task {task_id}: {e}")
|
88 |
+
results_log.append(
|
89 |
+
{
|
90 |
+
"Task ID": task_id,
|
91 |
+
"Question": question_text,
|
92 |
+
"Submitted Answer": f"AGENT ERROR: {e}",
|
93 |
+
}
|
94 |
+
)
|
95 |
|
96 |
if not answers_payload:
|
97 |
print("Agent did not produce any answers to submit.")
|
98 |
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
99 |
|
100 |
+
# 4. Prepare Submission
|
101 |
+
submission_data = {
|
102 |
+
"username": username.strip(),
|
103 |
+
"agent_code": agent_code,
|
104 |
+
"answers": answers_payload,
|
105 |
+
}
|
106 |
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
|
107 |
print(status_update)
|
108 |
|
|
|
172 |
|
173 |
run_button = gr.Button("Run Evaluation & Submit All Answers")
|
174 |
|
175 |
+
status_output = gr.Textbox(
|
176 |
+
label="Run Status / Submission Result", lines=5, interactive=False
|
177 |
+
)
|
178 |
# Removed max_rows=10 from DataFrame constructor
|
179 |
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
180 |
|
181 |
+
run_button.click(fn=run_and_submit_all, outputs=[status_output, results_table])
|
|
|
|
|
|
|
182 |
|
183 |
if __name__ == "__main__":
|
184 |
+
print("\n" + "-" * 30 + " App Starting " + "-" * 30)
|
185 |
# Check for SPACE_HOST and SPACE_ID at startup for information
|
186 |
space_host_startup = os.getenv("SPACE_HOST")
|
187 |
+
space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
|
188 |
|
189 |
if space_host_startup:
|
190 |
print(f"✅ SPACE_HOST found: {space_host_startup}")
|
|
|
192 |
else:
|
193 |
print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
|
194 |
|
195 |
+
if space_id_startup: # Print repo URLs if SPACE_ID is found
|
196 |
print(f"✅ SPACE_ID found: {space_id_startup}")
|
197 |
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
|
198 |
+
print(
|
199 |
+
f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main"
|
200 |
+
)
|
201 |
else:
|
202 |
+
print(
|
203 |
+
"ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined."
|
204 |
+
)
|
205 |
|
206 |
+
print("-" * (60 + len(" App Starting ")) + "\n")
|
207 |
|
208 |
print("Launching Gradio Interface for Basic Agent Evaluation...")
|
209 |
+
demo.launch(debug=True, share=False)
|
configuration.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define the configurable parameters for the agent."""
|
2 |
+
|
3 |
+
from __future__ import annotations
|
4 |
+
|
5 |
+
import os
|
6 |
+
from dataclasses import dataclass, fields
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from langchain_core.runnables import RunnableConfig
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass(kw_only=True)
|
13 |
+
class Configuration:
|
14 |
+
"""The configuration for the agent."""
|
15 |
+
|
16 |
+
# API configuration
|
17 |
+
api_base: Optional[str] = "http://localhost:11434"
|
18 |
+
api_key: Optional[str] = os.getenv("MODEL_API_KEY")
|
19 |
+
model_id: Optional[str] = (
|
20 |
+
f"ollama/{os.getenv('OLLAMA_MODEL', 'qwen2.5-coder:0.5b')}"
|
21 |
+
)
|
22 |
+
|
23 |
+
# Agent configuration
|
24 |
+
my_configurable_param: str = "changeme"
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def from_runnable_config(
|
28 |
+
cls, config: Optional[RunnableConfig] = None
|
29 |
+
) -> Configuration:
|
30 |
+
"""Create a Configuration instance from a RunnableConfig object."""
|
31 |
+
configurable = (config.get("configurable") or {}) if config else {}
|
32 |
+
_fields = {f.name for f in fields(cls) if f.init}
|
33 |
+
return cls(**{k: v for k, v in configurable.items() if k in _fields})
|
graph.py
CHANGED
@@ -1,130 +1,190 @@
|
|
|
|
|
|
1 |
import logging
|
2 |
-
from typing import Callable, List, Optional, TypedDict
|
3 |
-
from langgraph.graph import StateGraph, END
|
4 |
-
from smolagents import CodeAgent, ToolCallingAgent, LiteLLMModel
|
5 |
-
from tools import tools
|
6 |
-
import yaml
|
7 |
import os
|
8 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Configure logging
|
11 |
logging.basicConfig(level=logging.INFO)
|
12 |
logger = logging.getLogger(__name__)
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
# Configure LiteLLM to drop unsupported parameters
|
15 |
litellm.drop_params = True
|
16 |
|
17 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
class AgentState(TypedDict):
|
19 |
-
|
|
|
|
|
20 |
question: str
|
21 |
-
answer: str
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class AgentNode:
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
# Log
|
35 |
-
logger.info("
|
36 |
-
logger.info("
|
37 |
-
logger.info(
|
38 |
-
logger.info("
|
39 |
-
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
-
#
|
48 |
-
|
49 |
-
|
50 |
-
#
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
#
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
try:
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
logger.info("Updated state after processing:")
|
102 |
-
logger.info(f"Messages: {state['messages']}")
|
103 |
-
logger.info(f"Question: {state['question']}")
|
104 |
-
logger.info(f"Answer: {state['answer']}")
|
105 |
-
|
106 |
-
return state
|
107 |
-
|
108 |
except Exception as e:
|
109 |
-
logger.
|
110 |
-
state["answer"] = f"Error: {str(e)}"
|
111 |
return state
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
# Add edges
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Define the agent graph and its components."""
|
2 |
+
|
3 |
import logging
|
|
|
|
|
|
|
|
|
|
|
4 |
import os
|
5 |
+
import uuid
|
6 |
+
from typing import Dict, List, Optional, TypedDict, Union, cast
|
7 |
+
|
8 |
+
import yaml
|
9 |
+
from langchain_core.language_models import BaseChatModel
|
10 |
+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
11 |
+
from langchain_core.output_parsers import StrOutputParser
|
12 |
+
from langchain_core.prompts import ChatPromptTemplate
|
13 |
+
from langchain_core.runnables import RunnableConfig
|
14 |
+
from langgraph.graph import END, StateGraph
|
15 |
+
from langgraph.prebuilt import ToolExecutor, ToolNode
|
16 |
+
from langgraph.types import interrupt
|
17 |
+
from smolagents import CodeAgent, LiteLLMModel, ToolCallingAgent
|
18 |
+
|
19 |
+
from configuration import Configuration
|
20 |
+
from tools import tools
|
21 |
|
22 |
# Configure logging
|
23 |
logging.basicConfig(level=logging.INFO)
|
24 |
logger = logging.getLogger(__name__)
|
25 |
|
26 |
+
# Enable LiteLLM debug logging only if environment variable is set
|
27 |
+
import litellm
|
28 |
+
|
29 |
+
if os.getenv("LITELLM_DEBUG", "false").lower() == "true":
|
30 |
+
litellm.set_verbose = True
|
31 |
+
logger.setLevel(logging.DEBUG)
|
32 |
+
else:
|
33 |
+
litellm.set_verbose = False
|
34 |
+
logger.setLevel(logging.INFO)
|
35 |
+
|
36 |
# Configure LiteLLM to drop unsupported parameters
|
37 |
litellm.drop_params = True
|
38 |
|
39 |
+
# Load default prompt templates from local file
|
40 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
41 |
+
prompts_dir = os.path.join(current_dir, "prompts")
|
42 |
+
yaml_path = os.path.join(prompts_dir, "code_agent.yaml")
|
43 |
+
|
44 |
+
with open(yaml_path, "r") as f:
|
45 |
+
prompt_templates = yaml.safe_load(f)
|
46 |
+
|
47 |
+
# Initialize the model and agent using configuration
|
48 |
+
config = Configuration()
|
49 |
+
model = LiteLLMModel(
|
50 |
+
api_base=config.api_base,
|
51 |
+
api_key=config.api_key,
|
52 |
+
model_id=config.model_id,
|
53 |
+
)
|
54 |
+
|
55 |
+
agent = CodeAgent(
|
56 |
+
add_base_tools=True,
|
57 |
+
max_steps=1, # Execute one step at a time
|
58 |
+
model=model,
|
59 |
+
prompt_templates=prompt_templates,
|
60 |
+
tools=tools,
|
61 |
+
verbosity_level=logging.DEBUG,
|
62 |
+
)
|
63 |
+
|
64 |
+
|
65 |
class AgentState(TypedDict):
|
66 |
+
"""State for the agent graph."""
|
67 |
+
|
68 |
+
messages: List[Union[HumanMessage, AIMessage, SystemMessage]]
|
69 |
question: str
|
70 |
+
answer: Optional[str]
|
71 |
+
step_logs: List[Dict]
|
72 |
+
is_complete: bool
|
73 |
+
step_count: int
|
74 |
+
|
75 |
|
76 |
class AgentNode:
|
77 |
+
"""Node that runs the agent."""
|
78 |
+
|
79 |
+
def __init__(self, agent: CodeAgent):
|
80 |
+
"""Initialize the agent node with an agent."""
|
81 |
+
self.agent = agent
|
82 |
+
|
83 |
+
def __call__(
|
84 |
+
self, state: AgentState, config: Optional[RunnableConfig] = None
|
85 |
+
) -> AgentState:
|
86 |
+
"""Run the agent on the current state."""
|
87 |
+
# Log current state
|
88 |
+
logger.info("Current state before processing:")
|
89 |
+
logger.info(f"Messages: {state['messages']}")
|
90 |
+
logger.info(f"Question: {state['question']}")
|
91 |
+
logger.info(f"Answer: {state['answer']}")
|
92 |
+
|
93 |
+
# Get configuration
|
94 |
+
cfg = Configuration.from_runnable_config(config)
|
95 |
+
logger.info(f"Using configuration: {cfg}")
|
96 |
+
|
97 |
+
# Log execution start
|
98 |
+
logger.info("Starting agent execution")
|
99 |
+
|
100 |
+
# Run the agent
|
101 |
+
result = self.agent.run(state["question"])
|
102 |
+
|
103 |
+
# Log result
|
104 |
+
logger.info(f"Agent execution result type: {type(result)}")
|
105 |
+
logger.info(f"Agent execution result value: {result}")
|
106 |
+
|
107 |
+
# Update state
|
108 |
+
new_state = state.copy()
|
109 |
+
new_state["messages"].append(AIMessage(content=result))
|
110 |
+
new_state["answer"] = result
|
111 |
+
new_state["step_count"] += 1
|
112 |
+
|
113 |
+
# Log updated state
|
114 |
+
logger.info("Updated state after processing:")
|
115 |
+
logger.info(f"Messages: {new_state['messages']}")
|
116 |
+
logger.info(f"Question: {new_state['question']}")
|
117 |
+
logger.info(f"Answer: {new_state['answer']}")
|
118 |
+
|
119 |
+
return new_state
|
120 |
+
|
121 |
+
|
122 |
+
class StepCallbackNode:
|
123 |
+
"""Node that handles step callbacks and user interaction."""
|
124 |
+
|
125 |
+
def __call__(
|
126 |
+
self, state: AgentState, config: Optional[RunnableConfig] = None
|
127 |
+
) -> AgentState:
|
128 |
+
"""Handle step callback and user interaction."""
|
129 |
+
# Get configuration
|
130 |
+
cfg = Configuration.from_runnable_config(config)
|
131 |
+
|
132 |
+
# Log the step
|
133 |
+
step_log = {
|
134 |
+
"step": state["step_count"],
|
135 |
+
"messages": [msg.content for msg in state["messages"]],
|
136 |
+
"question": state["question"],
|
137 |
+
"answer": state["answer"],
|
138 |
+
}
|
139 |
+
state["step_logs"].append(step_log)
|
140 |
+
|
141 |
try:
|
142 |
+
# Use interrupt for user input
|
143 |
+
user_input = interrupt(
|
144 |
+
"Press 'c' to continue, 'q' to quit, or 'i' for more info: "
|
145 |
+
)
|
146 |
+
|
147 |
+
if user_input.lower() == "q":
|
148 |
+
state["is_complete"] = True
|
149 |
+
return state
|
150 |
+
elif user_input.lower() == "i":
|
151 |
+
logger.info(f"Current step: {state['step_count']}")
|
152 |
+
logger.info(f"Question: {state['question']}")
|
153 |
+
logger.info(f"Current answer: {state['answer']}")
|
154 |
+
return self(state, config) # Recursively call for new input
|
155 |
+
elif user_input.lower() == "c":
|
156 |
+
return state
|
157 |
+
else:
|
158 |
+
logger.warning("Invalid input. Please use 'c', 'q', or 'i'.")
|
159 |
+
return self(state, config) # Recursively call for new input
|
160 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
except Exception as e:
|
162 |
+
logger.warning(f"Error during interrupt: {str(e)}")
|
|
|
163 |
return state
|
164 |
|
165 |
+
|
166 |
+
def build_agent_graph(agent: AgentNode) -> StateGraph:
|
167 |
+
"""Build the agent graph."""
|
168 |
+
# Initialize the graph
|
169 |
+
workflow = StateGraph(AgentState)
|
170 |
+
|
171 |
+
# Add nodes
|
172 |
+
workflow.add_node("agent", agent)
|
173 |
+
workflow.add_node("callback", StepCallbackNode())
|
174 |
+
|
175 |
# Add edges
|
176 |
+
workflow.add_edge("agent", "callback")
|
177 |
+
workflow.add_conditional_edges(
|
178 |
+
"callback",
|
179 |
+
lambda x: END if x["is_complete"] else "agent",
|
180 |
+
{True: END, False: "agent"},
|
181 |
+
)
|
182 |
+
|
183 |
+
# Set entry point
|
184 |
+
workflow.set_entry_point("agent")
|
185 |
+
|
186 |
+
return workflow.compile()
|
187 |
+
|
188 |
+
|
189 |
+
# Initialize the agent graph
|
190 |
+
agent_graph = build_agent_graph(AgentNode(agent))
|
requirements.txt
CHANGED
@@ -1,5 +1,7 @@
|
|
|
|
1 |
duckduckgo-search>=8.0.1
|
2 |
gradio[oauth]>=5.26.0
|
|
|
3 |
langgraph>=0.3.34
|
4 |
pytest>=8.3.5
|
5 |
pytest-cov>=6.1.1
|
|
|
1 |
+
black>=25.1.0
|
2 |
duckduckgo-search>=8.0.1
|
3 |
gradio[oauth]>=5.26.0
|
4 |
+
isort>=6.0.1
|
5 |
langgraph>=0.3.34
|
6 |
pytest>=8.3.5
|
7 |
pytest-cov>=6.1.1
|
test_agent.py
CHANGED
@@ -1,84 +1,84 @@
|
|
1 |
import logging
|
|
|
2 |
import pytest
|
3 |
import requests
|
|
|
|
|
4 |
from agent import AgentRunner
|
5 |
|
6 |
-
# Configure
|
7 |
-
logging.
|
8 |
-
|
9 |
-
format='%(asctime)s - %(levelname)s - %(message)s'
|
10 |
-
)
|
11 |
-
logger = logging.getLogger(__name__)
|
12 |
|
13 |
# Suppress specific warnings
|
14 |
-
pytestmark = pytest.mark.filterwarnings(
|
15 |
-
"ignore::DeprecationWarning:httpx._models"
|
16 |
-
)
|
17 |
|
18 |
# Constants
|
19 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
20 |
QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
|
21 |
|
|
|
22 |
@pytest.fixture(scope="session")
|
23 |
def agent():
|
24 |
"""Fixture to create and return an AgentRunner instance."""
|
25 |
-
|
26 |
return AgentRunner()
|
27 |
|
|
|
28 |
# @pytest.fixture(scope="session")
|
29 |
# def questions_data():
|
30 |
# """Fixture to fetch questions from the API."""
|
31 |
-
#
|
32 |
# try:
|
33 |
# response = requests.get(QUESTIONS_URL, timeout=15)
|
34 |
# response.raise_for_status()
|
35 |
# data = response.json()
|
36 |
# if not data:
|
37 |
-
#
|
38 |
# return []
|
39 |
-
#
|
40 |
# return data
|
41 |
# except requests.exceptions.RequestException as e:
|
42 |
-
#
|
43 |
# return []
|
44 |
# except requests.exceptions.JSONDecodeError as e:
|
45 |
-
#
|
46 |
# return []
|
47 |
# except Exception as e:
|
48 |
-
#
|
49 |
# return []
|
50 |
#
|
51 |
# class TestAppQuestions:
|
52 |
# """Test cases for questions from the app."""
|
53 |
-
#
|
54 |
# def test_first_app_question(self, agent, questions_data):
|
55 |
# """Test the agent's response to the first app question."""
|
56 |
# if not questions_data:
|
57 |
# pytest.skip("No questions available from API")
|
58 |
-
#
|
59 |
# first_question = questions_data[0]
|
60 |
# question_text = first_question.get("question")
|
61 |
# task_id = first_question.get("task_id")
|
62 |
-
#
|
63 |
# if not question_text or not task_id:
|
64 |
# pytest.skip("First question is missing required fields")
|
65 |
-
#
|
66 |
-
#
|
67 |
-
#
|
68 |
# response = agent(question_text)
|
69 |
-
#
|
70 |
-
#
|
71 |
# # Check that the response contains the expected information
|
72 |
# assert "Mercedes Sosa" in response, "Response should mention Mercedes Sosa"
|
73 |
# assert "studio albums" in response.lower(), "Response should mention studio albums"
|
74 |
# assert "2000" in response and "2009" in response, "Response should mention the year range"
|
75 |
-
#
|
76 |
# # Verify that a number is mentioned (either as word or digit)
|
77 |
# import re
|
78 |
# number_pattern = r'\b(one|two|three|four|five|six|seven|eight|nine|ten|\d+)\b'
|
79 |
# has_number = bool(re.search(number_pattern, response.lower()))
|
80 |
# assert has_number, "Response should include the number of albums"
|
81 |
-
#
|
82 |
# # Check for album names in the response
|
83 |
# known_albums = [
|
84 |
# "Corazón Libre",
|
@@ -89,54 +89,157 @@ def agent():
|
|
89 |
# ]
|
90 |
# found_albums = [album for album in known_albums if album in response]
|
91 |
# assert len(found_albums) > 0, "Response should mention at least some of the known albums"
|
92 |
-
#
|
93 |
# # Check for a structured response
|
94 |
# assert re.search(r'\d+\.\s+[^(]+\(\d{4}\)', response), \
|
95 |
# "Response should list albums with years"
|
96 |
|
|
|
97 |
class TestBasicCodeAgentCapabilities:
|
98 |
-
"""Test
|
99 |
-
|
100 |
-
def
|
101 |
-
"""
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
question = "What is the result of the following operation: 5 + 3 + 1294.678?"
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
logger.info(f"Agent response: {response}")
|
110 |
-
|
111 |
# Verify the response contains the correct result
|
112 |
expected_result = str(5 + 3 + 1294.678)
|
113 |
-
assert
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
""
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
assert
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
if __name__ == "__main__":
|
142 |
-
pytest.main([__file__, "-v", "-x"])
|
|
|
1 |
import logging
|
2 |
+
|
3 |
import pytest
|
4 |
import requests
|
5 |
+
from langgraph.types import Command
|
6 |
+
|
7 |
from agent import AgentRunner
|
8 |
|
9 |
+
# Configure test logger
|
10 |
+
test_logger = logging.getLogger("test_agent")
|
11 |
+
test_logger.setLevel(logging.INFO)
|
|
|
|
|
|
|
12 |
|
13 |
# Suppress specific warnings
|
14 |
+
pytestmark = pytest.mark.filterwarnings("ignore::DeprecationWarning:httpx._models")
|
|
|
|
|
15 |
|
16 |
# Constants
|
17 |
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
18 |
QUESTIONS_URL = f"{DEFAULT_API_URL}/questions"
|
19 |
|
20 |
+
|
21 |
@pytest.fixture(scope="session")
|
22 |
def agent():
|
23 |
"""Fixture to create and return an AgentRunner instance."""
|
24 |
+
test_logger.info("Creating AgentRunner instance")
|
25 |
return AgentRunner()
|
26 |
|
27 |
+
|
28 |
# @pytest.fixture(scope="session")
|
29 |
# def questions_data():
|
30 |
# """Fixture to fetch questions from the API."""
|
31 |
+
# test_logger.info(f"Fetching questions from: {QUESTIONS_URL}")
|
32 |
# try:
|
33 |
# response = requests.get(QUESTIONS_URL, timeout=15)
|
34 |
# response.raise_for_status()
|
35 |
# data = response.json()
|
36 |
# if not data:
|
37 |
+
# test_logger.error("Fetched questions list is empty.")
|
38 |
# return []
|
39 |
+
# test_logger.info(f"Fetched {len(data)} questions.")
|
40 |
# return data
|
41 |
# except requests.exceptions.RequestException as e:
|
42 |
+
# test_logger.error(f"Error fetching questions: {e}")
|
43 |
# return []
|
44 |
# except requests.exceptions.JSONDecodeError as e:
|
45 |
+
# test_logger.error(f"Error decoding JSON response from questions endpoint: {e}")
|
46 |
# return []
|
47 |
# except Exception as e:
|
48 |
+
# test_logger.error(f"An unexpected error occurred fetching questions: {e}")
|
49 |
# return []
|
50 |
#
|
51 |
# class TestAppQuestions:
|
52 |
# """Test cases for questions from the app."""
|
53 |
+
#
|
54 |
# def test_first_app_question(self, agent, questions_data):
|
55 |
# """Test the agent's response to the first app question."""
|
56 |
# if not questions_data:
|
57 |
# pytest.skip("No questions available from API")
|
58 |
+
#
|
59 |
# first_question = questions_data[0]
|
60 |
# question_text = first_question.get("question")
|
61 |
# task_id = first_question.get("task_id")
|
62 |
+
#
|
63 |
# if not question_text or not task_id:
|
64 |
# pytest.skip("First question is missing required fields")
|
65 |
+
#
|
66 |
+
# test_logger.info(f"Testing with app question: {question_text}")
|
67 |
+
#
|
68 |
# response = agent(question_text)
|
69 |
+
# test_logger.info(f"Agent response: {response}")
|
70 |
+
#
|
71 |
# # Check that the response contains the expected information
|
72 |
# assert "Mercedes Sosa" in response, "Response should mention Mercedes Sosa"
|
73 |
# assert "studio albums" in response.lower(), "Response should mention studio albums"
|
74 |
# assert "2000" in response and "2009" in response, "Response should mention the year range"
|
75 |
+
#
|
76 |
# # Verify that a number is mentioned (either as word or digit)
|
77 |
# import re
|
78 |
# number_pattern = r'\b(one|two|three|four|five|six|seven|eight|nine|ten|\d+)\b'
|
79 |
# has_number = bool(re.search(number_pattern, response.lower()))
|
80 |
# assert has_number, "Response should include the number of albums"
|
81 |
+
#
|
82 |
# # Check for album names in the response
|
83 |
# known_albums = [
|
84 |
# "Corazón Libre",
|
|
|
89 |
# ]
|
90 |
# found_albums = [album for album in known_albums if album in response]
|
91 |
# assert len(found_albums) > 0, "Response should mention at least some of the known albums"
|
92 |
+
#
|
93 |
# # Check for a structured response
|
94 |
# assert re.search(r'\d+\.\s+[^(]+\(\d{4}\)', response), \
|
95 |
# "Response should list albums with years"
|
96 |
|
97 |
+
|
98 |
class TestBasicCodeAgentCapabilities:
|
99 |
+
"""Test basic capabilities of the code agent."""
|
100 |
+
|
101 |
+
def setup_method(self):
|
102 |
+
"""Setup method to initialize the agent before each test."""
|
103 |
+
test_logger.info("Creating AgentRunner instance")
|
104 |
+
self.agent = AgentRunner()
|
105 |
+
|
106 |
+
def test_simple_math_calculation_with_steps(self):
|
107 |
+
"""Test that the agent can perform basic math calculations and log steps."""
|
108 |
question = "What is the result of the following operation: 5 + 3 + 1294.678?"
|
109 |
+
test_logger.info(f"Testing math calculation with question: {question}")
|
110 |
+
|
111 |
+
# Run the agent and get the response
|
112 |
+
response = self.agent(question)
|
113 |
+
|
|
|
|
|
114 |
# Verify the response contains the correct result
|
115 |
expected_result = str(5 + 3 + 1294.678)
|
116 |
+
assert (
|
117 |
+
expected_result in response
|
118 |
+
), f"Response should contain the result {expected_result}"
|
119 |
+
|
120 |
+
# Verify step logs exist and have required fields
|
121 |
+
assert self.agent.last_state is not None, "Agent should store last state"
|
122 |
+
assert "step_logs" in self.agent.last_state, "State should contain step_logs"
|
123 |
+
assert (
|
124 |
+
len(self.agent.last_state["step_logs"]) > 0
|
125 |
+
), "Should have at least one step logged"
|
126 |
+
|
127 |
+
# Verify each step has required fields
|
128 |
+
for step in self.agent.last_state["step_logs"]:
|
129 |
+
assert "step_number" in step, "Each step should have a step_number"
|
130 |
+
assert any(
|
131 |
+
key in step for key in ["thought", "code", "observation"]
|
132 |
+
), "Each step should have at least one of: thought, code, or observation"
|
133 |
+
|
134 |
+
# Verify the final answer is indicated
|
135 |
+
assert (
|
136 |
+
"final_answer" in response.lower()
|
137 |
+
), "Response should indicate it's providing an answer"
|
138 |
+
|
139 |
+
def test_document_qa_and_image_generation_with_steps(self):
|
140 |
+
"""Test that the agent can search for information and generate images, with step logging."""
|
141 |
+
question = (
|
142 |
+
"Search for information about the Mona Lisa and generate an image of it."
|
143 |
+
)
|
144 |
+
test_logger.info(
|
145 |
+
f"Testing document QA and image generation with question: {question}"
|
146 |
+
)
|
147 |
+
|
148 |
+
# Run the agent and get the response
|
149 |
+
response = self.agent(question)
|
150 |
+
|
151 |
+
# Verify the response contains both search and image generation
|
152 |
+
assert "mona lisa" in response.lower(), "Response should mention Mona Lisa"
|
153 |
+
assert "image" in response.lower(), "Response should mention image generation"
|
154 |
+
|
155 |
+
# Verify step logs exist and show logical progression
|
156 |
+
assert self.agent.last_state is not None, "Agent should store last state"
|
157 |
+
assert "step_logs" in self.agent.last_state, "State should contain step_logs"
|
158 |
+
assert (
|
159 |
+
len(self.agent.last_state["step_logs"]) > 1
|
160 |
+
), "Should have multiple steps logged"
|
161 |
+
|
162 |
+
# Verify steps show logical progression
|
163 |
+
steps = self.agent.last_state["step_logs"]
|
164 |
+
search_steps = [step for step in steps if "search" in str(step).lower()]
|
165 |
+
image_steps = [step for step in steps if "image" in str(step).lower()]
|
166 |
+
|
167 |
+
assert len(search_steps) > 0, "Should have search steps"
|
168 |
+
assert len(image_steps) > 0, "Should have image generation steps"
|
169 |
+
|
170 |
+
# Verify each step has required fields
|
171 |
+
for step in steps:
|
172 |
+
assert "step_number" in step, "Each step should have a step_number"
|
173 |
+
assert any(
|
174 |
+
key in step for key in ["thought", "code", "observation"]
|
175 |
+
), "Each step should have at least one of: thought, code, or observation"
|
176 |
+
|
177 |
+
|
178 |
+
def test_simple_math_calculation_with_steps():
|
179 |
+
"""Test that the agent can perform a simple math calculation and verify intermediate steps."""
|
180 |
+
agent = AgentRunner()
|
181 |
+
question = "What is the result of the following operation: 5 + 3 + 1294.678?"
|
182 |
+
|
183 |
+
# Process the question
|
184 |
+
response = agent(question)
|
185 |
+
|
186 |
+
# Verify step logs exist and have required fields
|
187 |
+
assert agent.last_state is not None, "Last state should be stored"
|
188 |
+
step_logs = agent.last_state.get("step_logs", [])
|
189 |
+
assert len(step_logs) > 0, "Should have recorded step logs"
|
190 |
+
|
191 |
+
for step in step_logs:
|
192 |
+
assert "step_number" in step, "Each step should have a step number"
|
193 |
+
assert any(
|
194 |
+
key in step for key in ["thought", "code", "observation"]
|
195 |
+
), "Each step should have at least one of thought/code/observation"
|
196 |
+
|
197 |
+
# Verify final answer
|
198 |
+
expected_result = 1302.678
|
199 |
+
assert (
|
200 |
+
str(expected_result) in response
|
201 |
+
), f"Response should contain the result {expected_result}"
|
202 |
+
assert (
|
203 |
+
"final_answer" in response.lower()
|
204 |
+
), "Response should indicate it's using final_answer"
|
205 |
+
|
206 |
+
|
207 |
+
def test_document_qa_and_image_generation_with_steps():
|
208 |
+
"""Test document QA and image generation with step verification."""
|
209 |
+
agent = AgentRunner()
|
210 |
+
question = "Can you search for information about the Mona Lisa and generate an image inspired by it?"
|
211 |
+
|
212 |
+
# Process the question
|
213 |
+
response = agent(question)
|
214 |
+
|
215 |
+
# Verify step logs exist and demonstrate logical progression
|
216 |
+
assert agent.last_state is not None, "Last state should be stored"
|
217 |
+
step_logs = agent.last_state.get("step_logs", [])
|
218 |
+
assert len(step_logs) > 0, "Should have recorded step logs"
|
219 |
+
|
220 |
+
# Check for search and image generation steps
|
221 |
+
has_search_step = False
|
222 |
+
has_image_step = False
|
223 |
+
|
224 |
+
for step in step_logs:
|
225 |
+
assert "step_number" in step, "Each step should have a step number"
|
226 |
+
assert any(
|
227 |
+
key in step for key in ["thought", "code", "observation"]
|
228 |
+
), "Each step should have at least one of thought/code/observation"
|
229 |
+
|
230 |
+
# Look for search and image steps in thoughts or code
|
231 |
+
step_content = str(step.get("thought", "")) + str(step.get("code", ""))
|
232 |
+
if "search" in step_content.lower():
|
233 |
+
has_search_step = True
|
234 |
+
if "image" in step_content.lower() or "dalle" in step_content.lower():
|
235 |
+
has_image_step = True
|
236 |
+
|
237 |
+
assert has_search_step, "Should include a search step"
|
238 |
+
assert has_image_step, "Should include an image generation step"
|
239 |
+
assert (
|
240 |
+
"final_answer" in response.lower()
|
241 |
+
), "Response should indicate it's using final_answer"
|
242 |
+
|
243 |
|
244 |
if __name__ == "__main__":
|
245 |
+
pytest.main([__file__, "-s", "-v", "-x"])
|
tools.py
CHANGED
@@ -1,12 +1,16 @@
|
|
1 |
import logging
|
2 |
-
|
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
|
|
6 |
class GeneralSearchTool(Tool):
|
7 |
name = "search"
|
8 |
description = """Performs a general web search using both DuckDuckGo and Wikipedia, then returns the combined search results."""
|
9 |
-
inputs = {
|
|
|
|
|
10 |
output_type = "string"
|
11 |
|
12 |
def __init__(self, max_results=10, **kwargs):
|
@@ -22,26 +26,27 @@ class GeneralSearchTool(Tool):
|
|
22 |
except Exception as e:
|
23 |
ddg_results = "No DuckDuckGo results found."
|
24 |
logger.warning(f"DuckDuckGo search failed: {str(e)}")
|
25 |
-
|
26 |
# Get Wikipedia results
|
27 |
try:
|
28 |
wiki_results = self.wiki_tool.forward(query)
|
29 |
except Exception as e:
|
30 |
wiki_results = "No Wikipedia results found."
|
31 |
logger.warning(f"Wikipedia search failed: {str(e)}")
|
32 |
-
|
33 |
# Combine and format results
|
34 |
output = []
|
35 |
if ddg_results and ddg_results != "No DuckDuckGo results found.":
|
36 |
output.append("## DuckDuckGo Search Results\n\n" + ddg_results)
|
37 |
if wiki_results and wiki_results != "No Wikipedia results found.":
|
38 |
output.append("## Wikipedia Results\n\n" + wiki_results)
|
39 |
-
|
40 |
if not output:
|
41 |
raise Exception("No results found! Try a less restrictive/shorter query.")
|
42 |
-
|
43 |
return "\n\n---\n\n".join(output)
|
44 |
|
|
|
45 |
# Export all tools
|
46 |
tools = [
|
47 |
# DuckDuckGoSearchTool(),
|
|
|
1 |
import logging
|
2 |
+
|
3 |
+
from smolagents import DuckDuckGoSearchTool, Tool, WikipediaSearchTool
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
7 |
+
|
8 |
class GeneralSearchTool(Tool):
|
9 |
name = "search"
|
10 |
description = """Performs a general web search using both DuckDuckGo and Wikipedia, then returns the combined search results."""
|
11 |
+
inputs = {
|
12 |
+
"query": {"type": "string", "description": "The search query to perform."}
|
13 |
+
}
|
14 |
output_type = "string"
|
15 |
|
16 |
def __init__(self, max_results=10, **kwargs):
|
|
|
26 |
except Exception as e:
|
27 |
ddg_results = "No DuckDuckGo results found."
|
28 |
logger.warning(f"DuckDuckGo search failed: {str(e)}")
|
29 |
+
|
30 |
# Get Wikipedia results
|
31 |
try:
|
32 |
wiki_results = self.wiki_tool.forward(query)
|
33 |
except Exception as e:
|
34 |
wiki_results = "No Wikipedia results found."
|
35 |
logger.warning(f"Wikipedia search failed: {str(e)}")
|
36 |
+
|
37 |
# Combine and format results
|
38 |
output = []
|
39 |
if ddg_results and ddg_results != "No DuckDuckGo results found.":
|
40 |
output.append("## DuckDuckGo Search Results\n\n" + ddg_results)
|
41 |
if wiki_results and wiki_results != "No Wikipedia results found.":
|
42 |
output.append("## Wikipedia Results\n\n" + wiki_results)
|
43 |
+
|
44 |
if not output:
|
45 |
raise Exception("No results found! Try a less restrictive/shorter query.")
|
46 |
+
|
47 |
return "\n\n---\n\n".join(output)
|
48 |
|
49 |
+
|
50 |
# Export all tools
|
51 |
tools = [
|
52 |
# DuckDuckGoSearchTool(),
|