sonyps1928
commited on
Commit
·
40bbb95
1
Parent(s):
e100984
update app
Browse files- app.py +175 -33
- requirements.txt +7 -4
app.py
CHANGED
@@ -6,8 +6,14 @@ from transformers import (
|
|
6 |
AutoTokenizer, AutoModelForCausalLM
|
7 |
)
|
8 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
# Configuration for multiple models
|
11 |
MODEL_CONFIGS = {
|
12 |
"gpt2": {
|
13 |
"type": "causal",
|
@@ -39,45 +45,71 @@ MODEL_CONFIGS = {
|
|
39 |
}
|
40 |
}
|
41 |
|
42 |
-
# Environment variables
|
43 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
44 |
API_KEY = os.getenv("API_KEY")
|
45 |
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
|
46 |
|
47 |
-
# Global state for caching
|
48 |
loaded_model_name = None
|
49 |
model = None
|
50 |
tokenizer = None
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
def load_model_and_tokenizer(model_name):
|
53 |
global loaded_model_name, model, tokenizer
|
|
|
|
|
|
|
|
|
54 |
if model_name == loaded_model_name and model is not None and tokenizer is not None:
|
55 |
return model, tokenizer
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
tokenizer.pad_token = tokenizer.eos_token
|
68 |
-
|
69 |
-
loaded_model_name = model_name
|
70 |
-
return model, tokenizer
|
71 |
-
|
72 |
-
def authenticate_api_key(key):
|
73 |
-
if API_KEY and key != API_KEY:
|
74 |
-
return False
|
75 |
-
return True
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
try:
|
82 |
config = MODEL_CONFIGS[model_name]
|
83 |
model, tokenizer = load_model_and_tokenizer(model_name)
|
@@ -96,11 +128,9 @@ def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api
|
|
96 |
num_return_sequences=1
|
97 |
)
|
98 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
99 |
-
# Return generated continuation (remove original prompt)
|
100 |
return generated_text[len(prompt):].strip()
|
101 |
|
102 |
elif config["type"] == "seq2seq":
|
103 |
-
# Add task prefix for certain seq2seq models like flan-t5
|
104 |
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
|
105 |
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
|
106 |
with torch.no_grad():
|
@@ -117,8 +147,62 @@ def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api
|
|
117 |
return generated_text.strip()
|
118 |
|
119 |
except Exception as e:
|
120 |
-
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
|
123 |
gr.Markdown("# Multi-Model Text Generation Server")
|
124 |
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
|
@@ -131,6 +215,16 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
|
|
131 |
value="gpt2",
|
132 |
interactive=True
|
133 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
prompt_input = gr.Textbox(
|
135 |
label="Text Prompt",
|
136 |
placeholder="Enter the text prompt here...",
|
@@ -171,7 +265,7 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
|
|
171 |
)
|
172 |
|
173 |
generate_btn.click(
|
174 |
-
fn=
|
175 |
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
|
176 |
outputs=output_textbox
|
177 |
)
|
@@ -186,13 +280,61 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
|
|
186 |
inputs=prompt_input
|
187 |
)
|
188 |
|
189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
if __name__ == "__main__":
|
|
|
|
|
|
|
192 |
demo.launch(
|
193 |
auth=auth_config,
|
194 |
-
# share=True, # Required for Spaces if localhost isn't accessible
|
195 |
server_name="0.0.0.0",
|
196 |
server_port=7860,
|
197 |
-
ssr_mode=False
|
198 |
-
|
|
|
|
6 |
AutoTokenizer, AutoModelForCausalLM
|
7 |
)
|
8 |
import torch
|
9 |
+
import json
|
10 |
+
from fastapi import FastAPI, HTTPException, Depends, Header
|
11 |
+
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
12 |
+
import uvicorn
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from typing import Optional
|
15 |
|
16 |
+
# Configuration for multiple models
|
17 |
MODEL_CONFIGS = {
|
18 |
"gpt2": {
|
19 |
"type": "causal",
|
|
|
45 |
}
|
46 |
}
|
47 |
|
48 |
+
# Environment variables
|
49 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
50 |
API_KEY = os.getenv("API_KEY")
|
51 |
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
|
52 |
|
53 |
+
# Global state for caching
|
54 |
loaded_model_name = None
|
55 |
model = None
|
56 |
tokenizer = None
|
57 |
|
58 |
+
# Pydantic models for API
|
59 |
+
class GenerateRequest(BaseModel):
|
60 |
+
prompt: str
|
61 |
+
model_name: str = "gpt2"
|
62 |
+
max_length: int = 100
|
63 |
+
temperature: float = 0.7
|
64 |
+
top_p: float = 0.9
|
65 |
+
top_k: int = 50
|
66 |
+
|
67 |
+
class GenerateResponse(BaseModel):
|
68 |
+
generated_text: str
|
69 |
+
model_used: str
|
70 |
+
status: str = "success"
|
71 |
+
|
72 |
+
# Security
|
73 |
+
security = HTTPBearer(auto_error=False)
|
74 |
+
|
75 |
def load_model_and_tokenizer(model_name):
|
76 |
global loaded_model_name, model, tokenizer
|
77 |
+
|
78 |
+
if model_name not in MODEL_CONFIGS:
|
79 |
+
raise ValueError(f"Model {model_name} not supported. Available models: {list(MODEL_CONFIGS.keys())}")
|
80 |
+
|
81 |
if model_name == loaded_model_name and model is not None and tokenizer is not None:
|
82 |
return model, tokenizer
|
83 |
|
84 |
+
try:
|
85 |
+
config = MODEL_CONFIGS[model_name]
|
86 |
+
|
87 |
+
# Load tokenizer and model
|
88 |
+
if HF_TOKEN:
|
89 |
+
tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
|
90 |
+
model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
|
91 |
+
else:
|
92 |
+
tokenizer = config["tokenizer_class"].from_pretrained(model_name)
|
93 |
+
model = config["model_class"].from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
# Set pad token for causal models if missing
|
96 |
+
if config["type"] == "causal" and tokenizer.pad_token is None:
|
97 |
+
tokenizer.pad_token = tokenizer.eos_token
|
98 |
+
|
99 |
+
loaded_model_name = model_name
|
100 |
+
return model, tokenizer
|
101 |
+
|
102 |
+
except Exception as e:
|
103 |
+
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
|
104 |
|
105 |
+
def authenticate_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)):
|
106 |
+
if API_KEY:
|
107 |
+
if not credentials or credentials.credentials != API_KEY:
|
108 |
+
raise HTTPException(status_code=401, detail="Invalid or missing API key")
|
109 |
+
return True
|
110 |
+
|
111 |
+
def generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k):
|
112 |
+
"""Core text generation function"""
|
113 |
try:
|
114 |
config = MODEL_CONFIGS[model_name]
|
115 |
model, tokenizer = load_model_and_tokenizer(model_name)
|
|
|
128 |
num_return_sequences=1
|
129 |
)
|
130 |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
131 |
return generated_text[len(prompt):].strip()
|
132 |
|
133 |
elif config["type"] == "seq2seq":
|
|
|
134 |
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
|
135 |
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
|
136 |
with torch.no_grad():
|
|
|
147 |
return generated_text.strip()
|
148 |
|
149 |
except Exception as e:
|
150 |
+
raise RuntimeError(f"Error generating text: {str(e)}")
|
151 |
|
152 |
+
# Gradio interface function
|
153 |
+
def generate_text_gradio(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
|
154 |
+
if API_KEY and api_key != API_KEY:
|
155 |
+
return "Error: Invalid API key"
|
156 |
+
|
157 |
+
try:
|
158 |
+
return generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k)
|
159 |
+
except Exception as e:
|
160 |
+
return f"Error: {str(e)}"
|
161 |
+
|
162 |
+
# Create FastAPI app
|
163 |
+
app = FastAPI(title="Multi-Model Text Generation API", version="1.0.0")
|
164 |
+
|
165 |
+
# API Routes
|
166 |
+
@app.post("/generate", response_model=GenerateResponse)
|
167 |
+
async def generate_text_api(
|
168 |
+
request: GenerateRequest,
|
169 |
+
authenticated: bool = Depends(authenticate_api_key)
|
170 |
+
):
|
171 |
+
try:
|
172 |
+
generated_text = generate_text_core(
|
173 |
+
request.prompt,
|
174 |
+
request.model_name,
|
175 |
+
request.max_length,
|
176 |
+
request.temperature,
|
177 |
+
request.top_p,
|
178 |
+
request.top_k
|
179 |
+
)
|
180 |
+
return GenerateResponse(
|
181 |
+
generated_text=generated_text,
|
182 |
+
model_used=request.model_name
|
183 |
+
)
|
184 |
+
except Exception as e:
|
185 |
+
raise HTTPException(status_code=500, detail=str(e))
|
186 |
+
|
187 |
+
@app.get("/models")
|
188 |
+
async def list_models():
|
189 |
+
return {
|
190 |
+
"models": [
|
191 |
+
{
|
192 |
+
"name": name,
|
193 |
+
"description": config["description"],
|
194 |
+
"size": config["size"],
|
195 |
+
"type": config["type"]
|
196 |
+
}
|
197 |
+
for name, config in MODEL_CONFIGS.items()
|
198 |
+
]
|
199 |
+
}
|
200 |
+
|
201 |
+
@app.get("/health")
|
202 |
+
async def health_check():
|
203 |
+
return {"status": "healthy", "loaded_model": loaded_model_name}
|
204 |
+
|
205 |
+
# Create Gradio interface
|
206 |
with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
|
207 |
gr.Markdown("# Multi-Model Text Generation Server")
|
208 |
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
|
|
|
215 |
value="gpt2",
|
216 |
interactive=True
|
217 |
)
|
218 |
+
|
219 |
+
# Show model info
|
220 |
+
model_info = gr.Markdown("**Model Info:** Original GPT-2, good for creative writing (117M)")
|
221 |
+
|
222 |
+
def update_model_info(model_name):
|
223 |
+
config = MODEL_CONFIGS[model_name]
|
224 |
+
return f"**Model Info:** {config['description']} ({config['size']})"
|
225 |
+
|
226 |
+
model_selector.change(update_model_info, inputs=model_selector, outputs=model_info)
|
227 |
+
|
228 |
prompt_input = gr.Textbox(
|
229 |
label="Text Prompt",
|
230 |
placeholder="Enter the text prompt here...",
|
|
|
265 |
)
|
266 |
|
267 |
generate_btn.click(
|
268 |
+
fn=generate_text_gradio,
|
269 |
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
|
270 |
outputs=output_textbox
|
271 |
)
|
|
|
280 |
inputs=prompt_input
|
281 |
)
|
282 |
|
283 |
+
# API documentation
|
284 |
+
with gr.Accordion("API Documentation", open=False):
|
285 |
+
gr.Markdown("""
|
286 |
+
## REST API Endpoints
|
287 |
+
|
288 |
+
### POST /generate
|
289 |
+
Generate text using the specified model.
|
290 |
+
|
291 |
+
**Request Body:**
|
292 |
+
```json
|
293 |
+
{
|
294 |
+
"prompt": "Your text prompt here",
|
295 |
+
"model_name": "gpt2",
|
296 |
+
"max_length": 100,
|
297 |
+
"temperature": 0.7,
|
298 |
+
"top_p": 0.9,
|
299 |
+
"top_k": 50
|
300 |
+
}
|
301 |
+
```
|
302 |
+
|
303 |
+
**Response:**
|
304 |
+
```json
|
305 |
+
{
|
306 |
+
"generated_text": "Generated text...",
|
307 |
+
"model_used": "gpt2",
|
308 |
+
"status": "success"
|
309 |
+
}
|
310 |
+
```
|
311 |
+
|
312 |
+
### GET /models
|
313 |
+
List all available models.
|
314 |
+
|
315 |
+
### GET /health
|
316 |
+
Check server health and loaded model status.
|
317 |
+
|
318 |
+
**Example cURL:**
|
319 |
+
```bash
|
320 |
+
curl -X POST "http://localhost:7860/generate" \
|
321 |
+
-H "Content-Type: application/json" \
|
322 |
+
-H "Authorization: Bearer YOUR_API_KEY" \
|
323 |
+
-d '{"prompt": "Once upon a time", "model_name": "gpt2"}'
|
324 |
+
```
|
325 |
+
""")
|
326 |
+
|
327 |
+
# Mount Gradio app to FastAPI
|
328 |
+
app = gr.mount_gradio_app(app, demo, path="/")
|
329 |
|
330 |
if __name__ == "__main__":
|
331 |
+
auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
|
332 |
+
|
333 |
+
# Launch with both FastAPI and Gradio
|
334 |
demo.launch(
|
335 |
auth=auth_config,
|
|
|
336 |
server_name="0.0.0.0",
|
337 |
server_port=7860,
|
338 |
+
ssr_mode=False,
|
339 |
+
share=False
|
340 |
+
)
|
requirements.txt
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
-
gradio>=
|
2 |
-
transformers>=4.
|
3 |
-
torch>=
|
4 |
-
|
|
|
|
|
|
|
|
1 |
+
gradio>=4.0.0
|
2 |
+
transformers>=4.21.0
|
3 |
+
torch>=1.12.0
|
4 |
+
fastapi>=0.68.0
|
5 |
+
uvicorn>=0.15.0
|
6 |
+
pydantic>=1.8.0
|
7 |
+
python-multipart>=0.0.5
|