Spaces:
Running
Running
Priyanshi Saxena
commited on
Commit
·
c785b3f
1
Parent(s):
2fe0e75
fix: Add missing validate_gemini_response method and cleanup method
Browse files- Added validate_gemini_response method to AISafetyGuard
- Added cleanup method to ChartDataTool to prevent AttributeError
- Improved Gemini tool parsing logic to handle all suggested tools
- Updated Gemini model to use 'gemini-2.0-flash-lite'
- src/agent/research_agent.py +14 -0
- src/tools/chart_data_tool.py +5 -0
- src/utils/ai_safety.py +34 -0
src/agent/research_agent.py
CHANGED
@@ -417,6 +417,20 @@ Respond with only the tool names, comma-separated (no explanations)."""
|
|
417 |
'etherscan_data', 'chart_data_provider'
|
418 |
}]
|
419 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
420 |
logger.info(f"🛠️ Gemini suggested tools: {suggested_tools}")
|
421 |
|
422 |
# Step 2: Execute tools (same logic as Ollama version)
|
|
|
417 |
'etherscan_data', 'chart_data_provider'
|
418 |
}]
|
419 |
|
420 |
+
# If no valid tools found, extract from response content
|
421 |
+
if not suggested_tools:
|
422 |
+
response_text = str(tool_response).lower()
|
423 |
+
if 'cryptocompare' in response_text:
|
424 |
+
suggested_tools.append('cryptocompare_data')
|
425 |
+
if 'coingecko' in response_text:
|
426 |
+
suggested_tools.append('coingecko_data')
|
427 |
+
if 'defillama' in response_text:
|
428 |
+
suggested_tools.append('defillama_data')
|
429 |
+
if 'etherscan' in response_text:
|
430 |
+
suggested_tools.append('etherscan_data')
|
431 |
+
if 'chart' in response_text or 'visualization' in response_text:
|
432 |
+
suggested_tools.append('chart_data_provider')
|
433 |
+
|
434 |
logger.info(f"🛠️ Gemini suggested tools: {suggested_tools}")
|
435 |
|
436 |
# Step 2: Execute tools (same logic as Ollama version)
|
src/tools/chart_data_tool.py
CHANGED
@@ -393,3 +393,8 @@ class ChartDataTool(BaseTool):
|
|
393 |
"1d": 1, "7d": 7, "30d": 30, "90d": 90, "365d": 365, "1y": 365
|
394 |
}
|
395 |
return timeframe_map.get(timeframe, 30)
|
|
|
|
|
|
|
|
|
|
|
|
393 |
"1d": 1, "7d": 7, "30d": 30, "90d": 90, "365d": 365, "1y": 365
|
394 |
}
|
395 |
return timeframe_map.get(timeframe, 30)
|
396 |
+
|
397 |
+
async def cleanup(self):
|
398 |
+
"""Cleanup method for session management"""
|
399 |
+
# ChartDataTool doesn't maintain persistent connections, so nothing to clean up
|
400 |
+
pass
|
src/utils/ai_safety.py
CHANGED
@@ -130,6 +130,40 @@ class AISafetyGuard:
|
|
130 |
|
131 |
return cleaned, True, "Response is safe"
|
132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
def create_safe_prompt(self, user_query: str, tool_context: str) -> str:
|
134 |
"""Create a safety-enhanced prompt for Ollama"""
|
135 |
safety_instructions = """
|
|
|
130 |
|
131 |
return cleaned, True, "Response is safe"
|
132 |
|
133 |
+
def validate_gemini_response(self, response: str) -> Tuple[str, bool, str]:
|
134 |
+
"""
|
135 |
+
Validate Gemini response for safety and quality
|
136 |
+
Returns: (cleaned_response, is_valid, reason)
|
137 |
+
"""
|
138 |
+
if not response or not response.strip():
|
139 |
+
return "", False, "Empty response from Gemini"
|
140 |
+
|
141 |
+
# Check for dangerous content in response
|
142 |
+
dangerous_patterns = [
|
143 |
+
r'(?i)here.*is.*how.*to.*hack',
|
144 |
+
r'(?i)steps.*to.*exploit',
|
145 |
+
r'(?i)bypass.*security.*by',
|
146 |
+
r'(?i)manipulate.*market.*by',
|
147 |
+
]
|
148 |
+
|
149 |
+
for pattern in dangerous_patterns:
|
150 |
+
if re.search(pattern, response):
|
151 |
+
logger.warning(f"Blocked unsafe Gemini response: {pattern}")
|
152 |
+
return "", False, "Response contains potentially unsafe content"
|
153 |
+
|
154 |
+
# Basic response cleaning
|
155 |
+
cleaned = response.strip()
|
156 |
+
|
157 |
+
# Remove any potential HTML/JavaScript
|
158 |
+
cleaned = re.sub(r'<script.*?</script>', '', cleaned, flags=re.DOTALL | re.IGNORECASE)
|
159 |
+
cleaned = re.sub(r'<[^>]+>', '', cleaned)
|
160 |
+
|
161 |
+
# Ensure response is within reasonable length
|
162 |
+
if len(cleaned) > 10000: # 10k character limit
|
163 |
+
cleaned = cleaned[:10000] + "\n\n[Response truncated for safety]"
|
164 |
+
|
165 |
+
return cleaned, True, "Response is safe"
|
166 |
+
|
167 |
def create_safe_prompt(self, user_query: str, tool_context: str) -> str:
|
168 |
"""Create a safety-enhanced prompt for Ollama"""
|
169 |
safety_instructions = """
|