mjschock's picture
Refactor SmartSearchTool in tool.py to improve logging and variable naming for web search results. Update return statements to ensure consistency in terminology, enhancing clarity in the output messages.
4e02cb8 unverified
raw
history blame
2.7 kB
import logging
import re
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool
logger = logging.getLogger(__name__)
class SmartSearchTool(Tool):
name = "smart_search"
description = """A smart search tool that first performs a web search and then visits each URL to get its content."""
inputs = {"query": {"type": "string", "description": "The search query to find information"}}
output_type = "string"
def __init__(self):
super().__init__()
self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
self.visit_webpage_tool = VisitWebpageTool(max_output_length=-1)
def forward(self, query: str) -> str:
logger.info(f"Starting smart search for query: {query}")
# Get web search results
web_search_results = self.web_search_tool.forward(query)
logger.info(f"Web search results: {web_search_results[:100]}...")
# Extract URLs from the web search result
urls = re.findall(r'https?://[^\s)]+', web_search_results)
if not urls:
logger.info("No URLs found in web search result")
return f"Web search results:\n{web_search_results}"
# Visit each URL and get its content
contents = []
for url in urls:
logger.info(f"Visiting URL: {url}")
try:
content = self.visit_webpage_tool.forward(url)
if content:
contents.append(f"\nContent from {url}:\n{content}")
except Exception as e:
logger.warning(f"Error visiting {url}: {e}")
contents.append(f"\nError visiting {url}: {e}")
# Combine all results
return f"Web search results:\n{web_search_results}\n" + "\n".join(contents)
def main(query: str) -> str:
"""
Test function to run the SmartSearchTool directly.
Args:
query: The search query to test
Returns:
The search results
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Create and run the tool
tool = SmartSearchTool()
result = tool.forward(query)
# Print the result
print("\nSearch Results:")
print("-" * 80)
print(result)
print("-" * 80)
return result
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
main(query)
else:
print("Usage: python tool.py <search query>")
print("Example: python tool.py 'Mercedes Sosa discography'")