import gradio as gr from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch # Load model and tokenizer (using smaller GPT-2 for free tier) model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) # Set pad token tokenizer.pad_token = tokenizer.eos_token def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): """Generate text using GPT-2""" try: # Encode input inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) # Generate with torch.no_grad(): outputs = model.generate( inputs, max_length=min(max_length + len(inputs[0]), 512), # Limit total length temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Return only the new generated part return generated_text[len(prompt):].strip() except Exception as e: return f"Error generating text: {str(e)}" # Create Gradio interface with gr.Blocks(title="GPT-2 Text Generator") as demo: gr.Markdown("# GPT-2 Text Generation Server") gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your text prompt here...", lines=3 ) with gr.Row(): max_length = gr.Slider( minimum=10, maximum=200, value=100, step=10, label="Max Length" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) with gr.Row(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p" ) top_k = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-k" ) generate_btn = gr.Button("Generate Text", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10, placeholder="Generated text will appear here..." ) # Examples gr.Examples( examples=[ ["Once upon a time in a distant galaxy,"], ["The future of artificial intelligence is"], ["In the heart of the ancient forest,"], ["The detective walked into the room and noticed"], ], inputs=prompt_input ) # Connect the function with explicit API endpoint name generate_btn.click( fn=generate_text, inputs=[prompt_input, max_length, temperature, top_p, top_k], outputs=output_text, api_name="/predict" # Explicit API endpoint for external calls ) # Launch the app if __name__ == "__main__": demo.launch()