File size: 6,270 Bytes
6ca25ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
from typing import Any, Dict, List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from duckduckgo_search import DDGS
import re
import math
class WebSearchTool:
def __init__(self):
self.search = DDGS()
def run(self, query: str, max_results: int = 3) -> str:
"""Perform a web search and return formatted results."""
try:
results = list(self.search.text(query, max_results=max_results))
formatted_results = []
for r in results:
formatted_results.append(f"Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['link']}\n")
return "\n".join(formatted_results)
except Exception as e:
return f"Error performing web search: {str(e)}"
class Calculator:
def run(self, expression: str) -> str:
"""Evaluate mathematical expressions safely."""
try:
# Remove any characters that aren't numbers, operators, or parentheses
cleaned = re.sub(r'[^0-9+\-*/().\ ]', '', expression)
# Evaluate the expression
result = eval(cleaned, {"__builtins__": {}}, {"math": math})
return str(result)
except Exception as e:
return f"Error in calculation: {str(e)}"
class GaiaAgent:
def __init__(self):
# Initialize Qwen-7B model
self.model_name = "Qwen/Qwen-7B"
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map="auto",
trust_remote_code=True
).eval()
# Initialize tools
self.tools = {
"web_search": WebSearchTool(),
"calculator": Calculator()
}
# System prompt template
self.system_prompt = """You are a helpful AI assistant with access to the following tools:
1. web_search: Search the internet for current information
2. calculator: Perform mathematical calculations
To use a tool, respond with: <tool>tool_name|input</tool>
For example: <tool>calculator|2 + 2</tool> or <tool>web_search|latest news about AI</tool>
If you don't need any tools to answer, just provide your response directly.
Always explain your reasoning before using tools or providing final answers."""
def _generate_response(self, prompt: str, max_length: int = 2048) -> str:
"""Generate a response using the Qwen model."""
try:
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.model.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_length=max_length,
num_return_sequences=1,
temperature=0.7,
do_sample=True,
pad_token_id=self.tokenizer.pad_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
response = response.split(prompt)[-1].strip()
return response
except Exception as e:
return f"Error generating response: {str(e)}"
def _extract_tool_calls(self, response: str) -> List[Dict[str, str]]:
"""Extract tool calls from the response."""
tool_pattern = r'<tool>(.*?)\|(.*?)</tool>'
matches = re.finditer(tool_pattern, response)
tool_calls = []
for match in matches:
tool_name = match.group(1).strip()
tool_input = match.group(2).strip()
tool_calls.append({"name": tool_name, "input": tool_input})
return tool_calls
def _execute_tool_call(self, tool_call: Dict[str, str]) -> str:
"""Execute a single tool call and return the result."""
tool_name = tool_call["name"]
tool_input = tool_call["input"]
if tool_name not in self.tools:
return f"Error: Tool '{tool_name}' not found"
try:
result = self.tools[tool_name].run(tool_input)
return result
except Exception as e:
return f"Error executing {tool_name}: {str(e)}"
def process_question(self, question: str) -> str:
"""Process a single question and return the answer."""
# Construct the full prompt
full_prompt = f"{self.system_prompt}\n\nQuestion: {question}\n\nAnswer:"
# Get initial response
response = self._generate_response(full_prompt)
# Extract and execute any tool calls
tool_calls = self._extract_tool_calls(response)
if tool_calls:
# Execute each tool call and collect results
tool_results = []
for tool_call in tool_calls:
result = self._execute_tool_call(tool_call)
tool_results.append(f"Tool {tool_call['name']} result: {result}")
# Generate final response with tool results
tool_results_str = "\n".join(tool_results)
final_prompt = f"{full_prompt}\n{response}\n\nTool Results:\n{tool_results_str}\n\nFinal Answer:"
final_response = self._generate_response(final_prompt)
return final_response
return response
def get_answer(self, question_data: Dict[str, Any]) -> Optional[str]:
"""Process a question from the GAIA benchmark and return an answer."""
try:
# Extract the actual question from the question data
question = question_data.get("question", "")
if not question:
return None
# Process the question and get the answer
answer = self.process_question(question)
return answer
except Exception as e:
print(f"Error processing question: {str(e)}")
return None |