Spaces:
Sleeping
Sleeping
import os | |
import time | |
import httpx | |
import warnings | |
from typing import List, Dict, Optional | |
from smolagents import ApiModel, ChatMessage | |
class GeminiApiModel(ApiModel): | |
""" | |
ApiModel implementation using the Google Gemini API via direct HTTP requests. | |
""" | |
def __init__( | |
self, | |
model_id: str = "gemini-pro", | |
api_key: Optional[str] = None, | |
**kwargs, | |
): | |
""" | |
Initializes the GeminiApiModel. | |
Args: | |
model_id (str): The Gemini model ID to use (e.g., "gemini-pro"). | |
api_key (str, optional): Google AI Studio API key. Defaults to GEMINI_API_KEY environment variable. | |
**kwargs: Additional keyword arguments passed to the parent ApiModel. | |
""" | |
self.model_id = model_id | |
# Prefer explicitly passed key, fallback to environment variable | |
self.api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY") | |
if not self.api_key: | |
warnings.warn( | |
"GEMINI_API_KEY not provided via argument or environment variable. API calls will likely fail.", | |
UserWarning, | |
) | |
# Gemini API doesn't inherently support complex role structures or function calling like OpenAI. | |
# We'll flatten messages for simplicity. | |
super().__init__( | |
model_id=model_id, | |
flatten_messages_as_text=True, # Flatten messages to a single text prompt | |
**kwargs, | |
) | |
def create_client(self): | |
"""No dedicated client needed as we use httpx directly.""" | |
return None # Or potentially return httpx client if reused | |
def __call__( | |
self, | |
messages: List[Dict[str, str]], | |
stop_sequences: Optional[ | |
List[str] | |
] = None, # Note: Gemini API might not support stop sequences directly here | |
grammar: Optional[ | |
str | |
] = None, # Note: Gemini API doesn't support grammar directly | |
tools_to_call_from: Optional[ | |
List["Tool"] | |
] = None, # Note: Basic Gemini API doesn't support tools | |
**kwargs, | |
) -> ChatMessage: | |
""" | |
Calls the Google Gemini API with the provided messages. | |
Args: | |
messages: A list of message dictionaries (e.g., [{'role': 'user', 'content': '...'}]). | |
stop_sequences: Optional stop sequences (may not be supported). | |
grammar: Optional grammar constraint (not supported). | |
tools_to_call_from: Optional list of tools (not supported). | |
**kwargs: Additional keyword arguments. | |
Returns: | |
A ChatMessage object containing the response. | |
""" | |
if not self.api_key: | |
raise ValueError("GEMINI_API_KEY is not set.") | |
# Prepare the prompt by concatenating message content | |
# The Gemini Pro basic API expects a simple text prompt. | |
prompt = self._messages_to_prompt(messages) | |
prompt += ( | |
"\n\n" | |
+ "If you have a result from a web search that looks helpful, please use httpx to get the HTML from the URL listed." | |
+ "You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string." | |
) | |
# print(f"--- Gemini API prompt: ---\n{prompt}\n--- End of prompt ---") | |
url = f"https://generativelanguage.googleapis.com/v1beta/models/{self.model_id}:generateContent?key={self.api_key}" | |
headers = {"Content-Type": "application/json"} | |
# Construct the payload according to Gemini API requirements | |
data = {"contents": [{"parts": [{"text": prompt}]}]} | |
# Add generation config if provided via kwargs (optional) | |
generation_config = {} | |
if "temperature" in kwargs: | |
generation_config["temperature"] = kwargs["temperature"] | |
if "max_output_tokens" in kwargs: | |
generation_config["maxOutputTokens"] = kwargs["max_output_tokens"] | |
# Add other relevant config parameters here if needed | |
if generation_config: | |
data["generationConfig"] = generation_config | |
# Handle stop sequences if provided (basic support) | |
# Note: This is a best-effort addition, check Gemini API docs for formal support | |
if stop_sequences: | |
if "generationConfig" not in data: | |
data["generationConfig"] = {} | |
# Assuming Gemini API might support 'stopSequences' in generationConfig | |
data["generationConfig"]["stopSequences"] = stop_sequences | |
raw_response = None | |
try: | |
response = httpx.post( | |
url, headers=headers, json=data, timeout=120.0 | |
) # Increased timeout | |
time.sleep(6) # Add delay to respect rate limits | |
response.raise_for_status() | |
response_json = response.json() | |
raw_response = response_json # Store raw response | |
# Parse the response - adjust based on actual Gemini API structure | |
if "candidates" in response_json and response_json["candidates"]: | |
part = response_json["candidates"][0]["content"]["parts"][0] | |
if "text" in part: | |
content = part["text"] | |
# Check for "FINAL ANSWER: " and extract the rest of the string | |
final_answer_marker = "FINAL ANSWER: " | |
if final_answer_marker in content: | |
content = content.split(final_answer_marker)[-1].strip() | |
# Simulate token counts if available, otherwise default to 0 | |
# The basic generateContent API might not return usage directly in the main response | |
# It might be in safetyRatings or other metadata if enabled/available. | |
# Setting to 0 for now as it's not reliably present in the simplest call. | |
self.last_input_token_count = 0 | |
self.last_output_token_count = 0 | |
# If usage data becomes available in response_json, parse it here: | |
# e.g., if response_json.get("usageMetadata"): | |
# self.last_input_token_count = response_json["usageMetadata"].get("promptTokenCount", 0) | |
# self.last_output_token_count = response_json["usageMetadata"].get("candidatesTokenCount", 0) | |
return ChatMessage( | |
role="assistant", content=content, raw=raw_response | |
) | |
# Handle cases where the expected response structure isn't found | |
error_content = f"Error or unexpected response format: {response_json}" | |
return ChatMessage( | |
role="assistant", content=error_content, raw=raw_response | |
) | |
except httpx.RequestError as exc: | |
error_content = ( | |
f"An error occurred while requesting {exc.request.url!r}: {exc}" | |
) | |
return ChatMessage( | |
role="assistant", content=error_content, raw={"error": str(exc)} | |
) | |
except httpx.HTTPStatusError as exc: | |
error_content = f"Error response {exc.response.status_code} while requesting {exc.request.url!r}: {exc.response.text}" | |
return ChatMessage( | |
role="assistant", | |
content=error_content, | |
raw={ | |
"error": str(exc), | |
"status_code": exc.response.status_code, | |
"response_text": exc.response.text, | |
}, | |
) | |
except Exception as e: | |
error_content = f"An unexpected error occurred: {e}" | |
return ChatMessage( | |
role="assistant", content=error_content, raw={"error": str(e)} | |
) | |
def _messages_to_prompt(self, messages: List[Dict[str, str]]) -> str: | |
"""Converts a list of messages into a single string prompt.""" | |
# Simple concatenation, could be more sophisticated based on roles if needed | |
# Ensure we handle cases where 'content' might not be a string (though it should be) | |
return "\n".join([str(msg.get("content", "")) for msg in messages]) | |