sonyps1928
update app
760431c
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Load model and tokenizer (using smaller GPT-2 for free tier)
model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Set pad token
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
"""Generate text using GPT-2"""
try:
# Encode input
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
# Generate
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=min(max_length + len(inputs[0]), 512), # Limit total length
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the new generated part
return generated_text[len(prompt):].strip()
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="GPT-2 Text Generator") as demo:
gr.Markdown("# GPT-2 Text Generation Server")
gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your text prompt here...",
lines=3
)
with gr.Row():
max_length = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
generate_btn = gr.Button("Generate Text", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
placeholder="Generated text will appear here..."
)
# Examples
gr.Examples(
examples=[
["Once upon a time in a distant galaxy,"],
["The future of artificial intelligence is"],
["In the heart of the ancient forest,"],
["The detective walked into the room and noticed"],
],
inputs=prompt_input
)
# Connect the function with explicit API endpoint name
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k],
outputs=output_text,
api_name="/predict" # Explicit API endpoint for external calls
)
# Launch the app
if __name__ == "__main__":
demo.launch()