mjschock's picture
Refactor main_v2.py to update task formatting for dual answer requests, enhancing response structure. Implement error handling for JSON parsing in agent results, ensuring robust output. Add unit tests in test_questions.py to validate succinct answer accuracy against expected values. Remove unused extract_final_answer utility from utils.py, streamlining the codebase.
2da6a11 unverified
import logging
import re
from typing import Optional
import requests
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool
logger = logging.getLogger(__name__)
class SmartSearchTool(Tool):
name = "smart_search"
description = """A smart search tool that searches Wikipedia for information."""
inputs = {
"query": {
"type": "string",
"description": "The search query to find information",
}
}
output_type = "string"
def __init__(self):
super().__init__()
self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
self.api_url = "https://en.wikipedia.org/w/api.php"
self.headers = {
"User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])"
}
def get_wikipedia_page(self, title: str) -> Optional[str]:
"""Get the raw wiki markup of a Wikipedia page."""
try:
params = {
"action": "query",
"prop": "revisions",
"rvprop": "content",
"rvslots": "main",
"format": "json",
"titles": title,
"redirects": 1,
}
response = requests.get(self.api_url, params=params, headers=self.headers)
response.raise_for_status()
data = response.json()
# Extract page content
pages = data.get("query", {}).get("pages", {})
for page_id, page_data in pages.items():
if "revisions" in page_data:
return page_data["revisions"][0]["slots"]["main"]["*"]
return None
except Exception as e:
logger.error(f"Error getting Wikipedia page: {e}")
return None
def clean_wiki_content(self, content: str) -> str:
"""Clean Wikipedia content by removing markup and formatting."""
# Remove citations
content = re.sub(r"\[\d+\]", "", content)
# Remove edit links
content = re.sub(r"\[edit\]", "", content)
# Remove file links
content = re.sub(r"\[\[File:.*?\]\]", "", content)
# Convert links to just text
content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content)
# Remove HTML comments
content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL)
# Remove templates
content = re.sub(r"\{\{.*?\}\}", "", content)
# Remove small tags
content = re.sub(r"<small>.*?</small>", "", content)
# Normalize whitespace
content = re.sub(r"\n\s*\n", "\n\n", content)
return content.strip()
def format_wiki_table(self, table_content: str) -> str:
"""Format a Wikipedia table into readable text."""
# Split into rows
rows = table_content.strip().split("\n")
formatted_rows = []
current_row = []
for row in rows:
# Skip empty rows and table structure markers
if not row.strip() or row.startswith("|-") or row.startswith("|+"):
if current_row:
formatted_rows.append("\t".join(current_row))
current_row = []
continue
# Extract cells
cells = []
# Split the row into cells using | or ! as separators
cell_parts = re.split(r"[|!]", row)
for cell in cell_parts[1:]: # Skip the first empty part
# Clean up the cell content
cell = cell.strip()
# Remove any remaining markup
cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags
cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links
cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links
cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates
cell = re.sub(r"<small>.*?</small>", "", cell) # Remove small tags
cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan
cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan
cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes
cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes
cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes
cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes
cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes
cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes
cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes
cell = re.sub(
r'cellpadding=".*?"', "", cell
) # Remove cellpadding attributes
cell = re.sub(
r'cellspacing=".*?"', "", cell
) # Remove cellspacing attributes
cell = re.sub(r"<ref.*?</ref>", "", cell) # Remove references
cell = re.sub(r"<ref.*?/>", "", cell) # Remove empty references
cell = re.sub(
r"<br\s*/?>", " ", cell
) # Replace line breaks with spaces
cell = re.sub(r"\s+", " ", cell) # Normalize whitespace
cells.append(cell)
if cells:
current_row.extend(cells)
if current_row:
formatted_rows.append("\t".join(current_row))
if formatted_rows:
return "\n".join(formatted_rows)
return ""
def extract_wikipedia_title(self, search_result: str) -> Optional[str]:
"""Extract Wikipedia page title from search result."""
# Look for Wikipedia links in the format [Title - Wikipedia](url)
wiki_match = re.search(
r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)",
search_result,
)
if wiki_match:
return wiki_match.group(1).strip()
return None
def forward(self, query: str) -> str:
logger.info(f"Starting smart search for query: {query}")
# First do a web search to find the Wikipedia page
search_result = self.web_search_tool.forward(query)
logger.info(f"Web search results: {search_result[:100]}...")
# Extract Wikipedia page title from search results
wiki_title = self.extract_wikipedia_title(search_result)
if not wiki_title:
return f"Could not find Wikipedia page in search results for '{query}'."
# Get Wikipedia page content
page_content = self.get_wikipedia_page(wiki_title)
if not page_content:
return f"Could not find Wikipedia page for '{wiki_title}'."
# Format tables and content
formatted_content = []
current_section = []
in_table = False
table_content = []
for line in page_content.split("\n"):
if line.startswith("{|"):
in_table = True
table_content = [line]
elif line.startswith("|}"):
in_table = False
table_content.append(line)
formatted_table = self.format_wiki_table("\n".join(table_content))
if formatted_table:
current_section.append(formatted_table)
elif in_table:
table_content.append(line)
else:
if line.strip():
current_section.append(line)
elif current_section:
formatted_content.append("\n".join(current_section))
current_section = []
if current_section:
formatted_content.append("\n".join(current_section))
# Clean and return the formatted content
cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content))
return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}"
def main(query: str) -> str:
"""
Test function to run the SmartSearchTool directly.
Args:
query: The search query to test
Returns:
The search results
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Create and run the tool
tool = SmartSearchTool()
result = tool.forward(query)
# Print the result
print("\nSearch Results:")
print("-" * 80)
print(result)
print("-" * 80)
return result
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
main(query)
else:
print("Usage: python tool.py <search query>")
print("Example: python tool.py 'Mercedes Sosa discography'")