File size: 3,767 Bytes
760431c e27c591 1b3fa51 36fde64 760431c 7fe97c0 e27c591 760431c e27c591 1b3fa51 36fde64 e27c591 1b3fa51 e27c591 1b3fa51 e27c591 760431c e27c591 1b3fa51 e27c591 1b3fa51 e27c591 1b3fa51 760431c 107fb80 7fe97c0 760431c 7fe97c0 760431c adb694f 36fde64 760431c ad32177 760431c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
# Load model and tokenizer (using smaller GPT-2 for free tier)
model_name = "gpt2" # You can also use "gpt2-medium" if it fits in memory
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
# Set pad token
tokenizer.pad_token = tokenizer.eos_token
def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9, top_k=50):
"""Generate text using GPT-2"""
try:
# Encode input
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=512, truncation=True)
# Generate
with torch.no_grad():
outputs = model.generate(
inputs,
max_length=min(max_length + len(inputs[0]), 512), # Limit total length
temperature=temperature,
top_p=top_p,
top_k=top_k,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
num_return_sequences=1
)
# Decode output
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Return only the new generated part
return generated_text[len(prompt):].strip()
except Exception as e:
return f"Error generating text: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="GPT-2 Text Generator") as demo:
gr.Markdown("# GPT-2 Text Generation Server")
gr.Markdown("Enter a prompt and generate text using GPT-2. Free tier optimized!")
with gr.Row():
with gr.Column():
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Enter your text prompt here...",
lines=3
)
with gr.Row():
max_length = gr.Slider(
minimum=10,
maximum=200,
value=100,
step=10,
label="Max Length"
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature"
)
with gr.Row():
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top-p"
)
top_k = gr.Slider(
minimum=1,
maximum=100,
value=50,
step=1,
label="Top-k"
)
generate_btn = gr.Button("Generate Text", variant="primary")
with gr.Column():
output_text = gr.Textbox(
label="Generated Text",
lines=10,
placeholder="Generated text will appear here..."
)
# Examples
gr.Examples(
examples=[
["Once upon a time in a distant galaxy,"],
["The future of artificial intelligence is"],
["In the heart of the ancient forest,"],
["The detective walked into the room and noticed"],
],
inputs=prompt_input
)
# Connect the function with explicit API endpoint name
generate_btn.click(
fn=generate_text,
inputs=[prompt_input, max_length, temperature, top_p, top_k],
outputs=output_text,
api_name="/predict" # Explicit API endpoint for external calls
)
# Launch the app
if __name__ == "__main__":
demo.launch() |