File size: 8,919 Bytes
a3a55bb
 
2da6a11
 
 
a3a55bb
0305659
a3a55bb
 
 
 
 
 
0305659
2da6a11
 
 
 
 
 
a3a55bb
 
 
 
 
0305659
 
2da6a11
0305659
 
 
 
 
 
2da6a11
 
 
 
 
 
 
0305659
 
 
 
2da6a11
0305659
2da6a11
0305659
2da6a11
 
0305659
 
 
 
 
 
 
 
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
2da6a11
0305659
 
 
 
 
2da6a11
0305659
 
2da6a11
0305659
 
2da6a11
0305659
2da6a11
0305659
 
2da6a11
0305659
 
 
2da6a11
0305659
 
 
 
2da6a11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0305659
2da6a11
0305659
 
2da6a11
0305659
2da6a11
 
0305659
2da6a11
 
0305659
 
 
 
2da6a11
 
 
 
0305659
 
 
a3a55bb
 
 
2da6a11
0305659
 
 
2da6a11
0305659
 
 
 
2da6a11
0305659
 
 
 
2da6a11
0305659
 
 
 
 
2da6a11
 
 
0305659
 
2da6a11
0305659
 
2da6a11
0305659
 
 
 
 
 
 
 
2da6a11
0305659
2da6a11
0305659
2da6a11
 
0305659
2da6a11
0305659
ea0c151
 
 
 
 
2da6a11
ea0c151
 
2da6a11
ea0c151
 
 
 
 
 
2da6a11
ea0c151
2da6a11
ea0c151
 
 
2da6a11
ea0c151
 
 
 
 
2da6a11
ea0c151
 
 
 
 
2da6a11
ea0c151
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import logging
import re
from typing import Optional

import requests
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool

logger = logging.getLogger(__name__)


class SmartSearchTool(Tool):
    name = "smart_search"
    description = """A smart search tool that searches Wikipedia for information."""
    inputs = {
        "query": {
            "type": "string",
            "description": "The search query to find information",
        }
    }
    output_type = "string"

    def __init__(self):
        super().__init__()
        self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
        self.api_url = "https://en.wikipedia.org/w/api.php"
        self.headers = {
            "User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])"
        }

    def get_wikipedia_page(self, title: str) -> Optional[str]:
        """Get the raw wiki markup of a Wikipedia page."""
        try:
            params = {
                "action": "query",
                "prop": "revisions",
                "rvprop": "content",
                "rvslots": "main",
                "format": "json",
                "titles": title,
                "redirects": 1,
            }
            response = requests.get(self.api_url, params=params, headers=self.headers)
            response.raise_for_status()
            data = response.json()

            # Extract page content
            pages = data.get("query", {}).get("pages", {})
            for page_id, page_data in pages.items():
                if "revisions" in page_data:
                    return page_data["revisions"][0]["slots"]["main"]["*"]
            return None
        except Exception as e:
            logger.error(f"Error getting Wikipedia page: {e}")
            return None

    def clean_wiki_content(self, content: str) -> str:
        """Clean Wikipedia content by removing markup and formatting."""
        # Remove citations
        content = re.sub(r"\[\d+\]", "", content)
        # Remove edit links
        content = re.sub(r"\[edit\]", "", content)
        # Remove file links
        content = re.sub(r"\[\[File:.*?\]\]", "", content)
        # Convert links to just text
        content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content)
        # Remove HTML comments
        content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL)
        # Remove templates
        content = re.sub(r"\{\{.*?\}\}", "", content)
        # Remove small tags
        content = re.sub(r"<small>.*?</small>", "", content)
        # Normalize whitespace
        content = re.sub(r"\n\s*\n", "\n\n", content)
        return content.strip()

    def format_wiki_table(self, table_content: str) -> str:
        """Format a Wikipedia table into readable text."""
        # Split into rows
        rows = table_content.strip().split("\n")
        formatted_rows = []
        current_row = []

        for row in rows:
            # Skip empty rows and table structure markers
            if not row.strip() or row.startswith("|-") or row.startswith("|+"):
                if current_row:
                    formatted_rows.append("\t".join(current_row))
                    current_row = []
                continue

            # Extract cells
            cells = []
            # Split the row into cells using | or ! as separators
            cell_parts = re.split(r"[|!]", row)
            for cell in cell_parts[1:]:  # Skip the first empty part
                # Clean up the cell content
                cell = cell.strip()
                # Remove any remaining markup
                cell = re.sub(r"<.*?>", "", cell)  # Remove HTML tags
                cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell)  # Convert links
                cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell)  # Convert simple links
                cell = re.sub(r"\{\{.*?\}\}", "", cell)  # Remove templates
                cell = re.sub(r"<small>.*?</small>", "", cell)  # Remove small tags
                cell = re.sub(r'rowspan="\d+"', "", cell)  # Remove rowspan
                cell = re.sub(r'colspan="\d+"', "", cell)  # Remove colspan
                cell = re.sub(r'class=".*?"', "", cell)  # Remove class attributes
                cell = re.sub(r'style=".*?"', "", cell)  # Remove style attributes
                cell = re.sub(r'align=".*?"', "", cell)  # Remove align attributes
                cell = re.sub(r'width=".*?"', "", cell)  # Remove width attributes
                cell = re.sub(r'bgcolor=".*?"', "", cell)  # Remove bgcolor attributes
                cell = re.sub(r'valign=".*?"', "", cell)  # Remove valign attributes
                cell = re.sub(r'border=".*?"', "", cell)  # Remove border attributes
                cell = re.sub(
                    r'cellpadding=".*?"', "", cell
                )  # Remove cellpadding attributes
                cell = re.sub(
                    r'cellspacing=".*?"', "", cell
                )  # Remove cellspacing attributes
                cell = re.sub(r"<ref.*?</ref>", "", cell)  # Remove references
                cell = re.sub(r"<ref.*?/>", "", cell)  # Remove empty references
                cell = re.sub(
                    r"<br\s*/?>", " ", cell
                )  # Replace line breaks with spaces
                cell = re.sub(r"\s+", " ", cell)  # Normalize whitespace
                cells.append(cell)

            if cells:
                current_row.extend(cells)

        if current_row:
            formatted_rows.append("\t".join(current_row))

        if formatted_rows:
            return "\n".join(formatted_rows)
        return ""

    def extract_wikipedia_title(self, search_result: str) -> Optional[str]:
        """Extract Wikipedia page title from search result."""
        # Look for Wikipedia links in the format [Title - Wikipedia](url)
        wiki_match = re.search(
            r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)",
            search_result,
        )
        if wiki_match:
            return wiki_match.group(1).strip()
        return None

    def forward(self, query: str) -> str:
        logger.info(f"Starting smart search for query: {query}")

        # First do a web search to find the Wikipedia page
        search_result = self.web_search_tool.forward(query)
        logger.info(f"Web search results: {search_result[:100]}...")

        # Extract Wikipedia page title from search results
        wiki_title = self.extract_wikipedia_title(search_result)
        if not wiki_title:
            return f"Could not find Wikipedia page in search results for '{query}'."

        # Get Wikipedia page content
        page_content = self.get_wikipedia_page(wiki_title)
        if not page_content:
            return f"Could not find Wikipedia page for '{wiki_title}'."

        # Format tables and content
        formatted_content = []
        current_section = []
        in_table = False
        table_content = []

        for line in page_content.split("\n"):
            if line.startswith("{|"):
                in_table = True
                table_content = [line]
            elif line.startswith("|}"):
                in_table = False
                table_content.append(line)
                formatted_table = self.format_wiki_table("\n".join(table_content))
                if formatted_table:
                    current_section.append(formatted_table)
            elif in_table:
                table_content.append(line)
            else:
                if line.strip():
                    current_section.append(line)
                elif current_section:
                    formatted_content.append("\n".join(current_section))
                    current_section = []

        if current_section:
            formatted_content.append("\n".join(current_section))

        # Clean and return the formatted content
        cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content))
        return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}"


def main(query: str) -> str:
    """
    Test function to run the SmartSearchTool directly.

    Args:
        query: The search query to test

    Returns:
        The search results
    """
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    # Create and run the tool
    tool = SmartSearchTool()
    result = tool.forward(query)

    # Print the result
    print("\nSearch Results:")
    print("-" * 80)
    print(result)
    print("-" * 80)

    return result


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        query = " ".join(sys.argv[1:])
        main(query)
    else:
        print("Usage: python tool.py <search query>")
        print("Example: python tool.py 'Mercedes Sosa discography'")