sonyps1928 commited on
Commit
1b3fa51
Β·
1 Parent(s): adb694f

update app16

Browse files
Files changed (2) hide show
  1. app.py +362 -159
  2. requirements.txt +6 -4
app.py CHANGED
@@ -1,173 +1,376 @@
1
- import streamlit as st
2
- import os
3
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
4
- import torch
5
-
6
- # ----------------------------
7
- # Page config
8
- # ----------------------------
9
- st.set_page_config(
10
- page_title="GPT-2 Text Generator",
11
- page_icon="πŸ€–",
12
- layout="wide"
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  )
14
 
15
- # ----------------------------
16
- # Load environment variables
17
- # ----------------------------
18
- HF_TOKEN = os.getenv("HF_TOKEN")
19
- API_KEY = os.getenv("API_KEY")
20
- ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
21
-
22
- # ----------------------------
23
- # Model loading
24
- # ----------------------------
25
- @st.cache_resource
26
- def load_model():
27
- """Load and cache the GPT-2 model"""
28
- with st.spinner("Loading GPT-2 model..."):
29
- try:
30
- tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
31
- model = GPT2LMHeadModel.from_pretrained("gpt2")
32
- tokenizer.pad_token = tokenizer.eos_token
33
- return tokenizer, model
34
- except Exception as e:
35
- st.error(f"Error loading model: {e}")
36
- return None, None
37
-
38
- # ----------------------------
39
- # Text generation
40
- # ----------------------------
41
- def generate_text(prompt, max_length, temperature, tokenizer, model):
42
- """Generate text using GPT-2"""
43
- if not prompt:
44
- return "Please enter a prompt"
45
-
46
- if len(prompt) > 500:
47
- return "Prompt too long (max 500 characters)"
48
 
49
- try:
50
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True)
51
-
52
- with torch.no_grad():
53
- outputs = model.generate(
54
- inputs,
55
- max_length=inputs.shape[1] + max_length,
56
- temperature=temperature,
57
- do_sample=True,
58
- pad_token_id=tokenizer.eos_token_id,
59
- eos_token_id=tokenizer.eos_token_id,
60
- no_repeat_ngram_size=2
61
- )
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
64
- new_text = generated_text[len(prompt):].strip()
65
 
66
- return new_text if new_text else "No text generated. Try a different prompt."
 
67
 
68
- except Exception as e:
69
- return f"Error generating text: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # ----------------------------
72
  # Authentication
73
- # ----------------------------
74
- def check_auth():
75
- """Handle authentication"""
76
- if ADMIN_PASSWORD:
77
- if "authenticated" not in st.session_state:
78
- st.session_state.authenticated = False
79
-
80
- if not st.session_state.authenticated:
81
- st.title("πŸ”’ Authentication Required")
82
- password = st.text_input("Enter admin password:", type="password")
83
- if st.button("Login"):
84
- if password == ADMIN_PASSWORD:
85
- st.session_state.authenticated = True
86
- st.experimental_rerun()
87
- else:
88
- st.error("Invalid password")
89
- return False
90
- return True
91
-
92
- # ----------------------------
93
- # Main UI
94
- # ----------------------------
95
- def main():
96
- if not check_auth():
97
- return
98
-
99
- tokenizer, model = load_model()
100
- if tokenizer is None or model is None:
101
- st.error("Failed to load model. Please check the logs.")
102
- return
103
-
104
- st.title("πŸ€– GPT-2 Text Generator")
105
- st.markdown("Generate text using GPT-2 language model")
106
-
107
- # Security status
108
- col1, col2, col3 = st.columns(3)
109
- with col1:
110
- st.success("πŸ”‘ HF Token: Active" if HF_TOKEN else "πŸ”‘ HF Token: Not set")
111
- with col2:
112
- st.success("πŸ”’ API Auth: Enabled" if API_KEY else "πŸ”’ API Auth: Disabled")
113
- with col3:
114
- st.success("πŸ‘€ Admin Auth: Active" if ADMIN_PASSWORD else "πŸ‘€ Admin Auth: Disabled")
115
-
116
- # Input section
117
- st.subheader("πŸ“ Input")
118
- col1, col2 = st.columns([2, 1])
119
-
120
- with col1:
121
- prompt = st.text_area(
122
- "Enter your prompt:",
123
- placeholder="Type your text here...",
124
- height=100
125
- )
126
- api_key = ""
127
- if API_KEY:
128
- api_key = st.text_input("API Key:", type="password")
129
-
130
- with col2:
131
- st.subheader("βš™οΈ Settings")
132
- max_length = st.slider("Max Length", 20, 200, 100, 10)
133
- temperature = st.slider("Temperature", 0.1, 1.5, 0.7, 0.1)
134
- generate_btn = st.button("πŸš€ Generate Text", type="primary")
135
-
136
- # API key validation
137
- if API_KEY and generate_btn:
138
- if not api_key or api_key != API_KEY:
139
- st.error("πŸ”’ Invalid or missing API key")
140
- return
141
-
142
- # Generate text
143
- if generate_btn and prompt:
144
- with st.spinner("Generating text..."):
145
- result = generate_text(prompt, max_length, temperature, tokenizer, model)
146
- st.subheader("πŸ“„ Generated Text")
147
- st.text_area("Output:", value=result, height=200)
148
- st.code(result)
149
- elif generate_btn:
150
- st.warning("Please enter a prompt")
151
-
152
- # Example prompts
153
- st.subheader("πŸ’‘ Example Prompts")
154
- examples = [
155
- "Once upon a time in a distant galaxy,",
156
- "The future of artificial intelligence is",
157
- "In the heart of the ancient forest,",
158
- "The detective walked into the room and noticed"
159
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- cols = st.columns(len(examples))
162
- for i, example in enumerate(examples):
163
- with cols[i]:
164
- if st.button(f"Use Example {i+1}", key=f"ex_{i}"):
165
- st.session_state.example_prompt = example
166
- st.experimental_rerun()
167
 
168
- if hasattr(st.session_state, 'example_prompt'):
169
- st.info(f"Example selected: {st.session_state.example_prompt}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
- main()
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import time
3
+ import random
4
+ from typing import Dict, Any, List, Optional
5
+ import uvicorn
6
+ from fastapi import FastAPI, HTTPException, Depends, Request
7
+ from fastapi.responses import JSONResponse
8
+ from fastapi.exception_handlers import http_exception_handler
9
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from pydantic import BaseModel
12
+ import requests
13
+ import json
14
+
15
+ from config import config
16
+
17
+ # Configure logging
18
+ logging.basicConfig(level=getattr(logging, config.LOG_LEVEL))
19
+ logger = logging.getLogger(__name__)
20
+
21
+ # FastAPI app
22
+ app = FastAPI(
23
+ title="Advanced Gemini Proxy",
24
+ description="OpenAI-compatible proxy for Google Gemini API",
25
+ version="1.0.0"
26
  )
27
 
28
+ # CORS middleware
29
+ app.add_middleware(
30
+ CORSMiddleware,
31
+ allow_origins=["*"],
32
+ allow_credentials=True,
33
+ allow_methods=["*"],
34
+ allow_headers=["*"],
35
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Custom exception handler
38
+ @app.exception_handler(HTTPException)
39
+ async def custom_http_exception_handler(request: Request, exc: HTTPException):
40
+ # If detail is already in OpenAI format, return as-is
41
+ if isinstance(exc.detail, dict) and "error" in exc.detail:
42
+ return JSONResponse(
43
+ status_code=exc.status_code,
44
+ content=exc.detail
45
+ )
46
+
47
+ # Otherwise, format as OpenAI error
48
+ error_response = {
49
+ "error": {
50
+ "message": str(exc.detail),
51
+ "type": "api_error",
52
+ "param": None,
53
+ "code": None
54
+ }
55
+ }
56
+
57
+ return JSONResponse(
58
+ status_code=exc.status_code,
59
+ content=error_response
60
+ )
61
 
62
+ # Security
63
+ security = HTTPBearer()
64
 
65
+ # Rate limiting storage (in-memory for simplicity)
66
+ rate_limit_storage: Dict[str, List[float]] = {}
67
 
68
+ # Pydantic models
69
+ class ChatMessage(BaseModel):
70
+ role: str
71
+ content: str
72
+
73
+ class ChatCompletionRequest(BaseModel):
74
+ model: str
75
+ messages: List[ChatMessage]
76
+ temperature: Optional[float] = 1.0
77
+ max_tokens: Optional[int] = None
78
+ stream: Optional[bool] = False
79
+
80
+ class Choice(BaseModel):
81
+ index: int
82
+ message: Dict[str, str]
83
+ finish_reason: str
84
+
85
+ class Usage(BaseModel):
86
+ prompt_tokens: int
87
+ completion_tokens: int
88
+ total_tokens: int
89
+
90
+ class ChatCompletionResponse(BaseModel):
91
+ id: str
92
+ object: str = "chat.completion"
93
+ created: int
94
+ model: str
95
+ choices: List[Choice]
96
+ usage: Usage
97
 
 
98
  # Authentication
99
+ async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
100
+ if credentials.credentials != config.MASTER_API_KEY:
101
+ error_response = {
102
+ "error": {
103
+ "message": "Invalid API key provided",
104
+ "type": "invalid_request_error",
105
+ "param": None,
106
+ "code": "invalid_api_key"
107
+ }
108
+ }
109
+ raise HTTPException(status_code=401, detail=error_response)
110
+ return credentials.credentials
111
+
112
+ # Rate limiting
113
+ def check_rate_limit(client_ip: str) -> tuple[bool, int]:
114
+ now = time.time()
115
+ if client_ip not in rate_limit_storage:
116
+ rate_limit_storage[client_ip] = []
117
+
118
+ # Clean old entries
119
+ rate_limit_storage[client_ip] = [
120
+ timestamp for timestamp in rate_limit_storage[client_ip]
121
+ if now - timestamp < 60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  ]
123
+
124
+ current_count = len(rate_limit_storage[client_ip])
125
+
126
+ # Check limit
127
+ if current_count >= config.MAX_REQUESTS_PER_MINUTE:
128
+ # Calculate reset time
129
+ oldest_request = min(rate_limit_storage[client_ip])
130
+ reset_time = int(oldest_request + 60)
131
+ return False, reset_time
132
+
133
+ # Add current request
134
+ rate_limit_storage[client_ip].append(now)
135
+ return True, 0
136
+
137
+ # Gemini API interaction
138
+ def get_random_api_key() -> str:
139
+ return random.choice(config.GEMINI_API_KEYS)
140
+
141
+ def convert_to_gemini_format(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
142
+ gemini_messages = []
143
+ for msg in messages:
144
+ if msg.role == "system":
145
+ # Handle system messages by converting to user message with instruction
146
+ gemini_messages.append({
147
+ "role": "user",
148
+ "parts": [{"text": f"System instruction: {msg.content}"}]
149
+ })
150
+ else:
151
+ role = "user" if msg.role == "user" else "model"
152
+ gemini_messages.append({
153
+ "role": role,
154
+ "parts": [{"text": msg.content}]
155
+ })
156
+ return gemini_messages
157
+
158
+ def estimate_tokens(text: str) -> int:
159
+ """Simple token estimation - roughly 1 token per 4 characters"""
160
+ return max(1, len(text) // 4)
161
+
162
+ def call_gemini_api(messages: List[ChatMessage], model: str, temperature: float, max_tokens: Optional[int]) -> Dict[str, Any]:
163
+ api_key = get_random_api_key()
164
+
165
+ # Convert model name
166
+ if "gpt-4" in model.lower():
167
+ gemini_model = "gemini-1.5-pro-latest"
168
+ elif "gpt-3.5" in model.lower():
169
+ gemini_model = "gemini-1.5-flash-latest"
170
+ else:
171
+ gemini_model = "gemini-1.5-flash-latest" # Default fallback
172
+
173
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent"
174
+
175
+ # Convert messages
176
+ gemini_messages = convert_to_gemini_format(messages)
177
+
178
+ payload = {
179
+ "contents": gemini_messages,
180
+ "generationConfig": {
181
+ "temperature": max(0.0, min(2.0, temperature)), # Clamp temperature
182
+ },
183
+ "safetySettings": [
184
+ {
185
+ "category": "HARM_CATEGORY_HARASSMENT",
186
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
187
+ },
188
+ {
189
+ "category": "HARM_CATEGORY_HATE_SPEECH",
190
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
191
+ },
192
+ {
193
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
194
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
195
+ },
196
+ {
197
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
198
+ "threshold": "BLOCK_MEDIUM_AND_ABOVE"
199
+ }
200
+ ]
201
+ }
202
+
203
+ if max_tokens and max_tokens > 0:
204
+ payload["generationConfig"]["maxOutputTokens"] = min(max_tokens, 8192) # Gemini limit
205
+
206
+ headers = {
207
+ "Content-Type": "application/json",
208
+ "x-goog-api-key": api_key
209
+ }
210
+
211
+ try:
212
+ response = requests.post(url, json=payload, headers=headers, timeout=30)
213
+
214
+ if response.status_code != 200:
215
+ logger.error(f"Gemini API error: {response.status_code} - {response.text}")
216
+ error_response = {
217
+ "error": {
218
+ "message": f"Gemini API error: {response.text}",
219
+ "type": "api_error",
220
+ "param": None,
221
+ "code": "gemini_api_error"
222
+ }
223
+ }
224
+ raise HTTPException(status_code=response.status_code, detail=error_response)
225
+
226
+ return response.json()
227
+
228
+ except requests.exceptions.Timeout:
229
+ raise HTTPException(status_code=408, detail="Request timeout")
230
+ except requests.exceptions.RequestException as e:
231
+ logger.error(f"Request error: {str(e)}")
232
+ raise HTTPException(status_code=500, detail="Failed to connect to Gemini API")
233
+
234
+ # Routes
235
+ @app.get("/")
236
+ async def root():
237
+ return {"message": "Advanced Gemini Proxy is running!", "version": "1.0.0"}
238
 
239
+ @app.get("/health")
240
+ async def health_check():
241
+ return {"status": "healthy", "timestamp": time.time()}
 
 
 
242
 
243
+ @app.get("/v1/models")
244
+ async def list_models(api_key: str = Depends(verify_api_key)):
245
+ return {
246
+ "object": "list",
247
+ "data": [
248
+ {
249
+ "id": "gpt-3.5-turbo",
250
+ "object": "model",
251
+ "created": int(time.time()),
252
+ "owned_by": "gemini-proxy"
253
+ },
254
+ {
255
+ "id": "gpt-4",
256
+ "object": "model",
257
+ "created": int(time.time()),
258
+ "owned_by": "gemini-proxy"
259
+ }
260
+ ]
261
+ }
262
 
263
+ @app.post("/v1/chat/completions")
264
+ async def chat_completions(
265
+ request: ChatCompletionRequest,
266
+ client_request: Request,
267
+ api_key: str = Depends(verify_api_key)
268
+ ):
269
+ # Rate limiting
270
+ client_ip = client_request.client.host
271
+ allowed, reset_time = check_rate_limit(client_ip)
272
+ if not allowed:
273
+ error_response = {
274
+ "error": {
275
+ "message": "Rate limit reached for requests",
276
+ "type": "rate_limit_exceeded",
277
+ "param": None,
278
+ "code": "rate_limit_exceeded"
279
+ }
280
+ }
281
+ headers = {
282
+ "X-RateLimit-Limit": str(config.MAX_REQUESTS_PER_MINUTE),
283
+ "X-RateLimit-Remaining": "0",
284
+ "X-RateLimit-Reset": str(reset_time),
285
+ "Retry-After": str(60)
286
+ }
287
+ return JSONResponse(
288
+ status_code=429,
289
+ content=error_response,
290
+ headers=headers
291
+ )
292
+
293
+ # Validate request
294
+ if not request.messages:
295
+ error_response = {
296
+ "error": {
297
+ "message": "Missing required parameter: 'messages'",
298
+ "type": "invalid_request_error",
299
+ "param": "messages",
300
+ "code": "missing_required_parameter"
301
+ }
302
+ }
303
+ raise HTTPException(status_code=400, detail=error_response)
304
+
305
+ try:
306
+ # Call Gemini API
307
+ gemini_response = call_gemini_api(
308
+ request.messages,
309
+ request.model,
310
+ request.temperature,
311
+ request.max_tokens
312
+ )
313
+
314
+ # Extract response text
315
+ if "candidates" not in gemini_response or not gemini_response["candidates"]:
316
+ # Check for blocked content
317
+ if "promptFeedback" in gemini_response and "blockReason" in gemini_response["promptFeedback"]:
318
+ block_reason = gemini_response["promptFeedback"]["blockReason"]
319
+ raise HTTPException(status_code=400, detail=f"Content blocked: {block_reason}")
320
+ raise HTTPException(status_code=500, detail="No response from Gemini API")
321
+
322
+ candidate = gemini_response["candidates"][0]
323
+
324
+ # Check if response was blocked
325
+ if "finishReason" in candidate and candidate["finishReason"] in ["SAFETY", "RECITATION"]:
326
+ raise HTTPException(status_code=400, detail=f"Response blocked: {candidate['finishReason']}")
327
+
328
+ if "content" not in candidate or "parts" not in candidate["content"]:
329
+ raise HTTPException(status_code=500, detail="Invalid response format from Gemini API")
330
+
331
+ response_text = candidate["content"]["parts"][0]["text"]
332
+
333
+ # Calculate token usage
334
+ prompt_text = " ".join([msg.content for msg in request.messages])
335
+ prompt_tokens = estimate_tokens(prompt_text)
336
+ completion_tokens = estimate_tokens(response_text)
337
+
338
+ # Convert to OpenAI format
339
+ response = ChatCompletionResponse(
340
+ id=f"chatcmpl-{int(time.time())}{random.randint(1000, 9999)}",
341
+ created=int(time.time()),
342
+ model=request.model,
343
+ choices=[Choice(
344
+ index=0,
345
+ message={
346
+ "role": "assistant",
347
+ "content": response_text
348
+ },
349
+ finish_reason="stop"
350
+ )],
351
+ usage=Usage(
352
+ prompt_tokens=prompt_tokens,
353
+ completion_tokens=completion_tokens,
354
+ total_tokens=prompt_tokens + completion_tokens
355
+ )
356
+ )
357
+
358
+ return response
359
+
360
+ except HTTPException:
361
+ raise
362
+ except Exception as e:
363
+ logger.error(f"Unexpected error: {str(e)}")
364
+ raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
365
 
366
  if __name__ == "__main__":
367
+ logger.info(f"πŸš€ Starting Advanced Gemini Proxy on {config.HOST}:{config.PORT}")
368
+ logger.info(f"πŸ”‘ Master API Key: {config.MASTER_API_KEY[:8]}...")
369
+ logger.info(f"πŸ”§ Loaded {len(config.GEMINI_API_KEYS)} Gemini API key(s)")
370
+
371
+ uvicorn.run(
372
+ app,
373
+ host=config.HOST,
374
+ port=config.PORT,
375
+ log_level=config.LOG_LEVEL.lower()
376
+ )
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
- streamlit==1.28.1
2
- transformers==4.44.2
3
- torch==2.4.1
4
- tokenizers==0.19.1
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ requests==2.31.0
4
+ python-multipart==0.0.6
5
+ pydantic==2.5.0
6
+ python-dotenv==1.0.0