Spaces:
Build error
Build error
import logging | |
import re | |
from smolagents import Tool | |
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool | |
logger = logging.getLogger(__name__) | |
class SmartSearchTool(Tool): | |
name = "smart_search" | |
description = """A smart search tool that first performs a web search and then visits each URL to get its content.""" | |
inputs = {"query": {"type": "string", "description": "The search query to find information"}} | |
output_type = "string" | |
def __init__(self): | |
super().__init__() | |
self.web_search_tool = DuckDuckGoSearchTool(max_results=1) | |
self.visit_webpage_tool = VisitWebpageTool(max_output_length=-1) | |
def forward(self, query: str) -> str: | |
logger.info(f"Starting smart search for query: {query}") | |
# Get web search results | |
web_search_results = self.web_search_tool.forward(query) | |
logger.info(f"Web search results: {web_search_results[:100]}...") | |
# Extract URLs from the web search result | |
urls = re.findall(r'https?://[^\s)]+', web_search_results) | |
if not urls: | |
logger.info("No URLs found in web search result") | |
return f"Web search results:\n{web_search_results}" | |
# Visit each URL and get its content | |
contents = [] | |
for url in urls: | |
logger.info(f"Visiting URL: {url}") | |
try: | |
content = self.visit_webpage_tool.forward(url) | |
if content: | |
contents.append(f"\nContent from {url}:\n{content}") | |
except Exception as e: | |
logger.warning(f"Error visiting {url}: {e}") | |
contents.append(f"\nError visiting {url}: {e}") | |
# Combine all results | |
return f"Web search results:\n{web_search_results}\n" + "\n".join(contents) | |
def main(query: str) -> str: | |
""" | |
Test function to run the SmartSearchTool directly. | |
Args: | |
query: The search query to test | |
Returns: | |
The search results | |
""" | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
) | |
# Create and run the tool | |
tool = SmartSearchTool() | |
result = tool.forward(query) | |
# Print the result | |
print("\nSearch Results:") | |
print("-" * 80) | |
print(result) | |
print("-" * 80) | |
return result | |
if __name__ == "__main__": | |
import sys | |
if len(sys.argv) > 1: | |
query = " ".join(sys.argv[1:]) | |
main(query) | |
else: | |
print("Usage: python tool.py <search query>") | |
print("Example: python tool.py 'Mercedes Sosa discography'") | |