SOTA-Summary / app.py
awacke1's picture
Update app.py
4c1e59d verified
raw
history blame
6.52 kB
import gradio as gr
import torch
from transformers import pipeline
import os
# --- App Configuration ---
TITLE = "✍️ AI Story Outliner"
DESCRIPTION = """
Enter a prompt and get 10 unique story outlines from a CPU-friendly AI model.
The app uses **DistilGPT-2**, a reliable and lightweight model, to generate creative outlines.
**How it works:**
1. Enter your story idea.
2. The AI will generate 10 different story outlines.
3. Each outline has a dramatic beginning and is concise, like a song.
"""
# --- Example Prompts for Storytelling ---
examples = [
["The old lighthouse keeper stared into the storm. He'd seen many tempests, but this one was different. This one had eyes..."],
["In a city powered by dreams, a young inventor creates a machine that can record them. His first recording reveals a nightmare that doesn't belong to him."],
["The knight adjusted his helmet, the dragon's roar echoing in the valley. He was ready for the fight, but for what the dragon said when it finally spoke."],
["She found the old leather-bound journal in her grandfather's attic. The first entry read: 'To relieve stress, I walk in the woods. But today, the woods walked with me.'"],
["The meditation app promised to help her 'delete unhelpful thoughts.' She tapped the button, and to her horror, the memory of her own name began to fade..."]
]
# --- Model Initialization ---
# This section loads a smaller, stable, and CPU-friendly model that requires no authentication.
generator = None
model_error = None
try:
print("Initializing model... This may take a moment.")
# Using 'distilgpt2', a stable and widely supported model that does not require a token.
# This is much more suitable for a standard CPU environment.
generator = pipeline(
"text-generation",
model="distilgpt2",
torch_dtype=torch.float32, # Use float32 for wider CPU compatibility
device_map="auto" # Will use GPU if available, otherwise CPU
)
print("βœ… distilgpt2 model loaded successfully!")
except Exception as e:
model_error = e
print(f"--- 🚨 Error loading model ---")
print(f"Error: {model_error}")
# --- App Logic ---
def generate_stories(prompt: str) -> list[str]:
"""
Generates 10 story outlines from the loaded model based on the user's prompt.
"""
print("--- Button clicked. Attempting to generate stories... ---")
# If the model failed to load during startup, display that error.
if model_error:
error_message = f"**Model failed to load during startup.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(model_error)}`"
print(f"Returning startup error: {error_message}")
return [error_message] * 10
if not prompt:
# Return a list of 10 empty strings to clear the outputs
return [""] * 10
# --- DEBUGGING STEP ---
# To isolate the problem, we will first return a simple list of strings
# to confirm the Gradio UI is working correctly. If this works, the issue
# is with the model pipeline itself.
print("--- RUNNING IN DEBUG MODE ---")
debug_stories = [f"### Story Placeholder {i+1}\n\nThis is a test to confirm the UI is working." for i in range(10)]
return debug_stories
# --- ORIGINAL CODE (Temporarily disabled for debugging) ---
# try:
# # A generic story prompt that works well with models like GPT-2.
# story_prompt = f"""
# Story Idea: "{prompt}"
# Create a short story outline based on this idea.
# ### 🎬 The Hook
# A dramatic opening.
# ### 🎼 The Ballad
# The main story, told concisely.
# ### πŸ”š The Finale
# A clear and satisfying ending.
# ---
# """
# # Parameters for the pipeline to generate 10 diverse results.
# params = {
# "max_new_tokens": 200,
# "num_return_sequences": 10,
# "do_sample": True,
# "temperature": 0.9,
# "top_k": 50,
# "pad_token_id": generator.tokenizer.eos_token_id
# }
# print("Generating text with the model...")
# # Generate 10 different story variations
# outputs = generator(story_prompt, **params)
# print("βœ… Text generation complete.")
# # Extract the generated text.
# stories = []
# for out in outputs:
# full_text = out['generated_text']
# stories.append(full_text)
# # Ensure we return exactly 10 stories, padding if necessary.
# while len(stories) < 10:
# stories.append("Failed to generate a story for this slot.")
# return stories
# except Exception as e:
# # Catch any errors that happen DURING generation and display them in the UI.
# print(f"--- 🚨 Error during story generation ---")
# print(f"Error: {e}")
# runtime_error_message = f"**An error occurred during story generation.**\n\nPlease check the console logs for details.\n\n**Error:**\n`{str(e)}`"
# return [runtime_error_message] * 10
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css=".gradio-container {max-width: 95% !important;}") as demo:
gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>")
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_area = gr.TextArea(
lines=5,
label="Your Story Prompt πŸ‘‡",
placeholder="e.g., 'The last dragon on Earth lived not in a cave, but in a library...'"
)
generate_button = gr.Button("Generate 10 Outlines ✨", variant="primary")
gr.Markdown("---")
gr.Markdown("## πŸ“– Your 10 Story Outlines")
# Create 10 markdown components to display the stories in two columns
story_outputs = []
with gr.Row():
with gr.Column():
for i in range(5):
md = gr.Markdown(label=f"Story Outline {i + 1}")
story_outputs.append(md)
with gr.Column():
for i in range(5, 10):
md = gr.Markdown(label=f"Story Outline {i + 1}")
story_outputs.append(md)
gr.Examples(
examples=examples,
inputs=input_area,
label="Example Story Starters (Click to use)"
)
generate_button.click(
fn=generate_stories,
inputs=input_area,
outputs=story_outputs,
api_name="generate"
)
if __name__ == "__main__":
demo.launch()