mjschock's picture
Update requirements.txt to add new dependencies for enhanced functionality, including kagglehub, langchain, and llama-index packages. Refactor SmartSearchTool in tool.py to replace Wikipedia search with a webpage visiting approach, improving content retrieval from web search results. Update tool description and logging for better clarity and usability.
50aafe2 unverified
raw
history blame
2.66 kB
import logging
import re
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool
logger = logging.getLogger(__name__)
class SmartSearchTool(Tool):
name = "smart_search"
description = """A smart search tool that first performs a web search and then visits each URL to get its content."""
inputs = {"query": {"type": "string", "description": "The search query to find information"}}
output_type = "string"
def __init__(self):
super().__init__()
self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
self.visit_webpage_tool = VisitWebpageTool(max_output_length=-1)
def forward(self, query: str) -> str:
logger.info(f"Starting smart search for query: {query}")
# Get web search results
web_result = self.web_search_tool.forward(query)
logger.info(f"Web search result: {web_result[:100]}...")
# Extract URLs from the web search result
urls = re.findall(r'https?://[^\s)]+', web_result)
if not urls:
logger.info("No URLs found in web search result")
return f"Web search result:\n{web_result}"
# Visit each URL and get its content
contents = []
for url in urls:
logger.info(f"Visiting URL: {url}")
try:
content = self.visit_webpage_tool.forward(url)
if content:
contents.append(f"\nContent from {url}:\n{content}")
except Exception as e:
logger.warning(f"Error visiting {url}: {e}")
contents.append(f"\nError visiting {url}: {e}")
# Combine all results
return f"Web search result:\n{web_result}\n" + "\n".join(contents)
def main(query: str) -> str:
"""
Test function to run the SmartSearchTool directly.
Args:
query: The search query to test
Returns:
The search results
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
# Create and run the tool
tool = SmartSearchTool()
result = tool.forward(query)
# Print the result
print("\nSearch Results:")
print("-" * 80)
print(result)
print("-" * 80)
return result
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
main(query)
else:
print("Usage: python tool.py <search query>")
print("Example: python tool.py 'Mercedes Sosa discography'")