|
import gradio as gr |
|
import os |
|
from transformers import ( |
|
GPT2LMHeadModel, GPT2Tokenizer, |
|
T5ForConditionalGeneration, T5Tokenizer, |
|
AutoTokenizer, AutoModelForCausalLM |
|
) |
|
import torch |
|
|
|
|
|
MODEL_CONFIGS = { |
|
"gpt2": { |
|
"type": "causal", |
|
"model_class": GPT2LMHeadModel, |
|
"tokenizer_class": GPT2Tokenizer, |
|
"description": "Original GPT-2, good for creative writing", |
|
"size": "117M" |
|
}, |
|
"distilgpt2": { |
|
"type": "causal", |
|
"model_class": AutoModelForCausalLM, |
|
"tokenizer_class": AutoTokenizer, |
|
"description": "Smaller, faster GPT-2", |
|
"size": "82M" |
|
}, |
|
"google/flan-t5-small": { |
|
"type": "seq2seq", |
|
"model_class": T5ForConditionalGeneration, |
|
"tokenizer_class": T5Tokenizer, |
|
"description": "Instruction-following T5 model", |
|
"size": "80M" |
|
}, |
|
"microsoft/DialoGPT-small": { |
|
"type": "causal", |
|
"model_class": AutoModelForCausalLM, |
|
"tokenizer_class": AutoTokenizer, |
|
"description": "Conversational AI model", |
|
"size": "117M" |
|
} |
|
} |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
API_KEY = os.getenv("API_KEY") |
|
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") |
|
|
|
|
|
loaded_model_name = None |
|
model = None |
|
tokenizer = None |
|
|
|
def load_model_and_tokenizer(model_name): |
|
global loaded_model_name, model, tokenizer |
|
if model_name == loaded_model_name and model is not None and tokenizer is not None: |
|
return model, tokenizer |
|
|
|
config = MODEL_CONFIGS[model_name] |
|
if HF_TOKEN: |
|
tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) |
|
else: |
|
tokenizer = config["tokenizer_class"].from_pretrained(model_name) |
|
model = config["model_class"].from_pretrained(model_name) |
|
|
|
|
|
if config["type"] == "causal" and tokenizer.pad_token is None: |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
loaded_model_name = model_name |
|
return model, tokenizer |
|
|
|
def authenticate_api_key(key): |
|
if API_KEY and key != API_KEY: |
|
return False |
|
return True |
|
|
|
def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""): |
|
if API_KEY and not authenticate_api_key(api_key): |
|
return "Error: Invalid API key" |
|
|
|
try: |
|
config = MODEL_CONFIGS[model_name] |
|
model, tokenizer = load_model_and_tokenizer(model_name) |
|
|
|
if config["type"] == "causal": |
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=min(max_length + inputs.shape[1], 512), |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
pad_token_id=tokenizer.pad_token_id, |
|
num_return_sequences=1 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
return generated_text[len(prompt):].strip() |
|
|
|
elif config["type"] == "seq2seq": |
|
|
|
task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt |
|
inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True) |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
num_return_sequences=1 |
|
) |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return generated_text.strip() |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
with gr.Blocks(title="Multi-Model Text Generation Server") as demo: |
|
gr.Markdown("# Multi-Model Text Generation Server") |
|
gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_selector = gr.Dropdown( |
|
label="Model", |
|
choices=list(MODEL_CONFIGS.keys()), |
|
value="gpt2", |
|
interactive=True |
|
) |
|
prompt_input = gr.Textbox( |
|
label="Text Prompt", |
|
placeholder="Enter the text prompt here...", |
|
lines=4 |
|
) |
|
max_length_slider = gr.Slider( |
|
10, 200, 100, 10, |
|
label="Max Generation Length" |
|
) |
|
temperature_slider = gr.Slider( |
|
0.1, 2.0, 0.7, 0.1, |
|
label="Temperature" |
|
) |
|
top_p_slider = gr.Slider( |
|
0.1, 1.0, 0.9, 0.05, |
|
label="Top-p (nucleus sampling)" |
|
) |
|
top_k_slider = gr.Slider( |
|
1, 100, 50, 1, |
|
label="Top-k sampling" |
|
) |
|
if API_KEY: |
|
api_key_input = gr.Textbox( |
|
label="API Key", |
|
type="password", |
|
placeholder="Enter API Key" |
|
) |
|
else: |
|
api_key_input = gr.Textbox(value="", visible=False) |
|
|
|
generate_btn = gr.Button("Generate Text", variant="primary") |
|
|
|
with gr.Column(): |
|
output_textbox = gr.Textbox( |
|
label="Generated Text", |
|
lines=10, |
|
placeholder="Generated text will appear here..." |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input], |
|
outputs=output_textbox |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["Once upon a time in a distant galaxy,"], |
|
["The future of artificial intelligence is"], |
|
["In the heart of the ancient forest,"], |
|
["The detective walked into the room and noticed"], |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None |
|
|
|
if __name__ == "__main__": |
|
demo.launch( |
|
auth=auth_config, |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
ssr_mode=False |
|
) |
|
|