First_agent_template / tools /arxiv_tool.py
Ferocious0xide's picture
Update tools/arxiv_tool.py
9dbae6e verified
raw
history blame
5.11 kB
import urllib.request
import xml.etree.ElementTree as ET
from datetime import datetime, timedelta
import json
import os
from typing import List, Dict
from smolagents.tools import Tool
class ArxivSearchTool(Tool):
name = "search_arxiv"
description = "Search ArXiv for papers matching the query"
inputs = {
'query': {
'type': 'string',
'description': 'The search query for papers',
'nullable': True
},
'max_results': {
'type': 'integer',
'description': 'Maximum number of results to return',
'nullable': True
}
}
output_type = "string"
def forward(self, query: str = "artificial intelligence",
max_results: int = 50) -> str:
try:
# Construct the API URL
base_url = 'http://export.arxiv.org/api/query?'
query_params = {
'search_query': query,
'start': 0,
'max_results': max_results
}
# Create the full URL
url = base_url + urllib.parse.urlencode(query_params)
# Make the request
response = urllib.request.urlopen(url)
data = response.read().decode('utf-8')
# Parse the Atom XML response
root = ET.fromstring(data)
# Define the Atom namespace
ns = {'atom': 'http://www.w3.org/2005/Atom',
'arxiv': 'http://arxiv.org/schemas/atom'}
# Format results as a string
formatted_results = "## ArXiv Search Results\n\n"
for entry in root.findall('atom:entry', ns):
title = entry.find('atom:title', ns).text.strip()
authors = [author.find('atom:name', ns).text
for author in entry.findall('atom:author', ns)]
summary = entry.find('atom:summary', ns).text.strip() if entry.find('atom:summary', ns) is not None else ''
published = entry.find('atom:published', ns).text.strip()
paper_id = entry.find('atom:id', ns).text.strip()
pdf_url = next((link.get('href') for link in entry.findall('atom:link', ns)
if link.get('type') == 'application/pdf'), None)
categories = [cat.get('term') for cat in entry.findall('atom:category', ns)]
formatted_results += f"### {title}\n"
formatted_results += f"- Authors: {', '.join(authors)}\n"
formatted_results += f"- Published: {published}\n"
formatted_results += f"- Categories: {', '.join(categories)}\n"
formatted_results += f"- PDF: {pdf_url}\n"
formatted_results += f"- Summary: {summary}\n\n"
return formatted_results
except Exception as e:
return f"Error searching ArXiv: {str(e)}"
class LatestPapersTool(Tool):
name = "get_latest_papers"
description = "Get papers from the last N days from saved results"
inputs = {
'days_back': {
'type': 'integer',
'description': 'Number of days to look back',
'nullable': True
}
}
output_type = "string"
def forward(self, days_back: int = 1) -> str:
try:
papers = []
base_dir = "daily_papers"
# Get dates to check
dates = [
(datetime.now() - timedelta(days=i)).strftime("%Y-%m-%d")
for i in range(days_back)
]
# Load papers for each date
for date in dates:
file_path = os.path.join(base_dir, f"ai_papers_{date}.json")
if os.path.exists(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
day_papers = json.load(f)
papers.extend(day_papers)
# Format results as a string
formatted_results = f"## Latest Papers (Past {days_back} days)\n\n"
for paper in papers:
formatted_results += f"### {paper.get('title', 'Untitled')}\n"
formatted_results += f"- Authors: {', '.join(paper.get('authors', ['Unknown']))}\n"
formatted_results += f"- Published: {paper.get('published', 'Unknown')}\n"
formatted_results += f"- Categories: {', '.join(paper.get('categories', []))}\n"
if paper.get('pdf_url'):
formatted_results += f"- PDF: {paper['pdf_url']}\n"
if paper.get('summary'):
formatted_results += f"- Summary: {paper['summary']}\n"
formatted_results += "\n"
return formatted_results if papers else f"No papers found in the last {days_back} days."
except Exception as e:
return f"Error getting latest papers: {str(e)}"