sonyps1928 commited on
Commit
cef31a4
Β·
1 Parent(s): 40bbb95

update app

Browse files
Files changed (1) hide show
  1. app.py +191 -298
app.py CHANGED
@@ -1,340 +1,233 @@
1
  import gradio as gr
2
  import os
3
- from transformers import (
4
- GPT2LMHeadModel, GPT2Tokenizer,
5
- T5ForConditionalGeneration, T5Tokenizer,
6
- AutoTokenizer, AutoModelForCausalLM
7
- )
8
  import torch
9
- import json
10
- from fastapi import FastAPI, HTTPException, Depends, Header
11
- from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
- import uvicorn
13
- from pydantic import BaseModel
14
- from typing import Optional
15
 
16
- # Configuration for multiple models
17
- MODEL_CONFIGS = {
18
- "gpt2": {
19
- "type": "causal",
20
- "model_class": GPT2LMHeadModel,
21
- "tokenizer_class": GPT2Tokenizer,
22
- "description": "Original GPT-2, good for creative writing",
23
- "size": "117M"
24
- },
25
- "distilgpt2": {
26
- "type": "causal",
27
- "model_class": AutoModelForCausalLM,
28
- "tokenizer_class": AutoTokenizer,
29
- "description": "Smaller, faster GPT-2",
30
- "size": "82M"
31
- },
32
- "google/flan-t5-small": {
33
- "type": "seq2seq",
34
- "model_class": T5ForConditionalGeneration,
35
- "tokenizer_class": T5Tokenizer,
36
- "description": "Instruction-following T5 model",
37
- "size": "80M"
38
- },
39
- "microsoft/DialoGPT-small": {
40
- "type": "causal",
41
- "model_class": AutoModelForCausalLM,
42
- "tokenizer_class": AutoTokenizer,
43
- "description": "Conversational AI model",
44
- "size": "117M"
45
- }
46
- }
47
-
48
- # Environment variables
49
  HF_TOKEN = os.getenv("HF_TOKEN")
50
  API_KEY = os.getenv("API_KEY")
51
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
52
 
53
- # Global state for caching
54
- loaded_model_name = None
55
- model = None
56
- tokenizer = None
57
-
58
- # Pydantic models for API
59
- class GenerateRequest(BaseModel):
60
- prompt: str
61
- model_name: str = "gpt2"
62
- max_length: int = 100
63
- temperature: float = 0.7
64
- top_p: float = 0.9
65
- top_k: int = 50
66
-
67
- class GenerateResponse(BaseModel):
68
- generated_text: str
69
- model_used: str
70
- status: str = "success"
71
-
72
- # Security
73
- security = HTTPBearer(auto_error=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- def load_model_and_tokenizer(model_name):
76
- global loaded_model_name, model, tokenizer
 
 
 
 
 
77
 
78
- if model_name not in MODEL_CONFIGS:
79
- raise ValueError(f"Model {model_name} not supported. Available models: {list(MODEL_CONFIGS.keys())}")
 
80
 
81
- if model_name == loaded_model_name and model is not None and tokenizer is not None:
82
- return model, tokenizer
83
 
84
  try:
85
- config = MODEL_CONFIGS[model_name]
 
86
 
87
- # Load tokenizer and model
88
- if HF_TOKEN:
89
- tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
90
- model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
91
- else:
92
- tokenizer = config["tokenizer_class"].from_pretrained(model_name)
93
- model = config["model_class"].from_pretrained(model_name)
94
-
95
- # Set pad token for causal models if missing
96
- if config["type"] == "causal" and tokenizer.pad_token is None:
97
- tokenizer.pad_token = tokenizer.eos_token
98
-
99
- loaded_model_name = model_name
100
- return model, tokenizer
 
 
 
 
 
101
 
102
  except Exception as e:
103
- raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
104
-
105
- def authenticate_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)):
 
 
 
 
 
 
 
 
 
 
106
  if API_KEY:
107
- if not credentials or credentials.credentials != API_KEY:
108
- raise HTTPException(status_code=401, detail="Invalid or missing API key")
109
- return True
110
-
111
- def generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k):
112
- """Core text generation function"""
113
- try:
114
- config = MODEL_CONFIGS[model_name]
115
- model, tokenizer = load_model_and_tokenizer(model_name)
116
-
117
- if config["type"] == "causal":
118
- inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
119
- with torch.no_grad():
120
- outputs = model.generate(
121
- inputs,
122
- max_length=min(max_length + inputs.shape[1], 512),
123
- temperature=temperature,
124
- top_p=top_p,
125
- top_k=top_k,
126
- do_sample=True,
127
- pad_token_id=tokenizer.pad_token_id,
128
- num_return_sequences=1
129
- )
130
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
131
- return generated_text[len(prompt):].strip()
132
-
133
- elif config["type"] == "seq2seq":
134
- task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
135
- inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
136
- with torch.no_grad():
137
- outputs = model.generate(
138
- **inputs,
139
- max_length=max_length,
140
- temperature=temperature,
141
- top_p=top_p,
142
- top_k=top_k,
143
- do_sample=True,
144
- num_return_sequences=1
145
- )
146
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
147
- return generated_text.strip()
148
-
149
- except Exception as e:
150
- raise RuntimeError(f"Error generating text: {str(e)}")
151
-
152
- # Gradio interface function
153
- def generate_text_gradio(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
154
- if API_KEY and api_key != API_KEY:
155
- return "Error: Invalid API key"
156
 
157
- try:
158
- return generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k)
159
- except Exception as e:
160
- return f"Error: {str(e)}"
161
-
162
- # Create FastAPI app
163
- app = FastAPI(title="Multi-Model Text Generation API", version="1.0.0")
164
-
165
- # API Routes
166
- @app.post("/generate", response_model=GenerateResponse)
167
- async def generate_text_api(
168
- request: GenerateRequest,
169
- authenticated: bool = Depends(authenticate_api_key)
170
- ):
171
- try:
172
- generated_text = generate_text_core(
173
- request.prompt,
174
- request.model_name,
175
- request.max_length,
176
- request.temperature,
177
- request.top_p,
178
- request.top_k
179
- )
180
- return GenerateResponse(
181
- generated_text=generated_text,
182
- model_used=request.model_name
183
- )
184
- except Exception as e:
185
- raise HTTPException(status_code=500, detail=str(e))
186
-
187
- @app.get("/models")
188
- async def list_models():
189
- return {
190
- "models": [
191
- {
192
- "name": name,
193
- "description": config["description"],
194
- "size": config["size"],
195
- "type": config["type"]
196
- }
197
- for name, config in MODEL_CONFIGS.items()
198
- ]
199
- }
200
-
201
- @app.get("/health")
202
- async def health_check():
203
- return {"status": "healthy", "loaded_model": loaded_model_name}
204
-
205
- # Create Gradio interface
206
- with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
207
- gr.Markdown("# Multi-Model Text Generation Server")
208
- gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
209
-
210
  with gr.Row():
211
  with gr.Column():
212
- model_selector = gr.Dropdown(
213
- label="Model",
214
- choices=list(MODEL_CONFIGS.keys()),
215
- value="gpt2",
216
- interactive=True
217
- )
218
-
219
- # Show model info
220
- model_info = gr.Markdown("**Model Info:** Original GPT-2, good for creative writing (117M)")
221
-
222
- def update_model_info(model_name):
223
- config = MODEL_CONFIGS[model_name]
224
- return f"**Model Info:** {config['description']} ({config['size']})"
225
-
226
- model_selector.change(update_model_info, inputs=model_selector, outputs=model_info)
227
-
228
  prompt_input = gr.Textbox(
229
- label="Text Prompt",
230
- placeholder="Enter the text prompt here...",
231
- lines=4
232
- )
233
- max_length_slider = gr.Slider(
234
- 10, 200, 100, 10,
235
- label="Max Generation Length"
236
- )
237
- temperature_slider = gr.Slider(
238
- 0.1, 2.0, 0.7, 0.1,
239
- label="Temperature"
240
- )
241
- top_p_slider = gr.Slider(
242
- 0.1, 1.0, 0.9, 0.05,
243
- label="Top-p (nucleus sampling)"
244
- )
245
- top_k_slider = gr.Slider(
246
- 1, 100, 50, 1,
247
- label="Top-k sampling"
248
  )
 
 
249
  if API_KEY:
250
  api_key_input = gr.Textbox(
251
- label="API Key",
252
  type="password",
253
- placeholder="Enter API Key"
 
254
  )
255
  else:
256
  api_key_input = gr.Textbox(value="", visible=False)
257
-
258
- generate_btn = gr.Button("Generate Text", variant="primary")
259
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  with gr.Column():
261
- output_textbox = gr.Textbox(
262
- label="Generated Text",
263
- lines=10,
264
- placeholder="Generated text will appear here..."
 
265
  )
266
-
267
- generate_btn.click(
268
- fn=generate_text_gradio,
269
- inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
270
- outputs=output_textbox
271
- )
272
-
273
  gr.Examples(
274
  examples=[
275
  ["Once upon a time in a distant galaxy,"],
276
  ["The future of artificial intelligence is"],
277
  ["In the heart of the ancient forest,"],
278
  ["The detective walked into the room and noticed"],
 
279
  ],
280
- inputs=prompt_input
 
 
 
 
 
 
 
 
281
  )
282
 
283
- # API documentation
284
- with gr.Accordion("API Documentation", open=False):
285
- gr.Markdown("""
286
- ## REST API Endpoints
287
-
288
- ### POST /generate
289
- Generate text using the specified model.
290
-
291
- **Request Body:**
292
- ```json
293
- {
294
- "prompt": "Your text prompt here",
295
- "model_name": "gpt2",
296
- "max_length": 100,
297
- "temperature": 0.7,
298
- "top_p": 0.9,
299
- "top_k": 50
300
- }
301
- ```
302
-
303
- **Response:**
304
- ```json
305
- {
306
- "generated_text": "Generated text...",
307
- "model_used": "gpt2",
308
- "status": "success"
309
- }
310
- ```
311
-
312
- ### GET /models
313
- List all available models.
314
-
315
- ### GET /health
316
- Check server health and loaded model status.
317
-
318
- **Example cURL:**
319
- ```bash
320
- curl -X POST "http://localhost:7860/generate" \
321
- -H "Content-Type: application/json" \
322
- -H "Authorization: Bearer YOUR_API_KEY" \
323
- -d '{"prompt": "Once upon a time", "model_name": "gpt2"}'
324
- ```
325
- """)
326
-
327
- # Mount Gradio app to FastAPI
328
- app = gr.mount_gradio_app(app, demo, path="/")
329
 
330
  if __name__ == "__main__":
331
- auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
332
-
333
- # Launch with both FastAPI and Gradio
334
  demo.launch(
335
- auth=auth_config,
336
- server_name="0.0.0.0",
337
- server_port=7860,
338
- ssr_mode=False,
339
- share=False
340
- )
 
1
  import gradio as gr
2
  import os
3
+ import hashlib
4
+ import time
5
+ from collections import defaultdict
6
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
 
7
  import torch
 
 
 
 
 
 
8
 
9
+ # Load secrets from environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  HF_TOKEN = os.getenv("HF_TOKEN")
11
  API_KEY = os.getenv("API_KEY")
12
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
13
 
14
+ print(f"πŸ” Security Status:")
15
+ print(f" HF_TOKEN: {'βœ… Set' if HF_TOKEN else '❌ Not set'}")
16
+ print(f" API_KEY: {'βœ… Set' if API_KEY else '❌ Not set'}")
17
+ print(f" ADMIN_PASSWORD: {'βœ… Set' if ADMIN_PASSWORD else '❌ Not set'}")
18
+
19
+ # Rate limiting storage
20
+ request_counts = defaultdict(list)
21
+
22
+ # Load model with optional HF token
23
+ model_name = "gpt2"
24
+ try:
25
+ if HF_TOKEN:
26
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name, use_auth_token=HF_TOKEN)
27
+ model = GPT2LMHeadModel.from_pretrained(model_name, use_auth_token=HF_TOKEN)
28
+ print("βœ… Model loaded with HF token")
29
+ else:
30
+ tokenizer = GPT2Tokenizer.from_pretrained(model_name)
31
+ model = GPT2LMHeadModel.from_pretrained(model_name)
32
+ print("βœ… Model loaded without token")
33
+
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+ print("βœ… Model initialization complete")
36
+
37
+ except Exception as e:
38
+ print(f"❌ Model loading failed: {e}")
39
+ raise e
40
+
41
+ def validate_api_key(provided_key):
42
+ """Validate API key with rate limiting"""
43
+ if not API_KEY:
44
+ return True, "No API key required"
45
+
46
+ if not provided_key:
47
+ return False, "API key required but not provided"
48
+
49
+ if provided_key != API_KEY:
50
+ return False, "Invalid API key"
51
+
52
+ # Rate limiting per API key
53
+ now = time.time()
54
+ key_hash = hashlib.sha256(provided_key.encode()).hexdigest()[:8]
55
+
56
+ # Clean old requests (last hour)
57
+ request_counts[key_hash] = [
58
+ req_time for req_time in request_counts[key_hash]
59
+ if now - req_time < 3600
60
+ ]
61
+
62
+ # Check rate limit (100 requests per hour)
63
+ if len(request_counts[key_hash]) >= 100:
64
+ return False, "Rate limit exceeded (100 requests/hour)"
65
+
66
+ # Log successful request
67
+ request_counts[key_hash].append(now)
68
+ return True, f"Authenticated (Requests: {len(request_counts[key_hash])}/100)"
69
 
70
+ def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50, api_key=""):
71
+ """Generate text with security validation"""
72
+
73
+ # Validate API key
74
+ is_valid, message = validate_api_key(api_key)
75
+ if not is_valid:
76
+ return f"πŸ”’ Authentication Error: {message}"
77
 
78
+ # Input validation
79
+ if not prompt or len(prompt.strip()) == 0:
80
+ return "❌ Error: Prompt cannot be empty"
81
 
82
+ if len(prompt) > 1000:
83
+ return "❌ Error: Prompt too long (max 1000 characters)"
84
 
85
  try:
86
+ print(f"πŸ”‘ {message}")
87
+ print(f"πŸ“ Generating text for prompt: {prompt[:50]}...")
88
 
89
+ inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
90
+
91
+ with torch.no_grad():
92
+ outputs = model.generate(
93
+ inputs,
94
+ max_length=min(max_length + len(inputs[0]), 512),
95
+ temperature=max(0.1, min(2.0, temperature)),
96
+ top_p=max(0.1, min(1.0, top_p)),
97
+ top_k=max(1, min(100, top_k)),
98
+ do_sample=True,
99
+ pad_token_id=tokenizer.eos_token_id,
100
+ num_return_sequences=1
101
+ )
102
+
103
+ generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+ result = generated_text[len(prompt):].strip()
105
+
106
+ print(f"βœ… Generation successful, length: {len(result)} chars")
107
+ return result
108
 
109
  except Exception as e:
110
+ error_msg = f"❌ Generation error: {str(e)}"
111
+ print(error_msg)
112
+ return error_msg
113
+
114
+ # Create Gradio interface with conditional elements
115
+ with gr.Blocks(title="πŸ” Secure GPT-2 Generator") as demo:
116
+ gr.Markdown("# πŸ” Secure GPT-2 Text Generator")
117
+ gr.Markdown("**Security Features**: API Authentication β€’ Rate Limiting β€’ Admin Protection")
118
+
119
+ # Security status display
120
+ security_status = []
121
+ if HF_TOKEN:
122
+ security_status.append("πŸ”‘ HF Token Active")
123
  if API_KEY:
124
+ security_status.append("πŸ”’ API Authentication Enabled")
125
+ if ADMIN_PASSWORD:
126
+ security_status.append("πŸ‘€ Admin Protection Active")
127
+
128
+ if security_status:
129
+ gr.Markdown(f"**Active Security**: {' β€’ '.join(security_status)}")
130
+ else:
131
+ gr.Markdown("⚠️ **No security features enabled** - running in public mode")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with gr.Row():
134
  with gr.Column():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  prompt_input = gr.Textbox(
136
+ label="✏️ Text Prompt",
137
+ placeholder="Enter your prompt here... (max 1000 chars)",
138
+ lines=3,
139
+ max_lines=5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  )
141
+
142
+ # Show API key input only if API_KEY is configured
143
  if API_KEY:
144
  api_key_input = gr.Textbox(
145
+ label="πŸ”‘ API Key (Required)",
146
  type="password",
147
+ placeholder="Enter your API key...",
148
+ info="API authentication is enabled for this Space"
149
  )
150
  else:
151
  api_key_input = gr.Textbox(value="", visible=False)
152
+ gr.Markdown("πŸ”“ **Public Access**: No API key required")
153
+
154
+ with gr.Accordion("βš™οΈ Generation Parameters", open=False):
155
+ with gr.Row():
156
+ max_length = gr.Slider(
157
+ minimum=10,
158
+ maximum=200,
159
+ value=100,
160
+ step=10,
161
+ label="πŸ“ Max Length"
162
+ )
163
+ temperature = gr.Slider(
164
+ minimum=0.1,
165
+ maximum=2.0,
166
+ value=0.7,
167
+ step=0.1,
168
+ label="🌑️ Temperature"
169
+ )
170
+
171
+ with gr.Row():
172
+ top_p = gr.Slider(
173
+ minimum=0.1,
174
+ maximum=1.0,
175
+ value=0.9,
176
+ step=0.1,
177
+ label="🎯 Top-p"
178
+ )
179
+ top_k = gr.Slider(
180
+ minimum=1,
181
+ maximum=100,
182
+ value=50,
183
+ step=1,
184
+ label="πŸ”’ Top-k"
185
+ )
186
+
187
+ generate_btn = gr.Button("πŸš€ Generate Text", variant="primary", size="lg")
188
+
189
  with gr.Column():
190
+ output_text = gr.Textbox(
191
+ label="πŸ“„ Generated Text",
192
+ lines=12,
193
+ placeholder="Generated text will appear here...",
194
+ show_copy_button=True
195
  )
196
+
197
+ # Rate limit info
198
+ if API_KEY:
199
+ gr.Markdown("**Rate Limits**: 100 requests per hour per API key")
200
+
201
+ # Examples
 
202
  gr.Examples(
203
  examples=[
204
  ["Once upon a time in a distant galaxy,"],
205
  ["The future of artificial intelligence is"],
206
  ["In the heart of the ancient forest,"],
207
  ["The detective walked into the room and noticed"],
208
+ ["Write a short story about a robot who dreams of"],
209
  ],
210
+ inputs=prompt_input,
211
+ label="πŸ’‘ Example Prompts"
212
+ )
213
+
214
+ # Connect the generation function
215
+ generate_btn.click(
216
+ fn=generate_text,
217
+ inputs=[prompt_input, max_length, temperature, top_p, top_k, api_key_input],
218
+ outputs=output_text
219
  )
220
 
221
+ # Launch with authentication
222
+ auth_tuple = None
223
+ if ADMIN_PASSWORD:
224
+ auth_tuple = ("admin", ADMIN_PASSWORD)
225
+ print("πŸ” Admin authentication enabled")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  if __name__ == "__main__":
 
 
 
228
  demo.launch(
229
+ auth=auth_tuple,
230
+ show_api=True, # Enable API documentation
231
+ show_error=True
232
+ )
233
+ print("πŸš€ Secure GPT-2 Generator is running!")