sonyps1928 commited on
Commit
40bbb95
·
1 Parent(s): e100984

update app

Browse files
Files changed (2) hide show
  1. app.py +175 -33
  2. requirements.txt +7 -4
app.py CHANGED
@@ -6,8 +6,14 @@ from transformers import (
6
  AutoTokenizer, AutoModelForCausalLM
7
  )
8
  import torch
 
 
 
 
 
 
9
 
10
- # Configuration for multiple models, can add more by extending MODEL_CONFIGS dict
11
  MODEL_CONFIGS = {
12
  "gpt2": {
13
  "type": "causal",
@@ -39,45 +45,71 @@ MODEL_CONFIGS = {
39
  }
40
  }
41
 
42
- # Environment variables for optional authentication and private model access
43
  HF_TOKEN = os.getenv("HF_TOKEN")
44
  API_KEY = os.getenv("API_KEY")
45
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
46
 
47
- # Global state for caching loaded model and tokenizer
48
  loaded_model_name = None
49
  model = None
50
  tokenizer = None
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  def load_model_and_tokenizer(model_name):
53
  global loaded_model_name, model, tokenizer
 
 
 
 
54
  if model_name == loaded_model_name and model is not None and tokenizer is not None:
55
  return model, tokenizer
56
 
57
- config = MODEL_CONFIGS[model_name]
58
- if HF_TOKEN:
59
- tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
60
- model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
61
- else:
62
- tokenizer = config["tokenizer_class"].from_pretrained(model_name)
63
- model = config["model_class"].from_pretrained(model_name)
64
-
65
- # Set pad token for causal models if missing (important for generation padding)
66
- if config["type"] == "causal" and tokenizer.pad_token is None:
67
- tokenizer.pad_token = tokenizer.eos_token
68
-
69
- loaded_model_name = model_name
70
- return model, tokenizer
71
-
72
- def authenticate_api_key(key):
73
- if API_KEY and key != API_KEY:
74
- return False
75
- return True
76
 
77
- def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
78
- if API_KEY and not authenticate_api_key(api_key):
79
- return "Error: Invalid API key"
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
81
  try:
82
  config = MODEL_CONFIGS[model_name]
83
  model, tokenizer = load_model_and_tokenizer(model_name)
@@ -96,11 +128,9 @@ def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api
96
  num_return_sequences=1
97
  )
98
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
- # Return generated continuation (remove original prompt)
100
  return generated_text[len(prompt):].strip()
101
 
102
  elif config["type"] == "seq2seq":
103
- # Add task prefix for certain seq2seq models like flan-t5
104
  task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
105
  inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
106
  with torch.no_grad():
@@ -117,8 +147,62 @@ def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api
117
  return generated_text.strip()
118
 
119
  except Exception as e:
120
- return f"Error generating text: {str(e)}"
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
123
  gr.Markdown("# Multi-Model Text Generation Server")
124
  gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
@@ -131,6 +215,16 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
131
  value="gpt2",
132
  interactive=True
133
  )
 
 
 
 
 
 
 
 
 
 
134
  prompt_input = gr.Textbox(
135
  label="Text Prompt",
136
  placeholder="Enter the text prompt here...",
@@ -171,7 +265,7 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
171
  )
172
 
173
  generate_btn.click(
174
- fn=generate_text,
175
  inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
176
  outputs=output_textbox
177
  )
@@ -186,13 +280,61 @@ with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
186
  inputs=prompt_input
187
  )
188
 
189
- auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  if __name__ == "__main__":
 
 
 
192
  demo.launch(
193
  auth=auth_config,
194
- # share=True, # Required for Spaces if localhost isn't accessible
195
  server_name="0.0.0.0",
196
  server_port=7860,
197
- ssr_mode=False # Optional: disable server-side rendering to avoid Svelte i18n error
198
- )
 
 
6
  AutoTokenizer, AutoModelForCausalLM
7
  )
8
  import torch
9
+ import json
10
+ from fastapi import FastAPI, HTTPException, Depends, Header
11
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
12
+ import uvicorn
13
+ from pydantic import BaseModel
14
+ from typing import Optional
15
 
16
+ # Configuration for multiple models
17
  MODEL_CONFIGS = {
18
  "gpt2": {
19
  "type": "causal",
 
45
  }
46
  }
47
 
48
+ # Environment variables
49
  HF_TOKEN = os.getenv("HF_TOKEN")
50
  API_KEY = os.getenv("API_KEY")
51
  ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
52
 
53
+ # Global state for caching
54
  loaded_model_name = None
55
  model = None
56
  tokenizer = None
57
 
58
+ # Pydantic models for API
59
+ class GenerateRequest(BaseModel):
60
+ prompt: str
61
+ model_name: str = "gpt2"
62
+ max_length: int = 100
63
+ temperature: float = 0.7
64
+ top_p: float = 0.9
65
+ top_k: int = 50
66
+
67
+ class GenerateResponse(BaseModel):
68
+ generated_text: str
69
+ model_used: str
70
+ status: str = "success"
71
+
72
+ # Security
73
+ security = HTTPBearer(auto_error=False)
74
+
75
  def load_model_and_tokenizer(model_name):
76
  global loaded_model_name, model, tokenizer
77
+
78
+ if model_name not in MODEL_CONFIGS:
79
+ raise ValueError(f"Model {model_name} not supported. Available models: {list(MODEL_CONFIGS.keys())}")
80
+
81
  if model_name == loaded_model_name and model is not None and tokenizer is not None:
82
  return model, tokenizer
83
 
84
+ try:
85
+ config = MODEL_CONFIGS[model_name]
86
+
87
+ # Load tokenizer and model
88
+ if HF_TOKEN:
89
+ tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
90
+ model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
91
+ else:
92
+ tokenizer = config["tokenizer_class"].from_pretrained(model_name)
93
+ model = config["model_class"].from_pretrained(model_name)
 
 
 
 
 
 
 
 
 
94
 
95
+ # Set pad token for causal models if missing
96
+ if config["type"] == "causal" and tokenizer.pad_token is None:
97
+ tokenizer.pad_token = tokenizer.eos_token
98
+
99
+ loaded_model_name = model_name
100
+ return model, tokenizer
101
+
102
+ except Exception as e:
103
+ raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
104
 
105
+ def authenticate_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)):
106
+ if API_KEY:
107
+ if not credentials or credentials.credentials != API_KEY:
108
+ raise HTTPException(status_code=401, detail="Invalid or missing API key")
109
+ return True
110
+
111
+ def generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k):
112
+ """Core text generation function"""
113
  try:
114
  config = MODEL_CONFIGS[model_name]
115
  model, tokenizer = load_model_and_tokenizer(model_name)
 
128
  num_return_sequences=1
129
  )
130
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
131
  return generated_text[len(prompt):].strip()
132
 
133
  elif config["type"] == "seq2seq":
 
134
  task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
135
  inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
136
  with torch.no_grad():
 
147
  return generated_text.strip()
148
 
149
  except Exception as e:
150
+ raise RuntimeError(f"Error generating text: {str(e)}")
151
 
152
+ # Gradio interface function
153
+ def generate_text_gradio(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
154
+ if API_KEY and api_key != API_KEY:
155
+ return "Error: Invalid API key"
156
+
157
+ try:
158
+ return generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k)
159
+ except Exception as e:
160
+ return f"Error: {str(e)}"
161
+
162
+ # Create FastAPI app
163
+ app = FastAPI(title="Multi-Model Text Generation API", version="1.0.0")
164
+
165
+ # API Routes
166
+ @app.post("/generate", response_model=GenerateResponse)
167
+ async def generate_text_api(
168
+ request: GenerateRequest,
169
+ authenticated: bool = Depends(authenticate_api_key)
170
+ ):
171
+ try:
172
+ generated_text = generate_text_core(
173
+ request.prompt,
174
+ request.model_name,
175
+ request.max_length,
176
+ request.temperature,
177
+ request.top_p,
178
+ request.top_k
179
+ )
180
+ return GenerateResponse(
181
+ generated_text=generated_text,
182
+ model_used=request.model_name
183
+ )
184
+ except Exception as e:
185
+ raise HTTPException(status_code=500, detail=str(e))
186
+
187
+ @app.get("/models")
188
+ async def list_models():
189
+ return {
190
+ "models": [
191
+ {
192
+ "name": name,
193
+ "description": config["description"],
194
+ "size": config["size"],
195
+ "type": config["type"]
196
+ }
197
+ for name, config in MODEL_CONFIGS.items()
198
+ ]
199
+ }
200
+
201
+ @app.get("/health")
202
+ async def health_check():
203
+ return {"status": "healthy", "loaded_model": loaded_model_name}
204
+
205
+ # Create Gradio interface
206
  with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
207
  gr.Markdown("# Multi-Model Text Generation Server")
208
  gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
 
215
  value="gpt2",
216
  interactive=True
217
  )
218
+
219
+ # Show model info
220
+ model_info = gr.Markdown("**Model Info:** Original GPT-2, good for creative writing (117M)")
221
+
222
+ def update_model_info(model_name):
223
+ config = MODEL_CONFIGS[model_name]
224
+ return f"**Model Info:** {config['description']} ({config['size']})"
225
+
226
+ model_selector.change(update_model_info, inputs=model_selector, outputs=model_info)
227
+
228
  prompt_input = gr.Textbox(
229
  label="Text Prompt",
230
  placeholder="Enter the text prompt here...",
 
265
  )
266
 
267
  generate_btn.click(
268
+ fn=generate_text_gradio,
269
  inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
270
  outputs=output_textbox
271
  )
 
280
  inputs=prompt_input
281
  )
282
 
283
+ # API documentation
284
+ with gr.Accordion("API Documentation", open=False):
285
+ gr.Markdown("""
286
+ ## REST API Endpoints
287
+
288
+ ### POST /generate
289
+ Generate text using the specified model.
290
+
291
+ **Request Body:**
292
+ ```json
293
+ {
294
+ "prompt": "Your text prompt here",
295
+ "model_name": "gpt2",
296
+ "max_length": 100,
297
+ "temperature": 0.7,
298
+ "top_p": 0.9,
299
+ "top_k": 50
300
+ }
301
+ ```
302
+
303
+ **Response:**
304
+ ```json
305
+ {
306
+ "generated_text": "Generated text...",
307
+ "model_used": "gpt2",
308
+ "status": "success"
309
+ }
310
+ ```
311
+
312
+ ### GET /models
313
+ List all available models.
314
+
315
+ ### GET /health
316
+ Check server health and loaded model status.
317
+
318
+ **Example cURL:**
319
+ ```bash
320
+ curl -X POST "http://localhost:7860/generate" \
321
+ -H "Content-Type: application/json" \
322
+ -H "Authorization: Bearer YOUR_API_KEY" \
323
+ -d '{"prompt": "Once upon a time", "model_name": "gpt2"}'
324
+ ```
325
+ """)
326
+
327
+ # Mount Gradio app to FastAPI
328
+ app = gr.mount_gradio_app(app, demo, path="/")
329
 
330
  if __name__ == "__main__":
331
+ auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
332
+
333
+ # Launch with both FastAPI and Gradio
334
  demo.launch(
335
  auth=auth_config,
 
336
  server_name="0.0.0.0",
337
  server_port=7860,
338
+ ssr_mode=False,
339
+ share=False
340
+ )
requirements.txt CHANGED
@@ -1,4 +1,7 @@
1
- gradio>=3.50.0
2
- transformers>=4.30.0
3
- torch>=2.0.0
4
- tokenizers>=0.13.0
 
 
 
 
1
+ gradio>=4.0.0
2
+ transformers>=4.21.0
3
+ torch>=1.12.0
4
+ fastapi>=0.68.0
5
+ uvicorn>=0.15.0
6
+ pydantic>=1.8.0
7
+ python-multipart>=0.0.5