|
import gradio as gr |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
|
|
|
|
# Load model and tokenizer (using smaller GPT-2 for free tier) |
|
model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory |
|
tokenizer = GPT2Tokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
|
|
|
# Set pad token |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50): |
|
"""Generate text using GPT-2""" |
|
try: |
|
# Encode input |
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True) |
|
|
|
# Generate |
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=min(max_length + len(inputs[0]), 512), # Limit total length |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
num_return_sequences=1 |
|
) |
|
|
|
# Decode output |
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Return only the new generated part |
|
return generated_text[len(prompt):].strip() |
|
|
|
except Exception as e: |
|
return f"Error generating text: {str(e)}" |
|
|
|
|
|
# Create Gradio interface |
|
with gr.Blocks(title="GPT-2 Text Generator") as demo: |
|
gr.Markdown("# GPT-2 Text Generation Server") |
|
gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox( |
|
label="Prompt", |
|
placeholder="Enter your text prompt here...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
max_length = gr.Slider( |
|
minimum=10, |
|
maximum=200, |
|
value=100, |
|
step=10, |
|
label="Max Length" |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.1, |
|
maximum=2.0, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature" |
|
) |
|
|
|
with gr.Row(): |
|
top_p = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.0, |
|
value=0.9, |
|
step=0.1, |
|
label="Top-p" |
|
) |
|
top_k = gr.Slider( |
|
minimum=1, |
|
maximum=100, |
|
value=50, |
|
step=1, |
|
label="Top-k" |
|
) |
|
|
|
generate_btn = gr.Button("Generate Text", variant="primary") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Generated Text", |
|
lines=10, |
|
placeholder="Generated text will appear here..." |
|
) |
|
|
|
# Examples |
|
gr.Examples( |
|
examples=[ |
|
["Once upon a time in a distant galaxy,"], |
|
["The future of artificial intelligence is"], |
|
["In the heart of the ancient forest,"], |
|
["The detective walked into the room and noticed"], |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
# Connect the function with explicit API endpoint name |
|
generate_btn.click( |
|
fn=generate_text, |
|
inputs=[prompt_input, max_length, temperature, top_p, top_k], |
|
outputs=output_text, |
|
api_name="/predict" # Explicit API endpoint for external calls |
|
) |
|
|
|
|
|
# Launch the app |
|
if __name__ == "__main__": |
|
demo.launch() |