sonyps1928
update file
06e1ad9
raw
history blame
7.09 kB
import gradio as gr
import os
from transformers import (
GPT2LMHeadModel, GPT2Tokenizer,
T5ForConditionalGeneration, T5Tokenizer,
AutoTokenizer, AutoModelForCausalLM
)
import torch
# Configuration for multiple models, can add more by extending MODEL_CONFIGS dict
MODEL_CONFIGS = {
"gpt2": {
"type": "causal",
"model_class": GPT2LMHeadModel,
"tokenizer_class": GPT2Tokenizer,
"description": "Original GPT-2, good for creative writing",
"size": "117M"
},
"distilgpt2": {
"type": "causal",
"model_class": AutoModelForCausalLM,
"tokenizer_class": AutoTokenizer,
"description": "Smaller, faster GPT-2",
"size": "82M"
},
"google/flan-t5-small": {
"type": "seq2seq",
"model_class": T5ForConditionalGeneration,
"tokenizer_class": T5Tokenizer,
"description": "Instruction-following T5 model",
"size": "80M"
},
"microsoft/DialoGPT-small": {
"type": "causal",
"model_class": AutoModelForCausalLM,
"tokenizer_class": AutoTokenizer,
"description": "Conversational AI model",
"size": "117M"
}
}
# Environment variables for optional authentication and private model access
HF_TOKEN = os.getenv("HF_TOKEN")
API_KEY = os.getenv("API_KEY")
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD")
# Global state for caching loaded model and tokenizer
loaded_model_name = None
model = None
tokenizer = None
def load_model_and_tokenizer(model_name):
global loaded_model_name, model, tokenizer
if model_name == loaded_model_name and model is not None and tokenizer is not None:
return model, tokenizer
config = MODEL_CONFIGS[model_name]
if HF_TOKEN:
tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN)
else:
tokenizer = config["tokenizer_class"].from_pretrained(model_name)
model = config["model_class"].from_pretrained(model_name)
# Set pad token for causal models if missing (important for generation padding)
if config["type"] == "causal" and tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
loaded_model_name = model_name
return model, tokenizer
def authenticate_api_key(key):
if API_KEY and key != API_KEY:
return False
return True
def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""):
if API_KEY and not authenticate_api_key(api_key):
return "Error: Invalid API key"
try:
config = MODEL_CONFIGS[model_name]
model, tokenizer = load_model_and_tokenizer(model_name)
if config["type"] == "causal":
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=min(max_length + inputs.shape[1], 512),
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
num_return_sequences=1
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return generated continuation (remove original prompt)
return generated_text[len(prompt):].strip()
elif config["type"] == "seq2seq":
# Add task prefix for certain seq2seq models like flan-t5
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_length=max_length,
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
num_return_sequences=1
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated_text.strip()
except Exception as e:
return f"Error generating text: {str(e)}"
with gr.Blocks(title="Multi-Model Text Generation Server") as demo:
gr.Markdown("# Multi-Model Text Generation Server")
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.")
with gr.Row():
with gr.Column():
model_selector = gr.Dropdown(
label="Model",
choices=list(MODEL_CONFIGS.keys()),
value="gpt2",
interactive=True
)
prompt_input = gr.Textbox(
label="Text Prompt",
placeholder="Enter the text prompt here...",
lines=4
)
max_length_slider = gr.Slider(
10, 200, 100, 10,
label="Max Generation Length"
)
temperature_slider = gr.Slider(
0.1, 2.0, 0.7, 0.1,
label="Temperature"
)
top_p_slider = gr.Slider(
0.1, 1.0, 0.9, 0.05,
label="Top-p (nucleus sampling)"
)
top_k_slider = gr.Slider(
1, 100, 50, 1,
label="Top-k sampling"
)
if API_KEY:
api_key_input = gr.Textbox(
label="API Key",
type="password",
placeholder="Enter API Key"
)
else:
api_key_input = gr.Textbox(value="", visible=False)
generate_btn = gr.Button("Generate Text", variant="primary")
with gr.Column():
output_textbox = gr.Textbox(
label="Generated Text",
lines=10,
placeholder="Generated text will appear here..."
)
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input],
outputs=output_textbox
)
gr.Examples(
examples=[
["Once upon a time in a distant galaxy,"],
["The future of artificial intelligence is"],
["In the heart of the ancient forest,"],
["The detective walked into the room and noticed"],
],
inputs=prompt_input
)
auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None
if __name__ == "__main__":
demo.launch(
auth=auth_config,
share=True, # Required for Spaces if localhost isn't accessible
server_name="0.0.0.0",
server_port=7860,
ssr_mode=False # Optional: disable server-side rendering to avoid Svelte i18n error
)