File size: 7,753 Bytes
d0c134a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
#!/usr/bin/env python3
"""
๐Ÿš€ SmoLAgents Bridge for GAIA System
Integrates smolagents framework with our existing tools for 60+ point performance boost
"""

import os
import logging
from typing import Optional

# Try to import smolagents
try:
    from smolagents import CodeAgent, InferenceClientModel, tool, DuckDuckGoSearchTool
    from smolagents.tools import VisitWebpageTool
    SMOLAGENTS_AVAILABLE = True
except ImportError:
    SMOLAGENTS_AVAILABLE = False
    CodeAgent = None
    tool = None

# Import our existing system
from gaia_system import BasicAgent as FallbackAgent, UniversalMultimodalToolkit

logger = logging.getLogger(__name__)

class SmoLAgentsEnhancedAgent:
    """๐Ÿš€ Enhanced GAIA agent powered by SmoLAgents framework"""
    
    def __init__(self, hf_token: str = None, openai_key: str = None):
        self.hf_token = hf_token or os.getenv('HF_TOKEN')
        self.openai_key = openai_key or os.getenv('OPENAI_API_KEY')
        
        if not SMOLAGENTS_AVAILABLE:
            print("โš ๏ธ SmoLAgents not available, using fallback system")
            self.agent = FallbackAgent(hf_token, openai_key)
            self.use_smolagents = False
            return
        
        self.use_smolagents = True
        self.toolkit = UniversalMultimodalToolkit(self.hf_token, self.openai_key)
        
        # Create model with our priority system
        self.model = self._create_priority_model()
        
        # Create CodeAgent with our tools
        self.agent = self._create_code_agent()
        
        print("โœ… SmoLAgents GAIA System initialized")
    
    def _create_priority_model(self):
        """Create model with Qwen3-235B-A22B priority"""
        try:
            # Priority 1: Qwen3-235B-A22B (Best for GAIA)
            return InferenceClientModel(
                provider="fireworks-ai",
                api_key=self.hf_token,
                model="Qwen/Qwen3-235B-A22B"
            )
        except:
            try:
                # Priority 2: DeepSeek-R1
                return InferenceClientModel(
                    model="deepseek-ai/DeepSeek-R1",
                    token=self.hf_token
                )
            except:
                # Fallback
                return InferenceClientModel(
                    model="meta-llama/Llama-3.1-8B-Instruct",
                    token=self.hf_token
                )
    
    def _create_code_agent(self):
        """Create CodeAgent with essential tools"""
        # Create our custom tools
        calculator_tool = self._create_calculator_tool()
        image_tool = self._create_image_analysis_tool()
        download_tool = self._create_file_download_tool()
        pdf_tool = self._create_pdf_tool()
        
        tools = [
            DuckDuckGoSearchTool(),
            VisitWebpageTool(),
            calculator_tool,
            image_tool,
            download_tool,
            pdf_tool,
        ]
        
        return CodeAgent(
            tools=tools,
            model=self.model,
            system_prompt=self._get_gaia_prompt(),
            max_steps=3,
            verbosity=0
        )
    
    def _get_gaia_prompt(self):
        """GAIA-optimized system prompt"""
        return """You are a GAIA benchmark expert. Use tools to solve questions step-by-step.

CRITICAL: Provide ONLY the final answer - no explanations.
Format: number OR few words OR comma-separated list
No units unless specified. No articles for strings.

Available tools:
- DuckDuckGoSearchTool: Search the web
- VisitWebpageTool: Visit URLs
- calculator: Mathematical calculations
- analyze_image: Analyze images
- download_file: Download GAIA files
- read_pdf: Extract PDF text"""
    
    def _create_calculator_tool(self):
        """๐Ÿงฎ Mathematical calculations"""
        @tool
        def calculator(expression: str) -> str:
            """Perform mathematical calculations
            
            Args:
                expression: Mathematical expression to evaluate
            """
            return self.toolkit.calculator(expression)
        return calculator
    
    def _create_image_analysis_tool(self):
        """๐Ÿ–ผ๏ธ Image analysis"""
        @tool
        def analyze_image(image_path: str, question: str = "") -> str:
            """Analyze images and answer questions
            
            Args:
                image_path: Path to image file
                question: Question about the image
            """
            return self.toolkit.analyze_image(image_path, question)
        return analyze_image
    
    def _create_file_download_tool(self):
        """๐Ÿ“ฅ File downloads"""
        @tool
        def download_file(url: str = "", task_id: str = "") -> str:
            """Download files from URLs or GAIA tasks
            
            Args:
                url: URL to download from  
                task_id: GAIA task ID
            """
            return self.toolkit.download_file(url, task_id)
        return download_file
    
    def _create_pdf_tool(self):
        """๐Ÿ“„ PDF reading"""
        @tool
        def read_pdf(file_path: str) -> str:
            """Extract text from PDF documents
            
            Args:
                file_path: Path to PDF file
            """
            return self.toolkit.read_pdf(file_path)
        return read_pdf
    
    def query(self, question: str) -> str:
        """Process question with SmoLAgents or fallback"""
        if not self.use_smolagents:
            return self.agent.query(question)
        
        try:
            print(f"๐Ÿš€ Processing with SmoLAgents: {question[:80]}...")
            response = self.agent.run(question)
            cleaned = self._clean_response(response)
            print(f"โœ… SmoLAgents result: {cleaned}")
            return cleaned
        except Exception as e:
            print(f"โš ๏ธ SmoLAgents error: {e}, falling back to original system")
            # Fallback to original system
            fallback = FallbackAgent(self.hf_token, self.openai_key)
            return fallback.query(question)
    
    def _clean_response(self, response: str) -> str:
        """Clean response for GAIA compliance"""
        if not response:
            return "Unable to provide answer"
        
        response = response.strip()
        
        # Remove common prefixes
        prefixes = ["the answer is:", "answer:", "result:", "final answer:", "solution:"]
        response_lower = response.lower()
        for prefix in prefixes:
            if response_lower.startswith(prefix):
                response = response[len(prefix):].strip()
                break
        
        return response.rstrip('.')
    
    def clean_for_api_submission(self, response: str) -> str:
        """Clean response for GAIA API submission (compatibility method)"""
        return self._clean_response(response)
    
    def __call__(self, question: str) -> str:
        """Make agent callable"""
        return self.query(question)
    
    def cleanup(self):
        """Clean up resources"""
        if hasattr(self.toolkit, 'cleanup'):
            self.toolkit.cleanup()


def create_enhanced_agent(hf_token: str = None, openai_key: str = None) -> SmoLAgentsEnhancedAgent:
    """Factory function for enhanced agent"""
    return SmoLAgentsEnhancedAgent(hf_token, openai_key)


if __name__ == "__main__":
    # Quick test
    print("๐Ÿงช Testing SmoLAgents Bridge...")
    agent = SmoLAgentsEnhancedAgent()
    
    test_questions = [
        "What is 5 + 3?",
        "What is the capital of France?",
        "How many sides does a triangle have?"
    ]
    
    for q in test_questions:
        print(f"\nQ: {q}")
        print(f"A: {agent.query(q)}")
    
    print("\nโœ… Bridge test completed!")