File size: 2,701 Bytes
a3a55bb
 
 
50aafe2
a3a55bb
 
 
 
 
 
50aafe2
a3a55bb
 
 
 
 
 
50aafe2
a3a55bb
 
 
 
50aafe2
4e02cb8
 
a3a55bb
50aafe2
4e02cb8
50aafe2
 
4e02cb8
a3a55bb
50aafe2
 
 
 
 
 
 
 
 
 
 
 
 
4e02cb8
ea0c151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import logging
import re
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool, VisitWebpageTool

logger = logging.getLogger(__name__)


class SmartSearchTool(Tool):
    name = "smart_search"
    description = """A smart search tool that first performs a web search and then visits each URL to get its content."""
    inputs = {"query": {"type": "string", "description": "The search query to find information"}}
    output_type = "string"

    def __init__(self):
        super().__init__()
        self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
        self.visit_webpage_tool = VisitWebpageTool(max_output_length=-1)

    def forward(self, query: str) -> str:
        logger.info(f"Starting smart search for query: {query}")
        
        # Get web search results
        web_search_results = self.web_search_tool.forward(query)
        logger.info(f"Web search results: {web_search_results[:100]}...")
        
        # Extract URLs from the web search result
        urls = re.findall(r'https?://[^\s)]+', web_search_results)
        if not urls:
            logger.info("No URLs found in web search result")
            return f"Web search results:\n{web_search_results}"
        
        # Visit each URL and get its content
        contents = []
        for url in urls:
            logger.info(f"Visiting URL: {url}")
            try:
                content = self.visit_webpage_tool.forward(url)
                if content:
                    contents.append(f"\nContent from {url}:\n{content}")
            except Exception as e:
                logger.warning(f"Error visiting {url}: {e}")
                contents.append(f"\nError visiting {url}: {e}")
        
        # Combine all results
        return f"Web search results:\n{web_search_results}\n" + "\n".join(contents)


def main(query: str) -> str:
    """
    Test function to run the SmartSearchTool directly.
    
    Args:
        query: The search query to test
        
    Returns:
        The search results
    """
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    )
    
    # Create and run the tool
    tool = SmartSearchTool()
    result = tool.forward(query)
    
    # Print the result
    print("\nSearch Results:")
    print("-" * 80)
    print(result)
    print("-" * 80)
    
    return result


if __name__ == "__main__":
    import sys
    
    if len(sys.argv) > 1:
        query = " ".join(sys.argv[1:])
        main(query)
    else:
        print("Usage: python tool.py <search query>")
        print("Example: python tool.py 'Mercedes Sosa discography'")