File size: 5,101 Bytes
1af10cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import google.generativeai as genai
from typing import Dict, Any, Optional, List, Literal, Union, TypeVar, Callable
import os
from functools import wraps
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

T = TypeVar('T')

class ModelError(Exception):
    """Custom exception for model-related errors"""
    pass

def fallback_to_15_flash(method: Callable[..., T]) -> Callable[..., T]:
    """
    Decorator to automatically fall back to 1.5 if 2.0 fails.
    Only applies when the instance's version is '2.0'.
    """
    @wraps(method)
    async def wrapper(self: 'GeminiFlash', *args: Any, **kwargs: Any) -> T:
        if self.version != '2.0' or not self._should_fallback:
            return await method(self, *args, **kwargs)
            
        try:
            return await method(self, *args, **kwargs)
        except Exception as e:
            logger.warning(f"Error with Gemini 2.0 Flash: {str(e)}")
            
            # Only fallback if we haven't already tried 1.5
            if self.version == '2.0':
                logger.info("Falling back to Gemini 1.5 Flash...")
                fallback = GeminiFlash(version='1.5', api_key=self.api_key, _is_fallback=True)
                return await getattr(fallback, method.__name__)(*args, **kwargs)
            raise ModelError(f"Error with Gemini 1.5 Flash: {str(e)}")
    return wrapper

class GeminiFlash:
    """
    Google Gemini Flash model implementation with automatic fallback from 2.0 to 1.5.
    """
    
    SUPPORTED_VERSIONS = ['2.0', '1.5']
    
    def __init__(self, version: str = '2.0', api_key: Optional[str] = None, _is_fallback: bool = False):
        """
        Initialize the Gemini Flash model.
        
        Args:
            version: Model version ('2.0' or '1.5')
            api_key: Google AI API key. If not provided, will look for GOOGLE_API_KEY env var.
            _is_fallback: Internal flag to indicate if this is a fallback instance.
        """
        if version not in self.SUPPORTED_VERSIONS:
            raise ValueError(f"Unsupported version: {version}. Supported versions: {self.SUPPORTED_VERSIONS}")
            
        self.version = version
        api_key="AIzaSyC3TpRUinxSCASXncgqhD1FJ6yqAq3j9rY"
        self.api_key = api_key 
        if not self.api_key:
            raise ValueError("GOOGLE_API_KEY environment variable not set and no API key provided")
        
        self._should_fallback = version == '2.0' and not _is_fallback
        genai.configure(api_key=self.api_key)
        self.model_name = f'gemini-{version}-flash'
        self.model = genai.GenerativeModel(self.model_name)
    
    @fallback_to_15_flash
    async def generate_text(
        self,
        prompt: str,
        temperature: float = 0.7,
        max_tokens: int = 2048,
        top_p: float = 0.9,
        top_k: int = 40,
        **kwargs
    ) -> str:
        """
        Generate text using Gemini Flash.
        
        Args:
            prompt: The input prompt
            temperature: Controls randomness (0.0 to 1.0)
            max_tokens: Maximum number of tokens to generate
            top_p: Nucleus sampling parameter
            top_k: Top-k sampling parameter
            **kwargs: Additional generation parameters
            
        Returns:
            Generated text response
            
        Raises:
            ModelError: If both 2.0 and 1.5 models fail
        """
        response = await self.model.generate_content_async(
            prompt,
            generation_config={
                'temperature': temperature,
                'max_output_tokens': max_tokens,
                'top_p': top_p,
                'top_k': top_k,
                **kwargs
            }
        )
        return response.text
    
    @fallback_to_15_flash
    async def chat(
        self,
        messages: List[Dict[Literal['role', 'content'], str]],
        temperature: float = 0.7,
        max_tokens: int = 2048,
        **kwargs
    ) -> str:
        """
        Chat completion using Gemini Flash.
        
        Args:
            messages: List of message dictionaries with 'role' and 'content'
            temperature: Controls randomness (0.0 to 1.0)
            max_tokens: Maximum number of tokens to generate
            **kwargs: Additional generation parameters
            
        Returns:
            Model's response
            
        Raises:
            ModelError: If both 2.0 and 1.5 models fail
        """
        chat = self.model.start_chat(history=[])
        # Process all but the last message as history
        for message in messages[:-1]:
            if message['role'] == 'user':
                chat.send_message(message['content'])
        
        # Get response for the last message
        response = await chat.send_message_async(
            messages[-1]['content'],
            generation_config={
                'temperature': temperature,
                'max_output_tokens': max_tokens,
                **kwargs
            }
        )
        return response.text