sonyps1928
commited on
Commit
·
e27c591
1
Parent(s):
1b3fa51
update
Browse files- app.py +104 -365
- requirements.txt +4 -6
app.py
CHANGED
@@ -1,376 +1,115 @@
|
|
1 |
-
import
|
2 |
-
import
|
3 |
-
import
|
4 |
-
from typing import Dict, Any, List, Optional
|
5 |
-
import uvicorn
|
6 |
-
from fastapi import FastAPI, HTTPException, Depends, Request
|
7 |
-
from fastapi.responses import JSONResponse
|
8 |
-
from fastapi.exception_handlers import http_exception_handler
|
9 |
-
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
10 |
-
from fastapi.middleware.cors import CORSMiddleware
|
11 |
-
from pydantic import BaseModel
|
12 |
-
import requests
|
13 |
-
import json
|
14 |
|
15 |
-
|
|
|
|
|
|
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
logger = logging.getLogger(__name__)
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
title="Advanced Gemini Proxy",
|
24 |
-
description="OpenAI-compatible proxy for Google Gemini API",
|
25 |
-
version="1.0.0"
|
26 |
-
)
|
27 |
-
|
28 |
-
# CORS middleware
|
29 |
-
app.add_middleware(
|
30 |
-
CORSMiddleware,
|
31 |
-
allow_origins=["*"],
|
32 |
-
allow_credentials=True,
|
33 |
-
allow_methods=["*"],
|
34 |
-
allow_headers=["*"],
|
35 |
-
)
|
36 |
-
|
37 |
-
# Custom exception handler
|
38 |
-
@app.exception_handler(HTTPException)
|
39 |
-
async def custom_http_exception_handler(request: Request, exc: HTTPException):
|
40 |
-
# If detail is already in OpenAI format, return as-is
|
41 |
-
if isinstance(exc.detail, dict) and "error" in exc.detail:
|
42 |
-
return JSONResponse(
|
43 |
-
status_code=exc.status_code,
|
44 |
-
content=exc.detail
|
45 |
-
)
|
46 |
-
|
47 |
-
# Otherwise, format as OpenAI error
|
48 |
-
error_response = {
|
49 |
-
"error": {
|
50 |
-
"message": str(exc.detail),
|
51 |
-
"type": "api_error",
|
52 |
-
"param": None,
|
53 |
-
"code": None
|
54 |
-
}
|
55 |
-
}
|
56 |
-
|
57 |
-
return JSONResponse(
|
58 |
-
status_code=exc.status_code,
|
59 |
-
content=error_response
|
60 |
-
)
|
61 |
-
|
62 |
-
# Security
|
63 |
-
security = HTTPBearer()
|
64 |
-
|
65 |
-
# Rate limiting storage (in-memory for simplicity)
|
66 |
-
rate_limit_storage: Dict[str, List[float]] = {}
|
67 |
-
|
68 |
-
# Pydantic models
|
69 |
-
class ChatMessage(BaseModel):
|
70 |
-
role: str
|
71 |
-
content: str
|
72 |
-
|
73 |
-
class ChatCompletionRequest(BaseModel):
|
74 |
-
model: str
|
75 |
-
messages: List[ChatMessage]
|
76 |
-
temperature: Optional[float] = 1.0
|
77 |
-
max_tokens: Optional[int] = None
|
78 |
-
stream: Optional[bool] = False
|
79 |
-
|
80 |
-
class Choice(BaseModel):
|
81 |
-
index: int
|
82 |
-
message: Dict[str, str]
|
83 |
-
finish_reason: str
|
84 |
-
|
85 |
-
class Usage(BaseModel):
|
86 |
-
prompt_tokens: int
|
87 |
-
completion_tokens: int
|
88 |
-
total_tokens: int
|
89 |
-
|
90 |
-
class ChatCompletionResponse(BaseModel):
|
91 |
-
id: str
|
92 |
-
object: str = "chat.completion"
|
93 |
-
created: int
|
94 |
-
model: str
|
95 |
-
choices: List[Choice]
|
96 |
-
usage: Usage
|
97 |
-
|
98 |
-
# Authentication
|
99 |
-
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Depends(security)):
|
100 |
-
if credentials.credentials != config.MASTER_API_KEY:
|
101 |
-
error_response = {
|
102 |
-
"error": {
|
103 |
-
"message": "Invalid API key provided",
|
104 |
-
"type": "invalid_request_error",
|
105 |
-
"param": None,
|
106 |
-
"code": "invalid_api_key"
|
107 |
-
}
|
108 |
-
}
|
109 |
-
raise HTTPException(status_code=401, detail=error_response)
|
110 |
-
return credentials.credentials
|
111 |
-
|
112 |
-
# Rate limiting
|
113 |
-
def check_rate_limit(client_ip: str) -> tuple[bool, int]:
|
114 |
-
now = time.time()
|
115 |
-
if client_ip not in rate_limit_storage:
|
116 |
-
rate_limit_storage[client_ip] = []
|
117 |
-
|
118 |
-
# Clean old entries
|
119 |
-
rate_limit_storage[client_ip] = [
|
120 |
-
timestamp for timestamp in rate_limit_storage[client_ip]
|
121 |
-
if now - timestamp < 60
|
122 |
-
]
|
123 |
-
|
124 |
-
current_count = len(rate_limit_storage[client_ip])
|
125 |
-
|
126 |
-
# Check limit
|
127 |
-
if current_count >= config.MAX_REQUESTS_PER_MINUTE:
|
128 |
-
# Calculate reset time
|
129 |
-
oldest_request = min(rate_limit_storage[client_ip])
|
130 |
-
reset_time = int(oldest_request + 60)
|
131 |
-
return False, reset_time
|
132 |
-
|
133 |
-
# Add current request
|
134 |
-
rate_limit_storage[client_ip].append(now)
|
135 |
-
return True, 0
|
136 |
-
|
137 |
-
# Gemini API interaction
|
138 |
-
def get_random_api_key() -> str:
|
139 |
-
return random.choice(config.GEMINI_API_KEYS)
|
140 |
-
|
141 |
-
def convert_to_gemini_format(messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
142 |
-
gemini_messages = []
|
143 |
-
for msg in messages:
|
144 |
-
if msg.role == "system":
|
145 |
-
# Handle system messages by converting to user message with instruction
|
146 |
-
gemini_messages.append({
|
147 |
-
"role": "user",
|
148 |
-
"parts": [{"text": f"System instruction: {msg.content}"}]
|
149 |
-
})
|
150 |
-
else:
|
151 |
-
role = "user" if msg.role == "user" else "model"
|
152 |
-
gemini_messages.append({
|
153 |
-
"role": role,
|
154 |
-
"parts": [{"text": msg.content}]
|
155 |
-
})
|
156 |
-
return gemini_messages
|
157 |
-
|
158 |
-
def estimate_tokens(text: str) -> int:
|
159 |
-
"""Simple token estimation - roughly 1 token per 4 characters"""
|
160 |
-
return max(1, len(text) // 4)
|
161 |
-
|
162 |
-
def call_gemini_api(messages: List[ChatMessage], model: str, temperature: float, max_tokens: Optional[int]) -> Dict[str, Any]:
|
163 |
-
api_key = get_random_api_key()
|
164 |
-
|
165 |
-
# Convert model name
|
166 |
-
if "gpt-4" in model.lower():
|
167 |
-
gemini_model = "gemini-1.5-pro-latest"
|
168 |
-
elif "gpt-3.5" in model.lower():
|
169 |
-
gemini_model = "gemini-1.5-flash-latest"
|
170 |
-
else:
|
171 |
-
gemini_model = "gemini-1.5-flash-latest" # Default fallback
|
172 |
-
|
173 |
-
url = f"https://generativelanguage.googleapis.com/v1beta/models/{gemini_model}:generateContent"
|
174 |
-
|
175 |
-
# Convert messages
|
176 |
-
gemini_messages = convert_to_gemini_format(messages)
|
177 |
-
|
178 |
-
payload = {
|
179 |
-
"contents": gemini_messages,
|
180 |
-
"generationConfig": {
|
181 |
-
"temperature": max(0.0, min(2.0, temperature)), # Clamp temperature
|
182 |
-
},
|
183 |
-
"safetySettings": [
|
184 |
-
{
|
185 |
-
"category": "HARM_CATEGORY_HARASSMENT",
|
186 |
-
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
187 |
-
},
|
188 |
-
{
|
189 |
-
"category": "HARM_CATEGORY_HATE_SPEECH",
|
190 |
-
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
191 |
-
},
|
192 |
-
{
|
193 |
-
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
194 |
-
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
195 |
-
},
|
196 |
-
{
|
197 |
-
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
198 |
-
"threshold": "BLOCK_MEDIUM_AND_ABOVE"
|
199 |
-
}
|
200 |
-
]
|
201 |
-
}
|
202 |
-
|
203 |
-
if max_tokens and max_tokens > 0:
|
204 |
-
payload["generationConfig"]["maxOutputTokens"] = min(max_tokens, 8192) # Gemini limit
|
205 |
-
|
206 |
-
headers = {
|
207 |
-
"Content-Type": "application/json",
|
208 |
-
"x-goog-api-key": api_key
|
209 |
-
}
|
210 |
-
|
211 |
-
try:
|
212 |
-
response = requests.post(url, json=payload, headers=headers, timeout=30)
|
213 |
-
|
214 |
-
if response.status_code != 200:
|
215 |
-
logger.error(f"Gemini API error: {response.status_code} - {response.text}")
|
216 |
-
error_response = {
|
217 |
-
"error": {
|
218 |
-
"message": f"Gemini API error: {response.text}",
|
219 |
-
"type": "api_error",
|
220 |
-
"param": None,
|
221 |
-
"code": "gemini_api_error"
|
222 |
-
}
|
223 |
-
}
|
224 |
-
raise HTTPException(status_code=response.status_code, detail=error_response)
|
225 |
-
|
226 |
-
return response.json()
|
227 |
-
|
228 |
-
except requests.exceptions.Timeout:
|
229 |
-
raise HTTPException(status_code=408, detail="Request timeout")
|
230 |
-
except requests.exceptions.RequestException as e:
|
231 |
-
logger.error(f"Request error: {str(e)}")
|
232 |
-
raise HTTPException(status_code=500, detail="Failed to connect to Gemini API")
|
233 |
-
|
234 |
-
# Routes
|
235 |
-
@app.get("/")
|
236 |
-
async def root():
|
237 |
-
return {"message": "Advanced Gemini Proxy is running!", "version": "1.0.0"}
|
238 |
-
|
239 |
-
@app.get("/health")
|
240 |
-
async def health_check():
|
241 |
-
return {"status": "healthy", "timestamp": time.time()}
|
242 |
-
|
243 |
-
@app.get("/v1/models")
|
244 |
-
async def list_models(api_key: str = Depends(verify_api_key)):
|
245 |
-
return {
|
246 |
-
"object": "list",
|
247 |
-
"data": [
|
248 |
-
{
|
249 |
-
"id": "gpt-3.5-turbo",
|
250 |
-
"object": "model",
|
251 |
-
"created": int(time.time()),
|
252 |
-
"owned_by": "gemini-proxy"
|
253 |
-
},
|
254 |
-
{
|
255 |
-
"id": "gpt-4",
|
256 |
-
"object": "model",
|
257 |
-
"created": int(time.time()),
|
258 |
-
"owned_by": "gemini-proxy"
|
259 |
-
}
|
260 |
-
]
|
261 |
-
}
|
262 |
-
|
263 |
-
@app.post("/v1/chat/completions")
|
264 |
-
async def chat_completions(
|
265 |
-
request: ChatCompletionRequest,
|
266 |
-
client_request: Request,
|
267 |
-
api_key: str = Depends(verify_api_key)
|
268 |
-
):
|
269 |
-
# Rate limiting
|
270 |
-
client_ip = client_request.client.host
|
271 |
-
allowed, reset_time = check_rate_limit(client_ip)
|
272 |
-
if not allowed:
|
273 |
-
error_response = {
|
274 |
-
"error": {
|
275 |
-
"message": "Rate limit reached for requests",
|
276 |
-
"type": "rate_limit_exceeded",
|
277 |
-
"param": None,
|
278 |
-
"code": "rate_limit_exceeded"
|
279 |
-
}
|
280 |
-
}
|
281 |
-
headers = {
|
282 |
-
"X-RateLimit-Limit": str(config.MAX_REQUESTS_PER_MINUTE),
|
283 |
-
"X-RateLimit-Remaining": "0",
|
284 |
-
"X-RateLimit-Reset": str(reset_time),
|
285 |
-
"Retry-After": str(60)
|
286 |
-
}
|
287 |
-
return JSONResponse(
|
288 |
-
status_code=429,
|
289 |
-
content=error_response,
|
290 |
-
headers=headers
|
291 |
-
)
|
292 |
-
|
293 |
-
# Validate request
|
294 |
-
if not request.messages:
|
295 |
-
error_response = {
|
296 |
-
"error": {
|
297 |
-
"message": "Missing required parameter: 'messages'",
|
298 |
-
"type": "invalid_request_error",
|
299 |
-
"param": "messages",
|
300 |
-
"code": "missing_required_parameter"
|
301 |
-
}
|
302 |
-
}
|
303 |
-
raise HTTPException(status_code=400, detail=error_response)
|
304 |
-
|
305 |
try:
|
306 |
-
#
|
307 |
-
|
308 |
-
request.messages,
|
309 |
-
request.model,
|
310 |
-
request.temperature,
|
311 |
-
request.max_tokens
|
312 |
-
)
|
313 |
-
|
314 |
-
# Extract response text
|
315 |
-
if "candidates" not in gemini_response or not gemini_response["candidates"]:
|
316 |
-
# Check for blocked content
|
317 |
-
if "promptFeedback" in gemini_response and "blockReason" in gemini_response["promptFeedback"]:
|
318 |
-
block_reason = gemini_response["promptFeedback"]["blockReason"]
|
319 |
-
raise HTTPException(status_code=400, detail=f"Content blocked: {block_reason}")
|
320 |
-
raise HTTPException(status_code=500, detail="No response from Gemini API")
|
321 |
-
|
322 |
-
candidate = gemini_response["candidates"][0]
|
323 |
|
324 |
-
#
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
prompt_tokens = estimate_tokens(prompt_text)
|
336 |
-
completion_tokens = estimate_tokens(response_text)
|
337 |
-
|
338 |
-
# Convert to OpenAI format
|
339 |
-
response = ChatCompletionResponse(
|
340 |
-
id=f"chatcmpl-{int(time.time())}{random.randint(1000, 9999)}",
|
341 |
-
created=int(time.time()),
|
342 |
-
model=request.model,
|
343 |
-
choices=[Choice(
|
344 |
-
index=0,
|
345 |
-
message={
|
346 |
-
"role": "assistant",
|
347 |
-
"content": response_text
|
348 |
-
},
|
349 |
-
finish_reason="stop"
|
350 |
-
)],
|
351 |
-
usage=Usage(
|
352 |
-
prompt_tokens=prompt_tokens,
|
353 |
-
completion_tokens=completion_tokens,
|
354 |
-
total_tokens=prompt_tokens + completion_tokens
|
355 |
)
|
356 |
-
)
|
357 |
|
358 |
-
|
|
|
359 |
|
360 |
-
|
361 |
-
|
|
|
362 |
except Exception as e:
|
363 |
-
|
364 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
|
|
|
366 |
if __name__ == "__main__":
|
367 |
-
|
368 |
-
logger.info(f"🔑 Master API Key: {config.MASTER_API_KEY[:8]}...")
|
369 |
-
logger.info(f"🔧 Loaded {len(config.GEMINI_API_KEYS)} Gemini API key(s)")
|
370 |
-
|
371 |
-
uvicorn.run(
|
372 |
-
app,
|
373 |
-
host=config.HOST,
|
374 |
-
port=config.PORT,
|
375 |
-
log_level=config.LOG_LEVEL.lower()
|
376 |
-
)
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
3 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
+
# Load model and tokenizer (using smaller GPT-2 for free tier)
|
6 |
+
model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
|
7 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
|
8 |
+
model = GPT2LMHeadModel.from_pretrained(model_name)
|
9 |
|
10 |
+
# Set pad token
|
11 |
+
tokenizer.pad_token = tokenizer.eos_token
|
|
|
12 |
|
13 |
+
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
|
14 |
+
"""Generate text using GPT-2"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
try:
|
16 |
+
# Encode input
|
17 |
+
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Generate
|
20 |
+
with torch.no_grad():
|
21 |
+
outputs = model.generate(
|
22 |
+
inputs,
|
23 |
+
max_length=min(max_length + len(inputs[0]), 512), # Limit total length
|
24 |
+
temperature=temperature,
|
25 |
+
top_p=top_p,
|
26 |
+
top_k=top_k,
|
27 |
+
do_sample=True,
|
28 |
+
pad_token_id=tokenizer.eos_token_id,
|
29 |
+
num_return_sequences=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
)
|
|
|
31 |
|
32 |
+
# Decode output
|
33 |
+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
34 |
|
35 |
+
# Return only the new generated part
|
36 |
+
return generated_text[len(prompt):].strip()
|
37 |
+
|
38 |
except Exception as e:
|
39 |
+
return f"Error generating text: {str(e)}"
|
40 |
+
|
41 |
+
# Create Gradio interface
|
42 |
+
with gr.Blocks(title="GPT-2 Text Generator") as demo:
|
43 |
+
gr.Markdown("# GPT-2 Text Generation Server")
|
44 |
+
gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
|
45 |
+
|
46 |
+
with gr.Row():
|
47 |
+
with gr.Column():
|
48 |
+
prompt_input = gr.Textbox(
|
49 |
+
label="Prompt",
|
50 |
+
placeholder="Enter your text prompt here...",
|
51 |
+
lines=3
|
52 |
+
)
|
53 |
+
|
54 |
+
with gr.Row():
|
55 |
+
max_length = gr.Slider(
|
56 |
+
minimum=10,
|
57 |
+
maximum=200,
|
58 |
+
value=100,
|
59 |
+
step=10,
|
60 |
+
label="Max Length"
|
61 |
+
)
|
62 |
+
temperature = gr.Slider(
|
63 |
+
minimum=0.1,
|
64 |
+
maximum=2.0,
|
65 |
+
value=0.7,
|
66 |
+
step=0.1,
|
67 |
+
label="Temperature"
|
68 |
+
)
|
69 |
+
|
70 |
+
with gr.Row():
|
71 |
+
top_p = gr.Slider(
|
72 |
+
minimum=0.1,
|
73 |
+
maximum=1.0,
|
74 |
+
value=0.9,
|
75 |
+
step=0.1,
|
76 |
+
label="Top-p"
|
77 |
+
)
|
78 |
+
top_k = gr.Slider(
|
79 |
+
minimum=1,
|
80 |
+
maximum=100,
|
81 |
+
value=50,
|
82 |
+
step=1,
|
83 |
+
label="Top-k"
|
84 |
+
)
|
85 |
+
|
86 |
+
generate_btn = gr.Button("Generate Text", variant="primary")
|
87 |
+
|
88 |
+
with gr.Column():
|
89 |
+
output_text = gr.Textbox(
|
90 |
+
label="Generated Text",
|
91 |
+
lines=10,
|
92 |
+
placeholder="Generated text will appear here..."
|
93 |
+
)
|
94 |
+
|
95 |
+
# Examples
|
96 |
+
gr.Examples(
|
97 |
+
examples=[
|
98 |
+
["Once upon a time in a distant galaxy,"],
|
99 |
+
["The future of artificial intelligence is"],
|
100 |
+
["In the heart of the ancient forest,"],
|
101 |
+
["The detective walked into the room and noticed"],
|
102 |
+
],
|
103 |
+
inputs=prompt_input
|
104 |
+
)
|
105 |
+
|
106 |
+
# Connect the function
|
107 |
+
generate_btn.click(
|
108 |
+
fn=generate_text,
|
109 |
+
inputs=[prompt_input, max_length, temperature, top_p, top_k],
|
110 |
+
outputs=output_text
|
111 |
+
)
|
112 |
|
113 |
+
# Launch the app
|
114 |
if __name__ == "__main__":
|
115 |
+
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
pydantic==2.5.0
|
6 |
-
python-dotenv==1.0.0
|
|
|
1 |
+
gradio>=3.50.0
|
2 |
+
transformers>=4.30.0
|
3 |
+
torch>=2.0.0
|
4 |
+
tokenizers>=0.13.0
|
|
|
|