sonyps1928 commited on
Commit
e27c591
·
1 Parent(s): 1b3fa51
Files changed (2) hide show
  1. app.py +104 -365
  2. requirements.txt +4 -6
app.py CHANGED
@@ -1,376 +1,115 @@
1
- import logging
2
- import time
3
- import random
4
- from typing import Dict, Any, List, Optional
5
- import uvicorn
6
- from fastapi import FastAPI, HTTPException, Depends, Request
7
- from fastapi.responses import JSONResponse
8
- from fastapi.exception_handlers import http_exception_handler
9
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
10
- from fastapi.middleware.cors import CORSMiddleware
11
- from pydantic import BaseModel
12
- import requests
13
- import json
14
 
15
- from config import config
 
 
 
16
 
17
- # Configure logging
18
- logging.basicConfig(level=getattr(logging, config.LOG_LEVEL))
19
- logger = logging.getLogger(__name__)
20
 
21
- # FastAPI app
22
- app = FastAPI(
23
- title="Advanced Gemini Proxy",
24
- description="OpenAI-compatible proxy for Google Gemini API",
25
- version="1.0.0"
26
- )
27
-
28
- # CORS middleware
29
- app.add_middleware(
30
- CORSMiddleware,
31
- allow_origins=["*"],
32
- allow_credentials=True,
33
- allow_methods=["*"],
34
- allow_headers=["*"],
35
- )
36
-
37
- # Custom exception handler
38
- @app.exception_handler(HTTPException)
39
- async def custom_http_exception_handler(request: Request, exc: HTTPException):
40
- # If detail is already in OpenAI format, return as-is
41
- if isinstance(exc.detail, dict) and "error" in exc.detail:
42
- return JSONResponse(
43
- status_code=exc.status_code,
44
- content=exc.detail
45
- )
46
-
47
- # Otherwise, format as OpenAI error
48
- error_response = {
49
- "error": {
50
- "message": str(exc.detail),
51
- "type": "api_error",
52
- "param": None,
53
- "code": None
54
- }
55
- }
56
-
57
- return JSONResponse(
58
- status_code=exc.status_code,
59
- content=error_response
60
- )
61
-
62
- # Security
63
- security = HTTPBearer()
64
-
65
- # Rate limiting storage (in-memory for simplicity)
66
- rate_limit_storage: Dict[str, List[float]] = {}
67
-
68
- # Pydantic models
69
- class ChatMessage(BaseModel):
70
- role: str
71
- content: str
72
-
73
- class ChatCompletionRequest(BaseModel):
74
- model: str
75
- messages: List[ChatMessage]
76
- temperature: Optional[float] = 1.0
77
- max_tokens: Optional[int] = None
78
- stream: Optional[bool] = False
79
-
80
- class Choice(BaseModel):
81
- index: int
82
- message: Dict[str, str]
83
- finish_reason: str
84
-
85
- class Usage(BaseModel):
86
- prompt_tokens: int
87
- completion_tokens: int
88
- total_tokens: int
89
-
90
- class ChatCompletionResponse(BaseModel):
91
- id: str
92
- object: str = "chat.completion"
93
- created: int
94
- model: str
95
- choices: List[Choice]
96
- usage: Usage
97
-
98
- # Authentication
99
- async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
100
- if credentials.credentials != config.MASTER_API_KEY:
101
- error_response = {
102
- "error": {
103
- "message": "Invalid API key provided",
104
- "type": "invalid_request_error",
105
- "param": None,
106
- "code": "invalid_api_key"
107
- }
108
- }
109
- raise HTTPException(status_code=401, detail=error_response)
110
- return credentials.credentials
111
-
112
- # Rate limiting
113
- def check_rate_limit(client_ip: str) -> tuple[bool, int]:
114
- now = time.time()
115
- if client_ip not in rate_limit_storage:
116
- rate_limit_storage[client_ip] = []
117
-
118
- # Clean old entries
119
- rate_limit_storage[client_ip] = [
120
- timestamp for timestamp in rate_limit_storage[client_ip]
121
- if now - timestamp < 60
122
- ]
123
-
124
- current_count = len(rate_limit_storage[client_ip])
125
-
126
- # Check limit
127
- if current_count >= config.MAX_REQUESTS_PER_MINUTE:
128
- # Calculate reset time
129
- oldest_request = min(rate_limit_storage[client_ip])
130
- reset_time = int(oldest_request + 60)
131
- return False, reset_time
132
-
133
- # Add current request
134
- rate_limit_storage[client_ip].append(now)
135
- return True, 0
136
-
137
- # Gemini API interaction
138
- def get_random_api_key() -> str:
139
- return random.choice(config.GEMINI_API_KEYS)
140
-
141
- def convert_to_gemini_format(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
142
- gemini_messages = []
143
- for msg in messages:
144
- if msg.role == "system":
145
- # Handle system messages by converting to user message with instruction
146
- gemini_messages.append({
147
- "role": "user",
148
- "parts": [{"text": f"System instruction: {msg.content}"}]
149
- })
150
- else:
151
- role = "user" if msg.role == "user" else "model"
152
- gemini_messages.append({
153
- "role": role,
154
- "parts": [{"text": msg.content}]
155
- })
156
- return gemini_messages
157
-
158
- def estimate_tokens(text: str) -> int:
159
- """Simple token estimation - roughly 1 token per 4 characters"""
160
- return max(1, len(text) // 4)
161
-
162
- def call_gemini_api(messages: List[ChatMessage], model: str, temperature: float, max_tokens: Optional[int]) -> Dict[str, Any]:
163
- api_key = get_random_api_key()
164
-
165
- # Convert model name
166
- if "gpt-4" in model.lower():
167
- gemini_model = "gemini-1.5-pro-latest"
168
- elif "gpt-3.5" in model.lower():
169
- gemini_model = "gemini-1.5-flash-latest"
170
- else:
171
- gemini_model = "gemini-1.5-flash-latest" # Default fallback
172
-
173
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent"
174
-
175
- # Convert messages
176
- gemini_messages = convert_to_gemini_format(messages)
177
-
178
- payload = {
179
- "contents": gemini_messages,
180
- "generationConfig": {
181
- "temperature": max(0.0, min(2.0, temperature)), # Clamp temperature
182
- },
183
- "safetySettings": [
184
- {
185
- "category": "HARM_CATEGORY_HARASSMENT",
186
- "threshold": "BLOCK_MEDIUM_AND_ABOVE"
187
- },
188
- {
189
- "category": "HARM_CATEGORY_HATE_SPEECH",
190
- "threshold": "BLOCK_MEDIUM_AND_ABOVE"
191
- },
192
- {
193
- "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
194
- "threshold": "BLOCK_MEDIUM_AND_ABOVE"
195
- },
196
- {
197
- "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
198
- "threshold": "BLOCK_MEDIUM_AND_ABOVE"
199
- }
200
- ]
201
- }
202
-
203
- if max_tokens and max_tokens > 0:
204
- payload["generationConfig"]["maxOutputTokens"] = min(max_tokens, 8192) # Gemini limit
205
-
206
- headers = {
207
- "Content-Type": "application/json",
208
- "x-goog-api-key": api_key
209
- }
210
-
211
- try:
212
- response = requests.post(url, json=payload, headers=headers, timeout=30)
213
-
214
- if response.status_code != 200:
215
- logger.error(f"Gemini API error: {response.status_code} - {response.text}")
216
- error_response = {
217
- "error": {
218
- "message": f"Gemini API error: {response.text}",
219
- "type": "api_error",
220
- "param": None,
221
- "code": "gemini_api_error"
222
- }
223
- }
224
- raise HTTPException(status_code=response.status_code, detail=error_response)
225
-
226
- return response.json()
227
-
228
- except requests.exceptions.Timeout:
229
- raise HTTPException(status_code=408, detail="Request timeout")
230
- except requests.exceptions.RequestException as e:
231
- logger.error(f"Request error: {str(e)}")
232
- raise HTTPException(status_code=500, detail="Failed to connect to Gemini API")
233
-
234
- # Routes
235
- @app.get("/")
236
- async def root():
237
- return {"message": "Advanced Gemini Proxy is running!", "version": "1.0.0"}
238
-
239
- @app.get("/health")
240
- async def health_check():
241
- return {"status": "healthy", "timestamp": time.time()}
242
-
243
- @app.get("/v1/models")
244
- async def list_models(api_key: str = Depends(verify_api_key)):
245
- return {
246
- "object": "list",
247
- "data": [
248
- {
249
- "id": "gpt-3.5-turbo",
250
- "object": "model",
251
- "created": int(time.time()),
252
- "owned_by": "gemini-proxy"
253
- },
254
- {
255
- "id": "gpt-4",
256
- "object": "model",
257
- "created": int(time.time()),
258
- "owned_by": "gemini-proxy"
259
- }
260
- ]
261
- }
262
-
263
- @app.post("/v1/chat/completions")
264
- async def chat_completions(
265
- request: ChatCompletionRequest,
266
- client_request: Request,
267
- api_key: str = Depends(verify_api_key)
268
- ):
269
- # Rate limiting
270
- client_ip = client_request.client.host
271
- allowed, reset_time = check_rate_limit(client_ip)
272
- if not allowed:
273
- error_response = {
274
- "error": {
275
- "message": "Rate limit reached for requests",
276
- "type": "rate_limit_exceeded",
277
- "param": None,
278
- "code": "rate_limit_exceeded"
279
- }
280
- }
281
- headers = {
282
- "X-RateLimit-Limit": str(config.MAX_REQUESTS_PER_MINUTE),
283
- "X-RateLimit-Remaining": "0",
284
- "X-RateLimit-Reset": str(reset_time),
285
- "Retry-After": str(60)
286
- }
287
- return JSONResponse(
288
- status_code=429,
289
- content=error_response,
290
- headers=headers
291
- )
292
-
293
- # Validate request
294
- if not request.messages:
295
- error_response = {
296
- "error": {
297
- "message": "Missing required parameter: 'messages'",
298
- "type": "invalid_request_error",
299
- "param": "messages",
300
- "code": "missing_required_parameter"
301
- }
302
- }
303
- raise HTTPException(status_code=400, detail=error_response)
304
-
305
  try:
306
- # Call Gemini API
307
- gemini_response = call_gemini_api(
308
- request.messages,
309
- request.model,
310
- request.temperature,
311
- request.max_tokens
312
- )
313
-
314
- # Extract response text
315
- if "candidates" not in gemini_response or not gemini_response["candidates"]:
316
- # Check for blocked content
317
- if "promptFeedback" in gemini_response and "blockReason" in gemini_response["promptFeedback"]:
318
- block_reason = gemini_response["promptFeedback"]["blockReason"]
319
- raise HTTPException(status_code=400, detail=f"Content blocked: {block_reason}")
320
- raise HTTPException(status_code=500, detail="No response from Gemini API")
321
-
322
- candidate = gemini_response["candidates"][0]
323
 
324
- # Check if response was blocked
325
- if "finishReason" in candidate and candidate["finishReason"] in ["SAFETY", "RECITATION"]:
326
- raise HTTPException(status_code=400, detail=f"Response blocked: {candidate['finishReason']}")
327
-
328
- if "content" not in candidate or "parts" not in candidate["content"]:
329
- raise HTTPException(status_code=500, detail="Invalid response format from Gemini API")
330
-
331
- response_text = candidate["content"]["parts"][0]["text"]
332
-
333
- # Calculate token usage
334
- prompt_text = " ".join([msg.content for msg in request.messages])
335
- prompt_tokens = estimate_tokens(prompt_text)
336
- completion_tokens = estimate_tokens(response_text)
337
-
338
- # Convert to OpenAI format
339
- response = ChatCompletionResponse(
340
- id=f"chatcmpl-{int(time.time())}{random.randint(1000, 9999)}",
341
- created=int(time.time()),
342
- model=request.model,
343
- choices=[Choice(
344
- index=0,
345
- message={
346
- "role": "assistant",
347
- "content": response_text
348
- },
349
- finish_reason="stop"
350
- )],
351
- usage=Usage(
352
- prompt_tokens=prompt_tokens,
353
- completion_tokens=completion_tokens,
354
- total_tokens=prompt_tokens + completion_tokens
355
  )
356
- )
357
 
358
- return response
 
359
 
360
- except HTTPException:
361
- raise
 
362
  except Exception as e:
363
- logger.error(f"Unexpected error: {str(e)}")
364
- raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
 
366
  if __name__ == "__main__":
367
- logger.info(f"🚀 Starting Advanced Gemini Proxy on {config.HOST}:{config.PORT}")
368
- logger.info(f"🔑 Master API Key: {config.MASTER_API_KEY[:8]}...")
369
- logger.info(f"🔧 Loaded {len(config.GEMINI_API_KEYS)} Gemini API key(s)")
370
-
371
- uvicorn.run(
372
- app,
373
- host=config.HOST,
374
- port=config.PORT,
375
- log_level=config.LOG_LEVEL.lower()
376
- )
 
1
+ import gradio as gr
2
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
3
+ import torch
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Load model and tokenizer (using smaller GPT-2 for free tier)
6
+ model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
7
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
8
+ model = GPT2LMHeadModel.from_pretrained(model_name)
9
 
10
+ # Set pad token
11
+ tokenizer.pad_token = tokenizer.eos_token
 
12
 
13
+ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
14
+ """Generate text using GPT-2"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ # Encode input
17
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Generate
20
+ with torch.no_grad():
21
+ outputs = model.generate(
22
+ inputs,
23
+ max_length=min(max_length + len(inputs[0]), 512), # Limit total length
24
+ temperature=temperature,
25
+ top_p=top_p,
26
+ top_k=top_k,
27
+ do_sample=True,
28
+ pad_token_id=tokenizer.eos_token_id,
29
+ num_return_sequences=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
31
 
32
+ # Decode output
33
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
34
 
35
+ # Return only the new generated part
36
+ return generated_text[len(prompt):].strip()
37
+
38
  except Exception as e:
39
+ return f"Error generating text: {str(e)}"
40
+
41
+ # Create Gradio interface
42
+ with gr.Blocks(title="GPT-2 Text Generator") as demo:
43
+ gr.Markdown("# GPT-2 Text Generation Server")
44
+ gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
45
+
46
+ with gr.Row():
47
+ with gr.Column():
48
+ prompt_input = gr.Textbox(
49
+ label="Prompt",
50
+ placeholder="Enter your text prompt here...",
51
+ lines=3
52
+ )
53
+
54
+ with gr.Row():
55
+ max_length = gr.Slider(
56
+ minimum=10,
57
+ maximum=200,
58
+ value=100,
59
+ step=10,
60
+ label="Max Length"
61
+ )
62
+ temperature = gr.Slider(
63
+ minimum=0.1,
64
+ maximum=2.0,
65
+ value=0.7,
66
+ step=0.1,
67
+ label="Temperature"
68
+ )
69
+
70
+ with gr.Row():
71
+ top_p = gr.Slider(
72
+ minimum=0.1,
73
+ maximum=1.0,
74
+ value=0.9,
75
+ step=0.1,
76
+ label="Top-p"
77
+ )
78
+ top_k = gr.Slider(
79
+ minimum=1,
80
+ maximum=100,
81
+ value=50,
82
+ step=1,
83
+ label="Top-k"
84
+ )
85
+
86
+ generate_btn = gr.Button("Generate Text", variant="primary")
87
+
88
+ with gr.Column():
89
+ output_text = gr.Textbox(
90
+ label="Generated Text",
91
+ lines=10,
92
+ placeholder="Generated text will appear here..."
93
+ )
94
+
95
+ # Examples
96
+ gr.Examples(
97
+ examples=[
98
+ ["Once upon a time in a distant galaxy,"],
99
+ ["The future of artificial intelligence is"],
100
+ ["In the heart of the ancient forest,"],
101
+ ["The detective walked into the room and noticed"],
102
+ ],
103
+ inputs=prompt_input
104
+ )
105
+
106
+ # Connect the function
107
+ generate_btn.click(
108
+ fn=generate_text,
109
+ inputs=[prompt_input, max_length, temperature, top_p, top_k],
110
+ outputs=output_text
111
+ )
112
 
113
+ # Launch the app
114
  if __name__ == "__main__":
115
+ demo.launch()
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- fastapi==0.104.1
2
- uvicorn[standard]==0.24.0
3
- requests==2.31.0
4
- python-multipart==0.0.6
5
- pydantic==2.5.0
6
- python-dotenv==1.0.0
 
1
+ gradio>=3.50.0
2
+ transformers>=4.30.0
3
+ torch>=2.0.0
4
+ tokenizers>=0.13.0