TutorX-MCP / mcp_server /model /gemini_flash.py
Meet Patel
Refactor TutorX MCP server to remove legacy client and utility modules, update app.py for SSE integration, and enhance .gitignore to exclude .cursor directory. Clean up main.py for improved server configuration and streamline run script for better usability.
1af10cc
import google.generativeai as genai
from typing import Dict, Any, Optional, List, Literal, Union, TypeVar, Callable
import os
from functools import wraps
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
T = TypeVar('T')
class ModelError(Exception):
"""Custom exception for model-related errors"""
pass
def fallback_to_15_flash(method: Callable[..., T]) -> Callable[..., T]:
"""
Decorator to automatically fall back to 1.5 if 2.0 fails.
Only applies when the instance's version is '2.0'.
"""
@wraps(method)
async def wrapper(self: 'GeminiFlash', *args: Any, **kwargs: Any) -> T:
if self.version != '2.0' or not self._should_fallback:
return await method(self, *args, **kwargs)
try:
return await method(self, *args, **kwargs)
except Exception as e:
logger.warning(f"Error with Gemini 2.0 Flash: {str(e)}")
# Only fallback if we haven't already tried 1.5
if self.version == '2.0':
logger.info("Falling back to Gemini 1.5 Flash...")
fallback = GeminiFlash(version='1.5', api_key=self.api_key, _is_fallback=True)
return await getattr(fallback, method.__name__)(*args, **kwargs)
raise ModelError(f"Error with Gemini 1.5 Flash: {str(e)}")
return wrapper
class GeminiFlash:
"""
Google Gemini Flash model implementation with automatic fallback from 2.0 to 1.5.
"""
SUPPORTED_VERSIONS = ['2.0', '1.5']
def __init__(self, version: str = '2.0', api_key: Optional[str] = None, _is_fallback: bool = False):
"""
Initialize the Gemini Flash model.
Args:
version: Model version ('2.0' or '1.5')
api_key: Google AI API key. If not provided, will look for GOOGLE_API_KEY env var.
_is_fallback: Internal flag to indicate if this is a fallback instance.
"""
if version not in self.SUPPORTED_VERSIONS:
raise ValueError(f"Unsupported version: {version}. Supported versions: {self.SUPPORTED_VERSIONS}")
self.version = version
api_key="AIzaSyC3TpRUinxSCASXncgqhD1FJ6yqAq3j9rY"
self.api_key = api_key
if not self.api_key:
raise ValueError("GOOGLE_API_KEY environment variable not set and no API key provided")
self._should_fallback = version == '2.0' and not _is_fallback
genai.configure(api_key=self.api_key)
self.model_name = f'gemini-{version}-flash'
self.model = genai.GenerativeModel(self.model_name)
@fallback_to_15_flash
async def generate_text(
self,
prompt: str,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
top_k: int = 40,
**kwargs
) -> str:
"""
Generate text using Gemini Flash.
Args:
prompt: The input prompt
temperature: Controls randomness (0.0 to 1.0)
max_tokens: Maximum number of tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
**kwargs: Additional generation parameters
Returns:
Generated text response
Raises:
ModelError: If both 2.0 and 1.5 models fail
"""
response = await self.model.generate_content_async(
prompt,
generation_config={
'temperature': temperature,
'max_output_tokens': max_tokens,
'top_p': top_p,
'top_k': top_k,
**kwargs
}
)
return response.text
@fallback_to_15_flash
async def chat(
self,
messages: List[Dict[Literal['role', 'content'], str]],
temperature: float = 0.7,
max_tokens: int = 2048,
**kwargs
) -> str:
"""
Chat completion using Gemini Flash.
Args:
messages: List of message dictionaries with 'role' and 'content'
temperature: Controls randomness (0.0 to 1.0)
max_tokens: Maximum number of tokens to generate
**kwargs: Additional generation parameters
Returns:
Model's response
Raises:
ModelError: If both 2.0 and 1.5 models fail
"""
chat = self.model.start_chat(history=[])
# Process all but the last message as history
for message in messages[:-1]:
if message['role'] == 'user':
chat.send_message(message['content'])
# Get response for the last message
response = await chat.send_message_async(
messages[-1]['content'],
generation_config={
'temperature': temperature,
'max_output_tokens': max_tokens,
**kwargs
}
)
return response.text