sonyps1928
update app
40bbb95
raw
history blame
11.1 kB
import gradio as gr
import os
from transformers import (
GPT2LMHeadModel, GPT2Tokenizer,
T5ForConditionalGeneration, T5Tokenizer,
AutoTokenizer, AutoModelForCausalLM
)
import torch
import json
from fastapi import FastAPI, HTTPException, Depends, Header
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import uvicorn
from pydantic import BaseModel
from typing import Optional
# Configuration for multiple models
MODEL_CONFIGS = {
"gpt2": {
"type": "causal",
"model_class": GPT2LMHeadModel,
"tokenizer_class": GPT2Tokenizer,
"description": "Original GPT-2, good for creative writing",
"size": "117M"
},
"distilgpt2": {
"type": "causal",
"model_class": AutoModelForCausalLM,
"tokenizer_class": AutoTokenizer,
"description": "Smaller, faster GPT-2",
"size": "82M"
},
"google/flan-t5-small": {
"type": "seq2seq",
"model_class": T5ForConditionalGeneration,
"tokenizer_class": T5Tokenizer,
"description": "Instruction-following T5 model",
"size": "80M"
},
"microsoft/DialoGPT-small": {
"type": "causal",
"model_class": AutoModelForCausalLM,
"tokenizer_class": AutoTokenizer,
"description": "Conversational AI model",
"size": "117M"
}
}
# Environment variables
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("API_KEY")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
# Global state for caching
loaded_model_name = None
model = None
tokenizer = None
# Pydantic models for API
class GenerateRequest(BaseModel):
prompt: str
model_name: str = "gpt2"
max_length: int = 100
temperature: float = 0.7
top_p: float = 0.9
top_k: int = 50
class GenerateResponse(BaseModel):
generated_text: str
model_used: str
status: str = "success"
# Security
security = HTTPBearer(auto_error=False)
def load_model_and_tokenizer(model_name):
global loaded_model_name, model, tokenizer
if model_name not in MODEL_CONFIGS:
raise ValueError(f"Model {model_name} not supported. Available models: {list(MODEL_CONFIGS.keys())}")
if model_name == loaded_model_name and model is not None and tokenizer is not None:
return model, tokenizer
try:
config = MODEL_CONFIGS[model_name]
# Load tokenizer and model
if HF_TOKEN:
tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
else:
tokenizer = config["tokenizer_class"].from_pretrained(model_name)
model = config["model_class"].from_pretrained(model_name)
# Set pad token for causal models if missing
if config["type"] == "causal" and tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
loaded_model_name = model_name
return model, tokenizer
except Exception as e:
raise RuntimeError(f"Failed to load model {model_name}: {str(e)}")
def authenticate_api_key(credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)):
if API_KEY:
if not credentials or credentials.credentials != API_KEY:
raise HTTPException(status_code=401, detail="Invalid or missing API key")
return True
def generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k):
"""Core text generation function"""
try:
config = MODEL_CONFIGS[model_name]
model, tokenizer = load_model_and_tokenizer(model_name)
if config["type"] == "causal":
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=min(max_length + inputs.shape[1], 512),
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
num_return_sequences=1
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text[len(prompt):].strip()
elif config["type"] == "seq2seq":
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
num_return_sequences=1
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
raise RuntimeError(f"Error generating text: {str(e)}")
# Gradio interface function
def generate_text_gradio(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
if API_KEY and api_key != API_KEY:
return "Error: Invalid API key"
try:
return generate_text_core(prompt, model_name, max_length, temperature, top_p, top_k)
except Exception as e:
return f"Error: {str(e)}"
# Create FastAPI app
app = FastAPI(title="Multi-Model Text Generation API", version="1.0.0")
# API Routes
@app.post("/generate", response_model=GenerateResponse)
async def generate_text_api(
request: GenerateRequest,
authenticated: bool = Depends(authenticate_api_key)
):
try:
generated_text = generate_text_core(
request.prompt,
request.model_name,
request.max_length,
request.temperature,
request.top_p,
request.top_k
)
return GenerateResponse(
generated_text=generated_text,
model_used=request.model_name
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/models")
async def list_models():
return {
"models": [
{
"name": name,
"description": config["description"],
"size": config["size"],
"type": config["type"]
}
for name, config in MODEL_CONFIGS.items()
]
}
@app.get("/health")
async def health_check():
return {"status": "healthy", "loaded_model": loaded_model_name}
# Create Gradio interface
with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
gr.Markdown("# Multi-Model Text Generation Server")
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
label="Model",
choices=list(MODEL_CONFIGS.keys()),
value="gpt2",
interactive=True
)
# Show model info
model_info = gr.Markdown("**Model Info:** Original GPT-2, good for creative writing (117M)")
def update_model_info(model_name):
config = MODEL_CONFIGS[model_name]
return f"**Model Info:** {config['description']} ({config['size']})"
model_selector.change(update_model_info, inputs=model_selector, outputs=model_info)
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Enter the text prompt here...",
lines=4
)
max_length_slider = gr.Slider(
10, 200, 100, 10,
label="Max Generation Length"
)
temperature_slider = gr.Slider(
0.1, 2.0, 0.7, 0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
0.1, 1.0, 0.9, 0.05,
label="Top-p (nucleus sampling)"
)
top_k_slider = gr.Slider(
1, 100, 50, 1,
label="Top-k sampling"
)
if API_KEY:
api_key_input = gr.Textbox(
label="API Key",
type="password",
placeholder="Enter API Key"
)
else:
api_key_input = gr.Textbox(value="", visible=False)
generate_btn = gr.Button("Generate Text", variant="primary")
with gr.Column():
output_textbox = gr.Textbox(
label="Generated Text",
lines=10,
placeholder="Generated text will appear here..."
)
generate_btn.click(
fn=generate_text_gradio,
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
outputs=output_textbox
)
gr.Examples(
examples=[
["Once upon a time in a distant galaxy,"],
["The future of artificial intelligence is"],
["In the heart of the ancient forest,"],
["The detective walked into the room and noticed"],
],
inputs=prompt_input
)
# API documentation
with gr.Accordion("API Documentation", open=False):
gr.Markdown("""
## REST API Endpoints
### POST /generate
Generate text using the specified model.
**Request Body:**
```json
{
"prompt": "Your text prompt here",
"model_name": "gpt2",
"max_length": 100,
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50
}
```
**Response:**
```json
{
"generated_text": "Generated text...",
"model_used": "gpt2",
"status": "success"
}
```
### GET /models
List all available models.
### GET /health
Check server health and loaded model status.
**Example cURL:**
```bash
curl -X POST "http://localhost:7860/generate" \
-H "Content-Type: application/json" \
-H "Authorization: Bearer YOUR_API_KEY" \
-d '{"prompt": "Once upon a time", "model_name": "gpt2"}'
```
""")
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")
if __name__ == "__main__":
auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
# Launch with both FastAPI and Gradio
demo.launch(
auth=auth_config,
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False,
share=False
)