LamiaYT commited on
Commit
34c5bf3
·
1 Parent(s): 6ea9560

Initial commit with LlamaIndex-based agent

Browse files
Files changed (1) hide show
  1. app.py +287 -231
app.py CHANGED
@@ -1,8 +1,9 @@
1
- # app.py - Fixed for Local Instruction-Following Models
2
  from llama_index.llms.huggingface import HuggingFaceLLM
3
  from llama_index.core.agent import ReActAgent
4
  from llama_index.core.tools import FunctionTool
5
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
6
  import os
7
  import gradio as gr
8
  import requests
@@ -10,6 +11,7 @@ import pandas as pd
10
  import traceback
11
  import torch
12
  import re
 
13
 
14
  # Import real tool dependencies
15
  try:
@@ -19,7 +21,7 @@ except ImportError:
19
  DDGS = None
20
 
21
  try:
22
- from sympy import sympify, solve, simplify, N
23
  from sympy.core.sympify import SympifyError
24
  except ImportError:
25
  print("Warning: sympy not installed. Math calculator will be limited.")
@@ -29,253 +31,298 @@ except ImportError:
29
  # --- Constants ---
30
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
31
 
32
- # --- Smart Agent with Better Local Models ---
33
- class SmartAgent:
 
 
 
 
 
 
 
 
 
 
 
34
  def __init__(self):
35
- print("Initializing Local Instruction-Following Agent...")
36
 
37
- if torch.cuda.is_available():
38
- print(f"CUDA available. GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
39
- device_map = "auto"
40
- else:
41
- print("CUDA not available, using CPU")
42
- device_map = "cpu"
43
 
44
- # FIXED: Use instruction-following models, not chat models
45
- model_options = [
46
- "microsoft/DialoGPT-medium", # Remove this - it's for chat only
47
- "google/flan-t5-base", # Good for instructions
48
- "google/flan-t5-large", # Better reasoning (if memory allows)
49
- "microsoft/DialoGPT-small", # Fallback
50
- ]
51
 
52
- # Try FLAN-T5 first - it's designed for instruction following
53
- model_name = "google/flan-t5-base" # Start with smaller, reliable model
54
- print(f"Loading instruction model: {model_name}")
55
 
56
- try:
57
- # FLAN-T5 specific configuration
58
- self.llm = HuggingFaceLLM(
59
- model_name=model_name,
60
- tokenizer_name=model_name,
61
- context_window=1024,
62
- max_new_tokens=256,
63
- generate_kwargs={
64
- "temperature": 0.1,
65
- "do_sample": False, # Use greedy for more consistent answers
66
- "repetition_penalty": 1.1,
67
- },
68
- device_map=device_map,
69
- model_kwargs={
70
- "torch_dtype": torch.float16,
71
- "low_cpu_mem_usage": True,
72
- },
73
- # Clear system message for FLAN-T5
74
- system_message="Answer questions accurately using the provided tools when needed."
75
- )
76
- print(f"✅ Successfully loaded: {model_name}")
77
-
78
- except Exception as e:
79
- print(f" Failed to load {model_name}: {e}")
80
- print("🔄 Trying manual approach without LlamaIndex LLM wrapper...")
81
- # Try direct approach without complex wrapper
82
- self.llm = None
83
- self.use_direct_mode = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Define enhanced tools
 
 
 
 
86
  self.tools = [
87
  FunctionTool.from_defaults(
88
- fn=self.web_search,
89
- name="web_search",
90
- description="Search web for current information, facts, people, events, or recent data"
91
  ),
92
  FunctionTool.from_defaults(
93
- fn=self.math_calculator,
94
- name="math_calculator",
95
- description="Calculate mathematical expressions, solve equations, or perform numerical operations"
 
 
 
 
 
96
  )
97
  ]
98
-
99
- # Try to create agent, but prepare for direct mode
100
- try:
101
- if self.llm:
102
- self.agent = ReActAgent.from_tools(
103
- tools=self.tools,
104
- llm=self.llm,
105
- verbose=True,
106
- max_iterations=3,
107
- )
108
- print("✅ ReAct Agent created successfully")
109
- self.use_direct_mode = False
110
- else:
111
- raise Exception("No LLM available")
112
-
113
- except Exception as e:
114
- print(f"⚠️ Agent creation failed: {e}")
115
- print("🔄 Switching to direct tool mode...")
116
- self.agent = None
117
- self.use_direct_mode = True
118
 
119
- def web_search(self, query: str) -> str:
120
- """Enhanced web search"""
121
- print(f"🔍 Searching: {query}")
122
 
123
  if not DDGS:
124
- return "Web search unavailable"
125
 
126
  try:
127
  with DDGS() as ddgs:
128
- results = list(ddgs.text(query, max_results=5, region='wt-wt'))
 
129
 
130
- if results:
131
- # Format results clearly
132
- search_results = []
133
- for i, result in enumerate(results, 1):
134
- title = result.get('title', 'No title')
135
- body = result.get('body', '').strip()[:200]
136
- search_results.append(f"{i}. {title}\n {body}...")
137
-
138
- return f"Search results for '{query}':\n\n" + "\n\n".join(search_results)
139
- else:
140
  return f"No results found for: {query}"
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  except Exception as e:
143
  print(f"❌ Search error: {e}")
144
  return f"Search failed: {str(e)}"
145
 
146
- def math_calculator(self, expression: str) -> str:
147
- """Enhanced math calculator"""
148
- print(f"🧮 Calculating: {expression}")
149
 
150
  try:
151
- # Clean the expression
152
  clean_expr = expression.replace('^', '**').replace('×', '*').replace('÷', '/')
 
153
 
154
  if sympify:
155
- # Use SymPy for safe evaluation
156
- result = sympify(clean_expr)
157
- numerical = N(result, 10)
158
- return f"Calculation result: {numerical}"
 
 
 
 
 
 
 
 
 
 
 
 
159
  else:
160
- # Basic fallback
161
  result = eval(clean_expr)
162
- return f"Calculation result: {result}"
163
 
164
  except Exception as e:
165
  return f"Could not calculate '{expression}': {str(e)}"
166
 
167
- def __call__(self, question: str) -> str:
168
- print(f"\n🤔 Question: {question[:100]}...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- # If using direct mode (no LLM agent), route questions manually
171
- if self.use_direct_mode:
172
- return self._direct_question_answering(question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
- # Try using the agent
175
  try:
 
176
  response = self.agent.query(question)
177
- response_str = str(response).strip()
178
 
179
- # Check if response is meaningful
180
- if len(response_str) < 5 or response_str in ['?', '!', 'what', 'I']:
181
- print("⚠️ Poor agent response, switching to direct mode")
182
- return self._direct_question_answering(question)
183
 
184
- return response_str
 
185
 
186
  except Exception as e:
187
- print(f"❌ Agent failed: {e}")
188
- return self._direct_question_answering(question)
189
-
190
- def _direct_question_answering(self, question: str) -> str:
191
- """Direct question answering without LLM agent"""
192
- print("🎯 Using direct approach...")
193
-
194
  question_lower = question.lower()
195
 
196
- # Enhanced detection patterns
197
- search_patterns = [
198
- 'how many', 'who is', 'what is', 'when was', 'where is',
199
- 'mercedes sosa', 'albums', 'published', 'studio albums',
200
- 'between', 'winner', 'recipient', 'nationality', 'born',
201
- 'current', 'latest', 'recent', 'president', 'capital',
202
- 'malko', 'competition', 'award', 'founded', 'established'
203
- ]
204
-
205
- math_patterns = [
206
- 'calculate', 'compute', 'solve', 'equation', 'sum', 'total',
207
- 'average', 'percentage', '+', '-', '*', '/', '=', 'find x'
208
- ]
209
-
210
- needs_search = any(pattern in question_lower for pattern in search_patterns)
211
- needs_math = any(pattern in question_lower for pattern in math_patterns)
212
 
213
- # Check for numbers that suggest math
214
- has_math_numbers = bool(re.search(r'\d+\s*[\+\-\*/=]\s*\d+', question))
215
- if has_math_numbers:
216
- needs_math = True
 
217
 
218
- print(f"📊 Analysis - Search: {needs_search}, Math: {needs_math}")
219
-
220
- if needs_search:
221
- # Extract key search terms
222
- important_words = []
223
-
224
- # Special handling for specific questions
225
- if 'mercedes sosa' in question_lower and 'albums' in question_lower:
226
- search_query = "Mercedes Sosa studio albums discography 2000-2009"
227
- else:
228
- # General search term extraction
229
- words = question.replace('?', '').replace(',', '').split()
230
- skip_words = {'how', 'many', 'what', 'when', 'where', 'who', 'is', 'the', 'a', 'an', 'and', 'or', 'but', 'between', 'were', 'was', 'can', 'you', 'use'}
231
-
232
- for word in words:
233
- clean_word = word.lower().strip('.,!?;:()')
234
- if len(clean_word) > 2 and clean_word not in skip_words:
235
- important_words.append(clean_word)
236
-
237
- search_query = ' '.join(important_words[:5])
238
-
239
- print(f"🔍 Search query: {search_query}")
240
- search_result = self.web_search(search_query)
241
-
242
- # Try to extract specific answer from search results
243
- if 'albums' in question_lower and 'mercedes sosa' in question_lower:
244
- # Look for numbers in the search results
245
- numbers = re.findall(r'\b\d+\b', search_result)
246
- if numbers:
247
- return f"Based on web search, Mercedes Sosa published approximately {numbers[0]} studio albums between 2000-2009. Full search results:\n\n{search_result}"
248
-
249
- return f"Search results:\n\n{search_result}"
250
-
251
- if needs_math:
252
- # Extract mathematical expressions
253
- math_expressions = re.findall(r'[\d+\-*/().\s=]+', question)
254
- for expr in math_expressions:
255
- if any(op in expr for op in ['+', '-', '*', '/', '=']):
256
- result = self.math_calculator(expr.strip())
257
- return result
258
-
259
- # Default: Try a general web search
260
- key_words = question.split()[:5]
261
- general_query = ' '.join(word.strip('.,!?') for word in key_words if len(word) > 2)
262
-
263
- if general_query:
264
- search_result = self.web_search(general_query)
265
- return f"General search results:\n\n{search_result}"
266
-
267
- return f"I need more specific information to answer: {question[:100]}..."
268
-
269
 
270
  def cleanup_memory():
271
- """Clean up memory"""
272
  if torch.cuda.is_available():
273
  torch.cuda.empty_cache()
274
  print("🧹 Memory cleaned")
275
 
276
-
277
  def run_and_submit_all(profile: gr.OAuthProfile | None):
278
- """Run evaluation with better error handling"""
279
 
280
  if not profile:
281
  return "❌ Please login to Hugging Face first", None
@@ -290,12 +337,15 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
290
 
291
  cleanup_memory()
292
 
293
- # Initialize agent
294
  try:
295
- agent = SmartAgent()
296
- print("✅ Agent initialized")
 
297
  except Exception as e:
298
- return f"❌ Agent initialization failed: {str(e)}", None
 
 
299
 
300
  # Get space info
301
  space_id = os.getenv("SPACE_ID", "unknown")
@@ -316,7 +366,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
316
  answers_payload = []
317
 
318
  print("\n" + "="*50)
319
- print("🚀 STARTING EVALUATION")
320
  print("="*50)
321
 
322
  for i, item in enumerate(questions_data, 1):
@@ -328,17 +378,17 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
328
 
329
  print(f"\n📝 Question {i}/{len(questions_data)}")
330
  print(f"🆔 ID: {task_id}")
331
- print(f"❓ Q: {question_text}")
332
 
333
  try:
334
- # Get answer from agent
335
  answer = agent(question_text)
336
 
337
- # Ensure answer is not empty
338
- if not answer or len(answer.strip()) < 3:
339
- answer = f"Unable to process question about: {question_text[:50]}..."
340
 
341
- print(f"✅ A: {answer[:150]}...")
342
 
343
  # Store results
344
  answers_payload.append({
@@ -348,17 +398,17 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
348
 
349
  results_log.append({
350
  "Task ID": task_id,
351
- "Question": question_text[:100] + ("..." if len(question_text) > 100 else ""),
352
- "Answer": answer[:150] + ("..." if len(answer) > 150 else "")
353
  })
354
 
355
  # Memory cleanup every few questions
356
- if i % 5 == 0:
357
  cleanup_memory()
358
 
359
  except Exception as e:
360
  print(f"❌ Error processing {task_id}: {e}")
361
- error_answer = f"Error: {str(e)[:100]}"
362
 
363
  answers_payload.append({
364
  "task_id": task_id,
@@ -367,7 +417,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
367
 
368
  results_log.append({
369
  "Task ID": task_id,
370
- "Question": question_text[:100] + "...",
371
  "Answer": error_answer
372
  })
373
 
@@ -381,7 +431,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
381
  }
382
 
383
  try:
384
- response = requests.post(submit_url, json=submission_data, timeout=120)
385
  response.raise_for_status()
386
  result_data = response.json()
387
 
@@ -391,16 +441,23 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
391
  message = result_data.get('message', '')
392
 
393
  # Create final status message
394
- final_status = f"""🎉 EVALUATION COMPLETE!
395
 
396
  👤 User: {username}
 
397
  📊 Final Score: {score}%
398
  ✅ Correct: {correct}/{total}
399
- 🎯 Target: 30%+ {' ACHIEVED!' if score >= 30 else ' Keep improving!'}
400
 
401
  📝 Message: {message}
402
 
403
- 🔧 Mode Used: {'Direct Tool Mode' if hasattr(agent, 'use_direct_mode') and agent.use_direct_mode else 'Agent Mode'}
 
 
 
 
 
 
404
  """
405
 
406
  print(f"\n🏆 FINAL SCORE: {score}%")
@@ -411,20 +468,19 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
411
  print(error_msg)
412
  return error_msg, pd.DataFrame(results_log)
413
 
414
-
415
  # --- Gradio Interface ---
416
- with gr.Blocks(title="Fixed Local Agent", theme=gr.themes.Soft()) as demo:
417
- gr.Markdown("# 🔧 Fixed Local Agent (No API Required)")
418
  gr.Markdown("""
419
- **Key Fixes:**
420
- - Uses instruction-following models (FLAN-T5) instead of chat models
421
- - 🎯 Direct question routing when agent fails
422
- - 🔍 Enhanced web search with better keyword extraction
423
- - 🧮 Robust math calculator
424
- - 💾 Optimized for 16GB memory
425
- - 🛡️ Multiple fallback strategies
426
 
427
- **Target: 30%+ Score**
428
  """)
429
 
430
  with gr.Row():
@@ -432,19 +488,19 @@ with gr.Blocks(title="Fixed Local Agent", theme=gr.themes.Soft()) as demo:
432
 
433
  with gr.Row():
434
  run_button = gr.Button(
435
- "🚀 Run Fixed Evaluation",
436
  variant="primary",
437
  size="lg"
438
  )
439
 
440
  status_output = gr.Textbox(
441
  label="📊 Evaluation Results",
442
- lines=12,
443
  interactive=False
444
  )
445
 
446
  results_table = gr.DataFrame(
447
- label="📝 Question & Answer Details",
448
  wrap=True
449
  )
450
 
@@ -454,8 +510,8 @@ with gr.Blocks(title="Fixed Local Agent", theme=gr.themes.Soft()) as demo:
454
  )
455
 
456
  if __name__ == "__main__":
457
- print("🚀 Starting Fixed Local Agent...")
458
- print("💡 No API keys required - everything runs locally!")
459
  demo.launch(
460
  server_name="0.0.0.0",
461
  server_port=7860,
 
1
+ # app.py - Improved GAIA Agent with GPT-NeoX-20B + LoRA
2
  from llama_index.llms.huggingface import HuggingFaceLLM
3
  from llama_index.core.agent import ReActAgent
4
  from llama_index.core.tools import FunctionTool
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
6
+ from peft import LoraConfig, get_peft_model
7
  import os
8
  import gradio as gr
9
  import requests
 
11
  import traceback
12
  import torch
13
  import re
14
+ import json
15
 
16
  # Import real tool dependencies
17
  try:
 
21
  DDGS = None
22
 
23
  try:
24
+ from sympy import sympify, solve, simplify, N, symbols
25
  from sympy.core.sympify import SympifyError
26
  except ImportError:
27
  print("Warning: sympy not installed. Math calculator will be limited.")
 
31
  # --- Constants ---
32
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
33
 
34
+ def print_trainable_parameters(model):
35
+ """Print trainable parameters info"""
36
+ trainable_parameters = 0
37
+ all_parameters = 0
38
+ for _, param in model.named_parameters():
39
+ all_parameters += param.numel()
40
+ if param.requires_grad:
41
+ trainable_parameters += param.numel()
42
+ print(
43
+ f"Trainable: {trainable_parameters} || All: {all_parameters} || Trainable %: {100 * trainable_parameters / all_parameters:.2f}%"
44
+ )
45
+
46
+ class ImprovedGAIAAgent:
47
  def __init__(self):
48
+ print("🚀 Initializing Improved GAIA Agent with GPT-NeoX-20B...")
49
 
50
+ if not torch.cuda.is_available():
51
+ raise RuntimeError("CUDA required for GPT-NeoX-20B. Please use a GPU environment.")
 
 
 
 
52
 
53
+ gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
54
+ print(f"🔥 GPU Memory: {gpu_memory:.1f}GB")
 
 
 
 
 
55
 
56
+ # Model configuration
57
+ self.model_name = "EleutherAI/gpt-neox-20b"
 
58
 
59
+ # 4-bit quantization config for memory efficiency
60
+ self.bnb_config = BitsAndBytesConfig(
61
+ load_in_4bit=True,
62
+ bnb_4bit_use_double_quant=True,
63
+ bnb_4bit_quant_type="nf4",
64
+ bnb_4bit_compute_dtype=torch.bfloat16
65
+ )
66
+
67
+ # LoRA configuration for efficient fine-tuning capability
68
+ self.lora_config = LoraConfig(
69
+ r=16, # Increased for better performance
70
+ lora_alpha=32,
71
+ target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], # More comprehensive targets
72
+ lora_dropout=0.1,
73
+ bias="none",
74
+ task_type="CAUSAL_LM"
75
+ )
76
+
77
+ self.load_model()
78
+ self.setup_tools()
79
+ self.create_agent()
80
+
81
+ def load_model(self):
82
+ """Load and configure the model"""
83
+ print("📥 Loading tokenizer...")
84
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
85
+
86
+ # Add padding token if not present
87
+ if self.tokenizer.pad_token is None:
88
+ self.tokenizer.pad_token = self.tokenizer.eos_token
89
+
90
+ print("📥 Loading model with 4-bit quantization...")
91
+ self.model = AutoModelForCausalLM.from_pretrained(
92
+ self.model_name,
93
+ quantization_config=self.bnb_config,
94
+ device_map="auto",
95
+ trust_remote_code=True,
96
+ torch_dtype=torch.bfloat16
97
+ )
98
+
99
+ print("🔧 Applying LoRA configuration...")
100
+ self.model = get_peft_model(self.model, self.lora_config)
101
+ print_trainable_parameters(self.model)
102
+
103
+ # Create LlamaIndex LLM wrapper
104
+ print("🔗 Creating LlamaIndex LLM wrapper...")
105
+ self.llm = HuggingFaceLLM(
106
+ model=self.model,
107
+ tokenizer=self.tokenizer,
108
+ context_window=2048, # GPT-NeoX context length
109
+ max_new_tokens=512,
110
+ generate_kwargs={
111
+ "temperature": 0.1,
112
+ "do_sample": True,
113
+ "top_p": 0.9,
114
+ "repetition_penalty": 1.1,
115
+ "pad_token_id": self.tokenizer.eos_token_id,
116
+ },
117
+ # Improved system message for GAIA tasks
118
+ system_message="""You are a helpful AI assistant that can search the web and perform calculations.
119
+ When answering questions:
120
+ 1. Think step by step
121
+ 2. Use tools when you need current information or calculations
122
+ 3. Be precise and factual
123
+ 4. For numerical answers, provide exact numbers when possible
124
+ 5. Always show your reasoning
125
 
126
+ Available tools: web_search, math_calculator"""
127
+ )
128
+
129
+ def setup_tools(self):
130
+ """Setup enhanced tools for GAIA benchmark"""
131
  self.tools = [
132
  FunctionTool.from_defaults(
133
+ fn=self.enhanced_web_search,
134
+ name="web_search",
135
+ description="Search the web for current information, facts, people, events, or recent data. Use specific keywords."
136
  ),
137
  FunctionTool.from_defaults(
138
+ fn=self.advanced_calculator,
139
+ name="math_calculator",
140
+ description="Perform mathematical calculations, solve equations, handle percentages, averages, and complex math operations."
141
+ ),
142
+ FunctionTool.from_defaults(
143
+ fn=self.fact_checker,
144
+ name="fact_checker",
145
+ description="Verify facts and get detailed information about people, places, events, or concepts."
146
  )
147
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
149
+ def enhanced_web_search(self, query: str) -> str:
150
+ """Enhanced web search with better result processing"""
151
+ print(f"🔍 Enhanced search: {query}")
152
 
153
  if not DDGS:
154
+ return "Web search unavailable - duckduckgo_search not installed"
155
 
156
  try:
157
  with DDGS() as ddgs:
158
+ # Get both regular results and news if relevant
159
+ results = list(ddgs.text(query, max_results=8, region='wt-wt'))
160
 
161
+ if not results:
 
 
 
 
 
 
 
 
 
162
  return f"No results found for: {query}"
163
+
164
+ # Process and format results
165
+ formatted_results = []
166
+ for i, result in enumerate(results, 1):
167
+ title = result.get('title', 'No title')
168
+ body = result.get('body', '').strip()
169
+ url = result.get('href', '')
170
+
171
+ # Extract key information
172
+ if len(body) > 300:
173
+ body = body[:300] + "..."
174
 
175
+ formatted_results.append(f"""Result {i}: {title}
176
+ Content: {body}
177
+ Source: {url}
178
+ """)
179
+
180
+ search_summary = f"Search results for '{query}':\n\n" + "\n".join(formatted_results)
181
+
182
+ # Try to extract specific answers for common question types
183
+ if any(keyword in query.lower() for keyword in ['how many', 'when was', 'who is', 'what year']):
184
+ # Look for numbers and dates in results
185
+ all_text = " ".join([r.get('body', '') for r in results])
186
+
187
+ # Extract years
188
+ years = re.findall(r'\b(19|20)\d{2}\b', all_text)
189
+ if years and 'when' in query.lower():
190
+ search_summary += f"\n\nExtracted years: {', '.join(set(years))}"
191
+
192
+ # Extract numbers
193
+ numbers = re.findall(r'\b\d+\b', all_text)
194
+ if numbers and 'how many' in query.lower():
195
+ search_summary += f"\n\nExtracted numbers: {', '.join(set(numbers)[:5])}"
196
+
197
+ return search_summary
198
+
199
  except Exception as e:
200
  print(f"❌ Search error: {e}")
201
  return f"Search failed: {str(e)}"
202
 
203
+ def advanced_calculator(self, expression: str) -> str:
204
+ """Advanced calculator with symbolic math"""
205
+ print(f"🧮 Advanced calculation: {expression}")
206
 
207
  try:
208
+ # Clean and normalize the expression
209
  clean_expr = expression.replace('^', '**').replace('×', '*').replace('÷', '/')
210
+ clean_expr = re.sub(r'(\d)\s*\(', r'\1*(', clean_expr) # Add implicit multiplication
211
 
212
  if sympify:
213
+ try:
214
+ # Try symbolic computation first
215
+ expr = sympify(clean_expr, evaluate=False)
216
+ result = simplify(expr)
217
+ numerical = N(result, 15) # High precision
218
+
219
+ # Handle different result types
220
+ if result.is_number:
221
+ return f"Calculation: {expression} = {numerical}"
222
+ else:
223
+ return f"Calculation: {expression} = {result} ≈ {numerical}"
224
+
225
+ except SympifyError:
226
+ # Fallback to numerical evaluation
227
+ result = eval(clean_expr)
228
+ return f"Calculation: {expression} = {result}"
229
  else:
230
+ # Basic evaluation
231
  result = eval(clean_expr)
232
+ return f"Calculation: {expression} = {result}"
233
 
234
  except Exception as e:
235
  return f"Could not calculate '{expression}': {str(e)}"
236
 
237
+ def fact_checker(self, query: str) -> str:
238
+ """Specialized fact checking with multiple search strategies"""
239
+ print(f"✅ Fact checking: {query}")
240
+
241
+ # Try different search strategies
242
+ search_variations = [
243
+ query,
244
+ f"{query} facts",
245
+ f"{query} biography" if any(word in query.lower() for word in ['who is', 'person', 'artist']) else f"{query} information",
246
+ ]
247
+
248
+ all_results = []
249
+ for search_query in search_variations[:2]: # Limit to avoid rate limiting
250
+ result = self.enhanced_web_search(search_query)
251
+ if "No results found" not in result:
252
+ all_results.append(f"Search: {search_query}\n{result}")
253
 
254
+ return "\n\n" + "="*50 + "\n\n".join(all_results) if all_results else f"Could not verify facts about: {query}"
255
+
256
+ def create_agent(self):
257
+ """Create the ReAct agent"""
258
+ print("🤖 Creating ReAct agent...")
259
+ try:
260
+ self.agent = ReActAgent.from_tools(
261
+ tools=self.tools,
262
+ llm=self.llm,
263
+ verbose=True,
264
+ max_iterations=5, # Allow more iterations for complex problems
265
+ react_chat_formatter=None, # Use default formatter
266
+ )
267
+ print("✅ ReAct Agent created successfully")
268
+ except Exception as e:
269
+ print(f"❌ Agent creation failed: {e}")
270
+ traceback.print_exc()
271
+ raise
272
+
273
+ def __call__(self, question: str) -> str:
274
+ """Process question through the agent"""
275
+ print(f"\n" + "="*60)
276
+ print(f"🤔 Processing: {question}")
277
+ print("="*60)
278
 
 
279
  try:
280
+ # Use the agent to process the question
281
  response = self.agent.query(question)
282
+ answer = str(response).strip()
283
 
284
+ # Validate response quality
285
+ if len(answer) < 10 or answer.lower() in ['error', 'none', 'unknown']:
286
+ print("⚠️ Poor response, trying direct approach...")
287
+ return self._direct_approach(question)
288
 
289
+ print(f"✅ Agent response: {answer[:200]}...")
290
+ return answer
291
 
292
  except Exception as e:
293
+ print(f"❌ Agent error: {e}")
294
+ print("🔄 Falling back to direct approach...")
295
+ return self._direct_approach(question)
296
+
297
+ def _direct_approach(self, question: str) -> str:
298
+ """Direct approach when agent fails"""
 
299
  question_lower = question.lower()
300
 
301
+ # Determine approach based on question type
302
+ if any(term in question_lower for term in ['calculate', 'compute', 'math', '+', '-', '*', '/', '=', 'percentage', 'average']):
303
+ # Math-focused approach
304
+ math_result = self.advanced_calculator(question)
305
+ return math_result
 
 
 
 
 
 
 
 
 
 
 
306
 
307
+ elif any(term in question_lower for term in ['who is', 'when was', 'where is', 'what is', 'how many']):
308
+ # Search-focused approach
309
+ search_result = self.enhanced_web_search(question)
310
+ fact_result = self.fact_checker(question)
311
+ return f"{search_result}\n\nFact Check:\n{fact_result}"
312
 
313
+ else:
314
+ # General approach
315
+ search_result = self.enhanced_web_search(question)
316
+ return search_result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  def cleanup_memory():
319
+ """Clean up GPU memory"""
320
  if torch.cuda.is_available():
321
  torch.cuda.empty_cache()
322
  print("🧹 Memory cleaned")
323
 
 
324
  def run_and_submit_all(profile: gr.OAuthProfile | None):
325
+ """Run evaluation with improved agent"""
326
 
327
  if not profile:
328
  return "❌ Please login to Hugging Face first", None
 
337
 
338
  cleanup_memory()
339
 
340
+ # Initialize improved agent
341
  try:
342
+ print("🚀 Initializing Improved GAIA Agent...")
343
+ agent = ImprovedGAIAAgent()
344
+ print("✅ Agent initialized successfully")
345
  except Exception as e:
346
+ error_msg = f"❌ Agent initialization failed: {str(e)}\n{traceback.format_exc()}"
347
+ print(error_msg)
348
+ return error_msg, None
349
 
350
  # Get space info
351
  space_id = os.getenv("SPACE_ID", "unknown")
 
366
  answers_payload = []
367
 
368
  print("\n" + "="*50)
369
+ print("🚀 STARTING GAIA EVALUATION")
370
  print("="*50)
371
 
372
  for i, item in enumerate(questions_data, 1):
 
378
 
379
  print(f"\n📝 Question {i}/{len(questions_data)}")
380
  print(f"🆔 ID: {task_id}")
381
+ print(f"❓ Question: {question_text}")
382
 
383
  try:
384
+ # Get answer from improved agent
385
  answer = agent(question_text)
386
 
387
+ # Ensure answer is meaningful
388
+ if not answer or len(answer.strip()) < 5:
389
+ answer = f"Unable to determine answer for: {question_text[:100]}..."
390
 
391
+ print(f"✅ Answer: {answer[:200]}...")
392
 
393
  # Store results
394
  answers_payload.append({
 
398
 
399
  results_log.append({
400
  "Task ID": task_id,
401
+ "Question": question_text[:150] + ("..." if len(question_text) > 150 else ""),
402
+ "Answer": answer[:200] + ("..." if len(answer) > 200 else "")
403
  })
404
 
405
  # Memory cleanup every few questions
406
+ if i % 3 == 0:
407
  cleanup_memory()
408
 
409
  except Exception as e:
410
  print(f"❌ Error processing {task_id}: {e}")
411
+ error_answer = f"Processing error: {str(e)[:150]}"
412
 
413
  answers_payload.append({
414
  "task_id": task_id,
 
417
 
418
  results_log.append({
419
  "Task ID": task_id,
420
+ "Question": question_text[:150] + "...",
421
  "Answer": error_answer
422
  })
423
 
 
431
  }
432
 
433
  try:
434
+ response = requests.post(submit_url, json=submission_data, timeout=180)
435
  response.raise_for_status()
436
  result_data = response.json()
437
 
 
441
  message = result_data.get('message', '')
442
 
443
  # Create final status message
444
+ final_status = f"""🎉 IMPROVED GAIA EVALUATION COMPLETE!
445
 
446
  👤 User: {username}
447
+ 🤖 Model: GPT-NeoX-20B + LoRA + 4-bit Quantization
448
  📊 Final Score: {score}%
449
  ✅ Correct: {correct}/{total}
450
+ 🎯 Target: 30%+ {'🎉 ACHIEVED!' if score >= 30 else '📈 Significant improvement expected!'}
451
 
452
  📝 Message: {message}
453
 
454
+ 🔧 Improvements Made:
455
+ - ✅ Proper causal LM (GPT-NeoX-20B) instead of encoder-decoder
456
+ - ✅ 4-bit quantization for memory efficiency
457
+ - ✅ LoRA for better parameter efficiency
458
+ - ✅ Enhanced tools with fact checking
459
+ - ✅ Better reasoning prompts
460
+ - ✅ Multi-strategy search approach
461
  """
462
 
463
  print(f"\n🏆 FINAL SCORE: {score}%")
 
468
  print(error_msg)
469
  return error_msg, pd.DataFrame(results_log)
470
 
 
471
  # --- Gradio Interface ---
472
+ with gr.Blocks(title="Improved GAIA Agent", theme=gr.themes.Soft()) as demo:
473
+ gr.Markdown("# 🚀 Improved GAIA Agent - GPT-NeoX-20B + LoRA")
474
  gr.Markdown("""
475
+ **Major Improvements:**
476
+ - 🧠 **GPT-NeoX-20B**: 20B parameter causal language model (vs 220M FLAN-T5)
477
+ - **4-bit Quantization**: Memory efficient loading with BitsAndBytes
478
+ - 🎯 **LoRA**: Parameter-efficient fine-tuning ready
479
+ - 🔍 **Enhanced Tools**: Multi-strategy search + fact checking + advanced math
480
+ - 🤖 **Better ReAct**: Improved reasoning prompts and error handling
481
+ - 📈 **Expected**: Significant improvement over 0% baseline
482
 
483
+ **Requirements**: CUDA GPU with 16GB+ VRAM
484
  """)
485
 
486
  with gr.Row():
 
488
 
489
  with gr.Row():
490
  run_button = gr.Button(
491
+ "🚀 Run Improved GAIA Evaluation",
492
  variant="primary",
493
  size="lg"
494
  )
495
 
496
  status_output = gr.Textbox(
497
  label="📊 Evaluation Results",
498
+ lines=15,
499
  interactive=False
500
  )
501
 
502
  results_table = gr.DataFrame(
503
+ label="📝 Detailed Results",
504
  wrap=True
505
  )
506
 
 
510
  )
511
 
512
  if __name__ == "__main__":
513
+ print("🚀 Starting Improved GAIA Agent...")
514
+ print("💪 Using GPT-NeoX-20B + LoRA + 4-bit Quantization")
515
  demo.launch(
516
  server_name="0.0.0.0",
517
  server_port=7860,