|
import logging |
|
import re |
|
from smolagents import Tool |
|
from smolagents.default_tools import DuckDuckGoSearchTool, WikipediaSearchTool |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class SmartSearchTool(Tool): |
|
name = "smart_search" |
|
description = """A smart search tool that first performs a web search and then, if a Wikipedia article is found, |
|
uses Wikipedia search for more reliable information.""" |
|
inputs = {"query": {"type": "string", "description": "The search query to find information"}} |
|
output_type = "string" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
self.web_search_tool = DuckDuckGoSearchTool(max_results=1) |
|
self.wiki_tool = WikipediaSearchTool( |
|
user_agent="SmartSearchTool ([email protected])", |
|
language="en", |
|
|
|
content_type="text", |
|
extract_format="WIKI" |
|
) |
|
|
|
def forward(self, query: str) -> str: |
|
logger.info(f"Starting smart search for query: {query}") |
|
|
|
|
|
web_result = self.web_search_tool.forward(query) |
|
logger.info(f"Web search result: {web_result[:100]}...") |
|
|
|
|
|
if "wikipedia.org" in web_result.lower(): |
|
logger.info("Wikipedia link found in web search results") |
|
|
|
wiki_match = re.search(r'wikipedia\.org/wiki/([^)\s]+)', web_result) |
|
if wiki_match: |
|
wiki_title = wiki_match.group(1) |
|
logger.info(f"Extracted Wikipedia title: {wiki_title}") |
|
|
|
|
|
wiki_result = self.wiki_tool.forward(wiki_title) |
|
logger.info(f"Wikipedia search result: {wiki_result[:100]}...") |
|
|
|
if wiki_result and "No Wikipedia page found" not in wiki_result: |
|
logger.info("Successfully retrieved Wikipedia content") |
|
return f"Web search result:\n{web_result}\n\nWikipedia result:\n{wiki_result}" |
|
else: |
|
logger.warning("Wikipedia search failed or returned no results") |
|
else: |
|
logger.warning("Could not extract Wikipedia title from URL") |
|
|
|
|
|
logger.info("Returning web search result only") |
|
return f"Web search result:\n{web_result}" |
|
|
|
|
|
def main(query: str) -> str: |
|
""" |
|
Test function to run the SmartSearchTool directly. |
|
|
|
Args: |
|
query: The search query to test |
|
|
|
Returns: |
|
The search results |
|
""" |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
|
|
|
|
tool = SmartSearchTool() |
|
result = tool.forward(query) |
|
|
|
|
|
print("\nSearch Results:") |
|
print("-" * 80) |
|
print(result) |
|
print("-" * 80) |
|
|
|
return result |
|
|
|
|
|
if __name__ == "__main__": |
|
import sys |
|
|
|
if len(sys.argv) > 1: |
|
query = " ".join(sys.argv[1:]) |
|
main(query) |
|
else: |
|
print("Usage: python tool.py <search query>") |
|
print("Example: python tool.py 'Mercedes Sosa discography'") |
|
|