File size: 3,767 Bytes
107fb80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr from transformers import GPT2LMHeadModel, GPT2Tokenizer import torch # Load model and tokenizer (using smaller GPT-2 for free tier) model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory tokenizer = GPT2Tokenizer.from_pretrained(model_name) model = GPT2LMHeadModel.from_pretrained(model_name) # Set pad token tokenizer.pad_token = tokenizer.eos_token def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): """Generate text using GPT-2""" try: # Encode input inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) # Generate with torch.no_grad(): outputs = model.generate( inputs, max_length=min(max_length + len(inputs[0]), 512), # Limit total length temperature=temperature, top_p=top_p, top_k=top_k, do_sample=True, pad_token_id=tokenizer.eos_token_id, num_return_sequences=1 ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Return only the new generated part return generated_text[len(prompt):].strip() except Exception as e: return f"Error generating text: {str(e)}" # Create Gradio interface with gr.Blocks(title="GPT-2 Text Generator") as demo: gr.Markdown("# GPT-2 Text Generation Server") gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!") with gr.Row(): with gr.Column(): prompt_input = gr.Textbox( label="Prompt", placeholder="Enter your text prompt here...", lines=3 ) with gr.Row(): max_length = gr.Slider( minimum=10, maximum=200, value=100, step=10, label="Max Length" ) temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature" ) with gr.Row(): top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p" ) top_k = gr.Slider( minimum=1, maximum=100, value=50, step=1, label="Top-k" ) generate_btn = gr.Button("Generate Text", variant="primary") with gr.Column(): output_text = gr.Textbox( label="Generated Text", lines=10, placeholder="Generated text will appear here..." ) # Examples gr.Examples( examples=[ ["Once upon a time in a distant galaxy,"], ["The future of artificial intelligence is"], ["In the heart of the ancient forest,"], ["The detective walked into the room and noticed"], ], inputs=prompt_input ) # Connect the function with explicit API endpoint name generate_btn.click( fn=generate_text, inputs=[prompt_input, max_length, temperature, top_p, top_k], outputs=output_text, api_name="/predict" # Explicit API endpoint for external calls ) # Launch the app if __name__ == "__main__": demo.launch() |