|
import gradio as gr |
|
import os |
|
from transformers import GPT2LMHeadModel, GPT2Tokenizer |
|
import torch |
|
|
|
print("π Starting GPT-2 Text Generator...") |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
API_KEY = os.getenv("API_KEY") |
|
ADMIN_PASSWORD = os.getenv("ADMIN_PASSWORD") |
|
|
|
print(f"HF_TOKEN: {'Set' if HF_TOKEN else 'Not set'}") |
|
print(f"API_KEY: {'Set' if API_KEY else 'Not set'}") |
|
print(f"ADMIN_PASSWORD: {'Set' if ADMIN_PASSWORD else 'Not set'}") |
|
|
|
|
|
print("Loading GPT-2 model...") |
|
try: |
|
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") |
|
model = GPT2LMHeadModel.from_pretrained("gpt2") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
print("β
Model loaded successfully!") |
|
except Exception as e: |
|
print(f"β Error loading model: {e}") |
|
raise e |
|
|
|
def generate_text(prompt, max_length=100, temperature=0.7): |
|
"""Simple text generation function""" |
|
if not prompt: |
|
return "Please enter a prompt" |
|
|
|
if len(prompt) > 500: |
|
return "Prompt too long (max 500 characters)" |
|
|
|
try: |
|
print(f"Generating text for: {prompt[:30]}...") |
|
|
|
|
|
inputs = tokenizer.encode(prompt, return_tensors="pt", max_length=300, truncation=True) |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
inputs, |
|
max_length=inputs.shape[1] + max_length, |
|
temperature=temperature, |
|
do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2 |
|
) |
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
new_text = generated_text[len(prompt):].strip() |
|
|
|
print(f"β
Generated {len(new_text)} characters") |
|
return new_text if new_text else "No text generated. Try a different prompt." |
|
|
|
except Exception as e: |
|
error_msg = f"Error generating text: {str(e)}" |
|
print(f"β {error_msg}") |
|
return error_msg |
|
|
|
|
|
print("Creating Gradio interface...") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# GPT-2 Text Generator") |
|
gr.Markdown("Enter a prompt and click generate to create text using GPT-2") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_input = gr.Textbox( |
|
label="Enter your prompt", |
|
placeholder="Type your text here...", |
|
lines=3 |
|
) |
|
|
|
max_length_slider = gr.Slider( |
|
minimum=20, |
|
maximum=200, |
|
value=100, |
|
step=10, |
|
label="Max length of generated text" |
|
) |
|
|
|
temperature_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=1.5, |
|
value=0.7, |
|
step=0.1, |
|
label="Temperature (creativity)" |
|
) |
|
|
|
generate_button = gr.Button("Generate Text", variant="primary") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox( |
|
label="Generated Text", |
|
lines=8, |
|
placeholder="Generated text will appear here..." |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"Once upon a time", |
|
"The future of technology is", |
|
"In a world where", |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
|
|
generate_button.click( |
|
fn=generate_text, |
|
inputs=[prompt_input, max_length_slider, temperature_slider], |
|
outputs=output_text |
|
) |
|
|
|
|
|
print("Launching Gradio app...") |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|
|
print("β
App is running!") |