|
import gradio as gr |
|
import os |
|
from transformers import ( |
|
GPT2LMHeadModel, GPT2Tokenizer, |
|
T5ForConditionalGeneration, T5Tokenizer, |
|
AutoTokenizer, AutoModelForCausalLM |
|
) |
|
import torch |
|
import json |
|
from fastapi import FastAPI, HTTPException, Depends, Header |
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
|
import uvicorn |
|
from pydantic import BaseModel |
|
from typing import Optional |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"gpt2": { |
|
"type": "causal", |
|
"model_class": GPT2LMHeadModel, |
|
"tokenizer_class": GPT2Tokenizer, |
|
"description": "Original GPT-2, good for creative writing", |
|
"size": "117M" |
|
}, |
|
"distilgpt2": { |
|
"type": "causal", |
|
"model_class": AutoModelForCausalLM, |
|
"tokenizer_class": AutoTokenizer, |
|
"description": "Smaller, faster GPT-2", |
|
"size": "82M" |
|
}, |
|
"google/flan-t5-small": { |
|
"type": "seq2seq", |
|
"model_class": T5ForConditionalGeneration, |
|
"tokenizer_class": T5Tokenizer, |
|
"description": "Instruction-following T5 model", |
|
"size": "80M" |
|
}, |
|
"microsoft/DialoGPT-small": { |
|
"type": "causal", |
|
"model_class": AutoModelForCausalLM, |
|
"tokenizer_class": AutoTokenizer, |
|
"description": "Conversational AI model", |
|
"size": "117M" |
|
} |
|
} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
API_KEY = os.getenv("API_KEY") |
|
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") |
|
|
|
|
|
loaded_model_name = None |
|
model = None |
|
tokenizer = None |
|
|
|
|
|
class GenerateRequest(BaseModel): |
|
prompt: str |
|
model_name: str = "gpt2" |
|
max_length: int = 100 |
|
temperature: float = 0.7 |
|
top_p: float = 0.9 |
|
top_k: int = 50 |
|
|
|
class GenerateResponse(BaseModel): |
|
generated_text: str |
|
model_used: str |
|
status: str = "success" |
|
|
|
|
|
security = HTTPBearer(auto_error=False) |
|
|
|
def load_model_and_tokenizer(model_name): |
|
global loaded_model_name, model, tokenizer |
|
|
|
if model_name not in MODEL_CONFIGS: |
|
raise ValueError(f"Model {model_name} not supported. Available models: {list(MODEL_CONFIGS.keys())}") |
|
|
|
if model_name == loaded_model_name and model is not None and tokenizer is not None: |
|
return model, tokenizer |
|
|
|
try: |
|
config = MODEL_CONFIGS[model_name] |
|
|
|
|
|
if HF_TOKEN: |
|
tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
else: |
|
tokenizer = config["tokenizer_class"].from_pretrained(model_name) |
|
model = config["model_class"].from_pretrained(model_name) |
|
|
|
|
|
if config["type"] == "causal" and tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
loaded_model_name = model_name |
|
return model, tokenizer |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}") |
|
|
|
def authenticate_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)): |
|
if API_KEY: |
|
if not credentials or credentials.credentials != API_KEY: |
|
raise HTTPException(status_code=401, detail="Invalid or missing API key") |
|
return True |
|
|
|
def generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k): |
|
"""Core text generation function""" |
|
try: |
|
config = MODEL_CONFIGS[model_name] |
|
model, tokenizer = load_model_and_tokenizer(model_name) |
|
|
|
if config["type"] == "causal": |
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=min(max_length + inputs.shape[1], 512), |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text[len(prompt):].strip() |
|
|
|
elif config["type"] == "seq2seq": |
|
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt |
|
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
num_return_sequences=1 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text.strip() |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error generating text: {str(e)}") |
|
|
|
|
|
def generate_text_gradio(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""): |
|
if API_KEY and api_key != API_KEY: |
|
return "Error: Invalid API key" |
|
|
|
try: |
|
return generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k) |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
|
|
app = FastAPI(title="Multi-Model Text Generation API", version="1.0.0") |
|
|
|
|
|
@app.post("/generate", response_model=GenerateResponse) |
|
async def generate_text_api( |
|
request: GenerateRequest, |
|
authenticated: bool = Depends(authenticate_api_key) |
|
): |
|
try: |
|
generated_text = generate_text_core( |
|
request.prompt, |
|
request.model_name, |
|
request.max_length, |
|
request.temperature, |
|
request.top_p, |
|
request.top_k |
|
) |
|
return GenerateResponse( |
|
generated_text=generated_text, |
|
model_used=request.model_name |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/models") |
|
async def list_models(): |
|
return { |
|
"models": [ |
|
{ |
|
"name": name, |
|
"description": config["description"], |
|
"size": config["size"], |
|
"type": config["type"] |
|
} |
|
for name, config in MODEL_CONFIGS.items() |
|
] |
|
} |
|
|
|
@app.get("/health") |
|
async def health_check(): |
|
return {"status": "healthy", "loaded_model": loaded_model_name} |
|
|
|
|
|
with gr.Blocks(title="Multi-Model Text Generation Server") as demo: |
|
gr.Markdown("# Multi-Model Text Generation Server") |
|
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_selector = gr.Dropdown( |
|
label="Model", |
|
choices=list(MODEL_CONFIGS.keys()), |
|
value="gpt2", |
|
interactive=True |
|
) |
|
|
|
|
|
model_info = gr.Markdown("**Model Info:** Original GPT-2, good for creative writing (117M)") |
|
|
|
def update_model_info(model_name): |
|
config = MODEL_CONFIGS[model_name] |
|
return f"**Model Info:** {config['description']} ({config['size']})" |
|
|
|
model_selector.change(update_model_info, inputs=model_selector, outputs=model_info) |
|
|
|
prompt_input = gr.Textbox( |
|
label="Text Prompt", |
|
placeholder="Enter the text prompt here...", |
|
lines=4 |
|
) |
|
max_length_slider = gr.Slider( |
|
10, 200, 100, 10, |
|
label="Max Generation Length" |
|
) |
|
temperature_slider = gr.Slider( |
|
0.1, 2.0, 0.7, 0.1, |
|
label="Temperature" |
|
) |
|
top_p_slider = gr.Slider( |
|
0.1, 1.0, 0.9, 0.05, |
|
label="Top-p (nucleus sampling)" |
|
) |
|
top_k_slider = gr.Slider( |
|
1, 100, 50, 1, |
|
label="Top-k sampling" |
|
) |
|
if API_KEY: |
|
api_key_input = gr.Textbox( |
|
label="API Key", |
|
type="password", |
|
placeholder="Enter API Key" |
|
) |
|
else: |
|
api_key_input = gr.Textbox(value="", visible=False) |
|
|
|
generate_btn = gr.Button("Generate Text", variant="primary") |
|
|
|
with gr.Column(): |
|
output_textbox = gr.Textbox( |
|
label="Generated Text", |
|
lines=10, |
|
placeholder="Generated text will appear here..." |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_text_gradio, |
|
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input], |
|
outputs=output_textbox |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["Once upon a time in a distant galaxy,"], |
|
["The future of artificial intelligence is"], |
|
["In the heart of the ancient forest,"], |
|
["The detective walked into the room and noticed"], |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
|
|
with gr.Accordion("API Documentation", open=False): |
|
gr.Markdown(""" |
|
## REST API Endpoints |
|
|
|
### POST /generate |
|
Generate text using the specified model. |
|
|
|
**Request Body:** |
|
```json |
|
{ |
|
"prompt": "Your text prompt here", |
|
"model_name": "gpt2", |
|
"max_length": 100, |
|
"temperature": 0.7, |
|
"top_p": 0.9, |
|
"top_k": 50 |
|
} |
|
``` |
|
|
|
**Response:** |
|
```json |
|
{ |
|
"generated_text": "Generated text...", |
|
"model_used": "gpt2", |
|
"status": "success" |
|
} |
|
``` |
|
|
|
### GET /models |
|
List all available models. |
|
|
|
### GET /health |
|
Check server health and loaded model status. |
|
|
|
**Example cURL:** |
|
```bash |
|
curl -X POST "http://localhost:7860/generate" \ |
|
-H "Content-Type: application/json" \ |
|
-H "Authorization: Bearer YOUR_API_KEY" \ |
|
-d '{"prompt": "Once upon a time", "model_name": "gpt2"}' |
|
``` |
|
""") |
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
if __name__ == "__main__": |
|
auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None |
|
|
|
|
|
demo.launch( |
|
auth=auth_config, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
ssr_mode=False, |
|
share=False |
|
) |