Spaces:
Build error
Build error
File size: 8,919 Bytes
a3a55bb 2da6a11 a3a55bb 0305659 a3a55bb 0305659 2da6a11 a3a55bb 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 a3a55bb 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 2da6a11 0305659 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 2da6a11 ea0c151 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import logging
import re
from typing import Optional
import requests
from smolagents import Tool
from smolagents.default_tools import DuckDuckGoSearchTool
logger = logging.getLogger(__name__)
class SmartSearchTool(Tool):
name = "smart_search"
description = """A smart search tool that searches Wikipedia for information."""
inputs = {
"query": {
"type": "string",
"description": "The search query to find information",
}
}
output_type = "string"
def __init__(self):
super().__init__()
self.web_search_tool = DuckDuckGoSearchTool(max_results=1)
self.api_url = "https://en.wikipedia.org/w/api.php"
self.headers = {
"User-Agent": "SmartSearchTool/1.0 (https://github.com/yourusername/yourproject; [email protected])"
}
def get_wikipedia_page(self, title: str) -> Optional[str]:
"""Get the raw wiki markup of a Wikipedia page."""
try:
params = {
"action": "query",
"prop": "revisions",
"rvprop": "content",
"rvslots": "main",
"format": "json",
"titles": title,
"redirects": 1,
}
response = requests.get(self.api_url, params=params, headers=self.headers)
response.raise_for_status()
data = response.json()
# Extract page content
pages = data.get("query", {}).get("pages", {})
for page_id, page_data in pages.items():
if "revisions" in page_data:
return page_data["revisions"][0]["slots"]["main"]["*"]
return None
except Exception as e:
logger.error(f"Error getting Wikipedia page: {e}")
return None
def clean_wiki_content(self, content: str) -> str:
"""Clean Wikipedia content by removing markup and formatting."""
# Remove citations
content = re.sub(r"\[\d+\]", "", content)
# Remove edit links
content = re.sub(r"\[edit\]", "", content)
# Remove file links
content = re.sub(r"\[\[File:.*?\]\]", "", content)
# Convert links to just text
content = re.sub(r"\[\[(?:[^|\]]*\|)?([^\]]+)\]\]", r"\1", content)
# Remove HTML comments
content = re.sub(r"<!--.*?-->", "", content, flags=re.DOTALL)
# Remove templates
content = re.sub(r"\{\{.*?\}\}", "", content)
# Remove small tags
content = re.sub(r"<small>.*?</small>", "", content)
# Normalize whitespace
content = re.sub(r"\n\s*\n", "\n\n", content)
return content.strip()
def format_wiki_table(self, table_content: str) -> str:
"""Format a Wikipedia table into readable text."""
# Split into rows
rows = table_content.strip().split("\n")
formatted_rows = []
current_row = []
for row in rows:
# Skip empty rows and table structure markers
if not row.strip() or row.startswith("|-") or row.startswith("|+"):
if current_row:
formatted_rows.append("\t".join(current_row))
current_row = []
continue
# Extract cells
cells = []
# Split the row into cells using | or ! as separators
cell_parts = re.split(r"[|!]", row)
for cell in cell_parts[1:]: # Skip the first empty part
# Clean up the cell content
cell = cell.strip()
# Remove any remaining markup
cell = re.sub(r"<.*?>", "", cell) # Remove HTML tags
cell = re.sub(r"\[\[.*?\|(.*?)\]\]", r"\1", cell) # Convert links
cell = re.sub(r"\[\[(.*?)\]\]", r"\1", cell) # Convert simple links
cell = re.sub(r"\{\{.*?\}\}", "", cell) # Remove templates
cell = re.sub(r"<small>.*?</small>", "", cell) # Remove small tags
cell = re.sub(r'rowspan="\d+"', "", cell) # Remove rowspan
cell = re.sub(r'colspan="\d+"', "", cell) # Remove colspan
cell = re.sub(r'class=".*?"', "", cell) # Remove class attributes
cell = re.sub(r'style=".*?"', "", cell) # Remove style attributes
cell = re.sub(r'align=".*?"', "", cell) # Remove align attributes
cell = re.sub(r'width=".*?"', "", cell) # Remove width attributes
cell = re.sub(r'bgcolor=".*?"', "", cell) # Remove bgcolor attributes
cell = re.sub(r'valign=".*?"', "", cell) # Remove valign attributes
cell = re.sub(r'border=".*?"', "", cell) # Remove border attributes
cell = re.sub(
r'cellpadding=".*?"', "", cell
) # Remove cellpadding attributes
cell = re.sub(
r'cellspacing=".*?"', "", cell
) # Remove cellspacing attributes
cell = re.sub(r"<ref.*?</ref>", "", cell) # Remove references
cell = re.sub(r"<ref.*?/>", "", cell) # Remove empty references
cell = re.sub(
r"<br\s*/?>", " ", cell
) # Replace line breaks with spaces
cell = re.sub(r"\s+", " ", cell) # Normalize whitespace
cells.append(cell)
if cells:
current_row.extend(cells)
if current_row:
formatted_rows.append("\t".join(current_row))
if formatted_rows:
return "\n".join(formatted_rows)
return ""
def extract_wikipedia_title(self, search_result: str) -> Optional[str]:
"""Extract Wikipedia page title from search result."""
# Look for Wikipedia links in the format [Title - Wikipedia](url)
wiki_match = re.search(
r"\[([^\]]+)\s*-\s*Wikipedia\]\(https://en\.wikipedia\.org/wiki/[^)]+\)",
search_result,
)
if wiki_match:
return wiki_match.group(1).strip()
return None
def forward(self, query: str) -> str:
logger.info(f"Starting smart search for query: {query}")
# First do a web search to find the Wikipedia page
search_result = self.web_search_tool.forward(query)
logger.info(f"Web search results: {search_result[:100]}...")
# Extract Wikipedia page title from search results
wiki_title = self.extract_wikipedia_title(search_result)
if not wiki_title:
return f"Could not find Wikipedia page in search results for '{query}'."
# Get Wikipedia page content
page_content = self.get_wikipedia_page(wiki_title)
if not page_content:
return f"Could not find Wikipedia page for '{wiki_title}'."
# Format tables and content
formatted_content = []
current_section = []
in_table = False
table_content = []
for line in page_content.split("\n"):
if line.startswith("{|"):
in_table = True
table_content = [line]
elif line.startswith("|}"):
in_table = False
table_content.append(line)
formatted_table = self.format_wiki_table("\n".join(table_content))
if formatted_table:
current_section.append(formatted_table)
elif in_table:
table_content.append(line)
else:
if line.strip():
current_section.append(line)
elif current_section:
formatted_content.append("\n".join(current_section))
current_section = []
if current_section:
formatted_content.append("\n".join(current_section))
# Clean and return the formatted content
cleaned_content = self.clean_wiki_content("\n\n".join(formatted_content))
return f"Wikipedia content for '{wiki_title}':\n\n{cleaned_content}"
def main(query: str) -> str:
"""
Test function to run the SmartSearchTool directly.
Args:
query: The search query to test
Returns:
The search results
"""
# Configure logging
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
# Create and run the tool
tool = SmartSearchTool()
result = tool.forward(query)
# Print the result
print("\nSearch Results:")
print("-" * 80)
print(result)
print("-" * 80)
return result
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
query = " ".join(sys.argv[1:])
main(query)
else:
print("Usage: python tool.py <search query>")
print("Example: python tool.py 'Mercedes Sosa discography'")
|