|
import logging |
|
import os |
|
from typing import Dict, Iterator, List, Optional, Union |
|
|
|
import requests |
|
from dotenv import load_dotenv |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
API_KEY_ENV_VAR = "GOOGLE_API_KEY" |
|
BASE_URL = "https://generativelanguage.googleapis.com/v1beta/models/" |
|
DEFAULT_MODEL_ID = "gemini-2.0-flash" |
|
|
|
|
|
def _get_api_key() -> Optional[str]: |
|
""" |
|
Retrieves the Google API key from environment variables. |
|
|
|
Returns: |
|
Optional[str]: The API key if found, otherwise None. |
|
""" |
|
api_key = os.getenv(API_KEY_ENV_VAR) |
|
if not api_key: |
|
logger.error(f"API key not found. Set the variable '{API_KEY_ENV_VAR}'.") |
|
return api_key |
|
|
|
|
|
def _format_payload_for_gemini( |
|
messages: List[Dict], temperature: float, max_tokens: int |
|
) -> Optional[Dict]: |
|
""" |
|
Formats the message history and configuration into a valid payload for the Gemini REST API. |
|
|
|
This function performs two critical tasks: |
|
1. Separates the 'system' instruction from the main conversation history. |
|
2. Consolidates consecutive 'user' messages into a single block to comply with |
|
the Gemini API's requirement of alternating 'user' and 'model' roles. |
|
|
|
Args: |
|
messages (List[Dict]): A list of message dictionaries, potentially including a 'system' role. |
|
temperature (float): The generation temperature. |
|
max_tokens (int): The maximum number of tokens to generate. |
|
|
|
Returns: |
|
Optional[Dict]: A fully formed payload dictionary ready for the API, or None if the |
|
conversation history is empty. |
|
""" |
|
system_instruction = None |
|
conversation_history = [] |
|
|
|
for msg in messages: |
|
if msg.get("role") == "system": |
|
system_instruction = {"parts": [{"text": msg.get("content", "")}]} |
|
else: |
|
conversation_history.append(msg) |
|
|
|
if not conversation_history: |
|
return None |
|
|
|
consolidated_contents = [] |
|
current_block = None |
|
for msg in conversation_history: |
|
role = "model" if msg.get("role") == "assistant" else "user" |
|
content = msg.get("content", "") |
|
|
|
if current_block and current_block["role"] == "user" and role == "user": |
|
current_block["parts"][0]["text"] += "\n" + content |
|
else: |
|
if current_block: |
|
consolidated_contents.append(current_block) |
|
current_block = {"role": role, "parts": [{"text": content}]} |
|
|
|
if current_block: |
|
consolidated_contents.append(current_block) |
|
|
|
payload = { |
|
"contents": consolidated_contents, |
|
"safetySettings": [ |
|
{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"}, |
|
{"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"}, |
|
], |
|
"generationConfig": {"temperature": temperature, "maxOutputTokens": max_tokens}, |
|
} |
|
if system_instruction: |
|
payload["system_instruction"] = system_instruction |
|
|
|
return payload |
|
|
|
|
|
def call_gemini_api( |
|
messages: List[Dict], stream: bool = False, temperature: float = 0.7, max_tokens: int = 2048 |
|
) -> Union[Iterator[str], str]: |
|
""" |
|
Calls the Google Gemini REST API with the provided messages and parameters. |
|
|
|
This is the main public function of the module. It handles API key retrieval, |
|
payload formatting, making the HTTP request, and processing the response. |
|
|
|
Args: |
|
messages (List[Dict]): The list of messages forming the conversation context. |
|
stream (bool): If True, streams the response. (Currently not implemented). |
|
temperature (float): The generation temperature (creativity). |
|
max_tokens (int): The maximum number of tokens for the response. |
|
|
|
Returns: |
|
Union[Iterator[str], str]: An iterator of response chunks if streaming, or a single |
|
response string otherwise. Returns an error string on failure. |
|
""" |
|
api_key = _get_api_key() |
|
if not api_key: |
|
error_msg = "Error: Google API key not configured." |
|
return iter([error_msg]) if stream else error_msg |
|
|
|
payload = _format_payload_for_gemini(messages, temperature, max_tokens) |
|
if not payload or not payload.get("contents"): |
|
error_msg = "Error: Conversation is empty or malformed after processing." |
|
return iter([error_msg]) if stream else error_msg |
|
|
|
stream_param = "streamGenerateContent" if stream else "generateContent" |
|
request_url = f"{BASE_URL}{DEFAULT_MODEL_ID}:{stream_param}?key={api_key}" |
|
headers = {"Content-Type": "application/json"} |
|
|
|
try: |
|
response = requests.post( |
|
request_url, headers=headers, json=payload, stream=stream, timeout=180 |
|
) |
|
response.raise_for_status() |
|
|
|
if stream: |
|
|
|
pass |
|
else: |
|
data = response.json() |
|
|
|
if data.get("candidates") and data["candidates"][0].get("content", {}).get("parts"): |
|
return data["candidates"][0]["content"]["parts"][0]["text"] |
|
else: |
|
logger.warning( |
|
f"Gemini's response does not contain 'candidates'. Full response: {data}" |
|
) |
|
return "[BLOCKED OR EMPTY RESPONSE]" |
|
|
|
except requests.exceptions.HTTPError as e: |
|
err_msg = f"API HTTP Error ({e.response.status_code}): {e.response.text[:500]}" |
|
logger.error(err_msg, exc_info=False) |
|
return f"Error: {err_msg}" |
|
except Exception as e: |
|
logger.exception("Unexpected error while calling Gemini API:") |
|
return f"Error: {e}" |
|
|