import gradio as gr import os from transformers import ( GPT2LMHeadModel, GPT2Tokenizer, T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer, AutoModelForCausalLM ) import torch # Configuration for multiple models, can add more by extending MODEL_CONFIGS dict MODEL_CONFIGS = { "gpt2": { "type": "causal", "model_class": GPT2LMHeadModel, "tokenizer_class": GPT2Tokenizer, "description": "Original GPT-2, good for creative writing", "size": "117M" }, "distilgpt2": { "type": "causal", "model_class": AutoModelForCausalLM, "tokenizer_class": AutoTokenizer, "description": "Smaller, faster GPT-2", "size": "82M" }, "google/flan-t5-small": { "type": "seq2seq", "model_class": T5ForConditionalGeneration, "tokenizer_class": T5Tokenizer, "description": "Instruction-following T5 model", "size": "80M" }, "microsoft/DialoGPT-small": { "type": "causal", "model_class": AutoModelForCausalLM, "tokenizer_class": AutoTokenizer, "description": "Conversational AI model", "size": "117M" } } # Environment variables for optional authentication and private model access HF_TOKEN = os.getenv("HF_TOKEN") API_KEY = os.getenv("API_KEY") ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") # Global state for caching loaded model and tokenizer loaded_model_name = None model = None tokenizer = None def load_model_and_tokenizer(model_name): global loaded_model_name, model, tokenizer if model_name == loaded_model_name and model is not None and tokenizer is not None: return model, tokenizer config = MODEL_CONFIGS[model_name] if HF_TOKEN: tokenizer = config["tokenizer_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) model = config["model_class"].from_pretrained(model_name, use_auth_token=HF_TOKEN) else: tokenizer = config["tokenizer_class"].from_pretrained(model_name) model = config["model_class"].from_pretrained(model_name) # Set pad token for causal models if missing (important for generation padding) if config["type"] == "causal" and tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token loaded_model_name = model_name return model, tokenizer def authenticate_api_key(key): if API_KEY and key != API_KEY: return False return True def generate_text(prompt, model_name, max_length, temperature, top_p, top_k, api_key=""): if API_KEY and not authenticate_api_key(api_key): return "Error: Invalid API key" try: config = MODEL_CONFIGS[model_name] model, tokenizer = load_model_and_tokenizer(model_name) if config["type"] == "causal": inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) with torch.no_grad(): outputs = model.generate( inputs, max_length=min(max_length + inputs.shape[1], 512), temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=tokenizer.pad_token_id, num_return_sequences=1 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Return generated continuation (remove original prompt) return generated_text[len(prompt):].strip() elif config["type"] == "seq2seq": # Add task prefix for certain seq2seq models like flan-t5 task_prompt = f"Complete this text: {prompt}" if "flan-t5" in model_name.lower() else prompt inputs = tokenizer(task_prompt, return_tensors="pt", max_length=512, truncation=True) with torch.no_grad(): outputs = model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, num_return_sequences=1 ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) return generated_text.strip() except Exception as e: return f"Error generating text: {str(e)}" with gr.Blocks(title="Multi-Model Text Generation Server") as demo: gr.Markdown("# Multi-Model Text Generation Server") gr.Markdown("Choose a model from the dropdown, enter a text prompt, and generate text.") with gr.Row(): with gr.Column(): model_selector = gr.Dropdown( label="Model", choices=list(MODEL_CONFIGS.keys()), value="gpt2", interactive=True ) prompt_input = gr.Textbox( label="Text Prompt", placeholder="Enter the text prompt here...", lines=4 ) max_length_slider = gr.Slider( 10, 200, 100, 10, label="Max Generation Length" ) temperature_slider = gr.Slider( 0.1, 2.0, 0.7, 0.1, label="Temperature" ) top_p_slider = gr.Slider( 0.1, 1.0, 0.9, 0.05, label="Top-p (nucleus sampling)" ) top_k_slider = gr.Slider( 1, 100, 50, 1, label="Top-k sampling" ) if API_KEY: api_key_input = gr.Textbox( label="API Key", type="password", placeholder="Enter API Key" ) else: api_key_input = gr.Textbox(value="", visible=False) generate_btn = gr.Button("Generate Text", variant="primary") with gr.Column(): output_textbox = gr.Textbox( label="Generated Text", lines=10, placeholder="Generated text will appear here..." ) generate_btn.click( fn=generate_text, inputs=[prompt_input, model_selector, max_length_slider, temperature_slider, top_p_slider, top_k_slider, api_key_input], outputs=output_textbox ) gr.Examples( examples=[ ["Once upon a time in a distant galaxy,"], ["The future of artificial intelligence is"], ["In the heart of the ancient forest,"], ["The detective walked into the room and noticed"], ], inputs=prompt_input ) auth_config = ("admin", ADMIN_PASSWORD) if ADMIN_PASSWORD else None if __name__ == "__main__": demo.launch( auth=auth_config, # share=True, # Required for Spaces if localhost isn't accessible server_name="0.0.0.0", server_port=7860, ssr_mode=False # Optional: disable server-side rendering to avoid Svelte i18n error )