Miracle / app.py
rajux75's picture
Update app.py
b5173a7 verified
import gradio as gr
import time
import random
import torch
import numpy as np
from PIL import Image
import imageio # For saving video
import tempfile # For creating temporary files
import os
# --- Hugging Face Model Imports ---
from transformers import T5ForConditionalGeneration, T5Tokenizer
from diffusers import StableDiffusionPipeline, AnimateDiffPipeline, DDIMScheduler, MotionAdapter
# --- Model Loading (Load outside the function for better performance) ---
# Check for CUDA availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load Prompt Enhancement Model
print("Loading Prompt Enhancement Model (T5)...")
tokenizer_t5 = T5Tokenizer.from_pretrained("t5-small")
model_t5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
print("T5 model loaded.")
# Load Image Generation Model
print("Loading Image Generation Model (Stable Diffusion 1.5)...")
pipe_sd = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
# Optional: Enable optimizations if using CUDA
if device == "cuda":
pipe_sd.enable_xformers_memory_efficient_attention()
pipe_sd.enable_vae_slicing()
pipe_sd.enable_cfashion_scaling() # Typo: Should be enable_cfashion_scaling - correcting in code
# Corrected:
# pipe_sd.enable_cfashion_scaling() # This method doesn't exist. Common optimizations are xformers, vae slicing, model CPU offload. Let's stick to standard ones.
# For SDXL specifically, you might use enable_model_cpu_offload()
print("Stable Diffusion 1.5 model loaded.")
# Load Animation Model (AnimateDiff)
print("Loading Animation Model (AnimateDiff)...")
# Load motion module
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32)
# Load base SD pipeline
pipe_anim = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device)
# Configure scheduler
pipe_anim.scheduler = DDIMScheduler.from_config(pipe_anim.scheduler.config, clip_sample=False, timestep_spacing="uniform")
# Optional: Enable optimizations if using CUDA
if device == "cuda":
pipe_anim.enable_xformers_memory_efficient_attention()
pipe_anim.enable_vae_slicing()
# pipe_anim.enable_model_cpu_offload() # Can be useful for memory, but slower if components are moved back and forth
print("AnimateDiff model loaded.")
# --- Function to run the pipeline ---
def process_prompt_and_generate(user_prompt, image_resolution, guidance_scale, seed, animation_frames, animation_style):
"""
Runs the AI pipeline using Hugging Face models.
It yields updates for the status and logs.
"""
logs = []
status = "Starting processing..."
# Yield initial state - Gradio expects all outputs to be present, even if empty
yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status
if not user_prompt:
logs.append("Error: No prompt provided.")
status = "Error: No prompt provided."
yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status
return
# Ensure seed is a positive integer, use random if -1
current_seed = seed if seed != -1 else random.randint(0, 100000000)
generator = torch.Generator(device=device).manual_seed(current_seed)
np.random.seed(current_seed) # Seed numpy too for any potential numpy randomness
# --- Step 1: Simulate Prompt Enhancement (using T5) ---
status = "Enhancing prompt (T5)..."
logs.append(f"User Prompt: '{user_prompt}'")
logs.append(f"Parameters: Resolution={image_resolution}, Guidance Scale={guidance_scale}, Seed={current_seed}, Frames={animation_frames}, Style={animation_style}")
yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Update parameters display early
start_time = time.time()
try:
input_text = f"enhance prompt: {user_prompt}" # T5-small enhancement prefix
input_ids = tokenizer_t5(input_text, return_tensors="pt").input_ids.to(device)
outputs = model_t5.generate(input_ids, max_length=64, num_beams=4, early_stopping=True) # Keep enhancement concise
enhanced_prompt = tokenizer_t5.decode(outputs[0], skip_special_tokens=True)
logs.append(f"Enhanced Prompt: '{enhanced_prompt}'")
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
except Exception as e:
logs.append(f"Error during prompt enhancement: {e}")
status = "Error during prompt enhancement."
yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
return
end_time = time.time()
logs.append(f"Prompt enhancement took {end_time - start_time:.2f} seconds.")
# --- Step 2: Simulate Image Generation (using Stable Diffusion) ---
status = "Generating image (Stable Diffusion)..."
logs.append(f"Generating initial image ({image_resolution}x{image_resolution}px)...")
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
start_time = time.time()
try:
# Generate the image
with torch.no_grad():
image = pipe_sd(
prompt=enhanced_prompt,
height=image_resolution,
width=image_resolution,
guidance_scale=guidance_scale,
generator=generator
).images[0]
# Save the image temporarily
# Gradio can handle PIL images directly, but saving to a temp file is also common
# Using tempfile for a robust approach
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
temp_image_path = tmpfile.name
image.save(temp_image_path)
logs.append(f"Image generated successfully: {temp_image_path}")
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
except Exception as e:
logs.append(f"Error during image generation: {e}")
status = "Error during image generation."
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
# Clean up temp file if it exists from a partial save
if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
os.remove(temp_image_path)
return
end_time = time.time()
logs.append(f"Image generation took {end_time - start_time:.2f} seconds.")
# --- Step 3: Simulate Animation (using AnimateDiff) ---
status = "Generating animation (AnimateDiff)..."
logs.append(f"Generating animation ({animation_frames} frames, style: {animation_style}). Note: 'Style' parameter currently doesn't directly control AnimateDiff output...") # Add note about style limitation
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
start_time = time.time()
try:
# Generate animation frames
# AnimateDiff takes text prompt and generates a sequence.
# The style parameter doesn't directly map to AnimateDiff options.
# We'll use the enhanced prompt and requested frames.
# Guidance scale might be applied differently or not at all depending on the pipeline implementation.
with torch.no_grad():
# The AnimateDiff pipeline often doesn't have image_resolution, guidance_scale,
# etc., parameters in the same way as text2image. It's primarily text-to-video.
# We'll use the enhanced prompt and num_frames.
# The height/width might default or need explicit setting if supported.
# Let's use default resolution for simplicity or check pipeline args.
# Assuming base SD resolution (512x512) if not explicitly supported/needed.
# The pipe_anim loaded is StableDiffusionPipeline with motion adapter, let's check its call signature.
# It should support most SD parameters.
animation_frames_list = pipe_anim(
prompt=enhanced_prompt,
negative_prompt=None, # Could add negative prompt if needed
num_frames=animation_frames,
guidance_scale=guidance_scale, # Use guidance scale if pipeline supports it
generator=generator,
# width=image_resolution, # AnimateDiff motion adapter might expect specific resolutions
# height=image_resolution, # Commented out for compatibility, using default
).frames
# Compile frames into a video
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
temp_video_path = tmpfile.name
# Use imageio to write video - requires ffmpeg or similar backend
# Ensure imageio can find a writer (like ffmpeg)
try:
imageio.mimwrite(temp_video_path, animation_frames_list, fps=8, quality=8) # Adjust fps and quality as needed
except Exception as ffmpeg_error:
logs.append(f"Error saving video with imageio/ffmpeg: {ffmpeg_error}")
logs.append("Ensure ffmpeg is installed and in your PATH, or use imageio.get_writer with a specific backend.")
status = "Error saving video."
# Attempt cleanup
if os.path.exists(temp_video_path):
os.remove(temp_video_path)
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
# Clean up temp image
if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
os.remove(temp_image_path)
return
logs.append(f"Animation generated successfully: {temp_video_path}")
yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
except Exception as e:
logs.append(f"Error during animation generation: {e}")
status = "Error during animation generation."
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status
# Clean up temp files
if 'temp_image_path' in locals() and os.path.exists(temp_image_path):
os.remove(temp_image_path)
if 'temp_video_path' in locals() and os.path.exists(temp_video_path):
os.remove(temp_video_path)
return
end_time = time.time()
logs.append(f"Animation generation took {end_time - start_time:.2f} seconds.")
# --- Finalizing Outputs ---
status = "Process complete!"
logs.append("All steps finished.")
# Ensure all outputs are returned in the final state (yielded)
# The last yield in a generator function provides the final values for Gradio
# Let's make the last yield explicitly contain all final values
yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, \
str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, \
"\n".join(logs), status
# --- Function to update the parameters display (called after main function) ---
def update_parameters_display(res, gs, seed, frames, style):
# This function remains the same, it just formats the strings passed from the main function
if not res: # Check if results exist (e.g., first yield is empty)
return ""
metadata = f"Resolution: {res}px\nGuidance Scale: {gs}\nSeed: {seed}\nFrames: {frames}\nStyle: {style}\n(Note: Animation Style may not directly control model output)" # Add note here too
return metadata
# --- Function to randomize seed ---
def randomize():
return random.randint(1, 100000000) # Generate a random seed
# --- Gradio UI Definition ---
# Choose a more modern theme
theme = gr.themes.Monochrome().set(
# Customize colors slightly for a softer look
# You can inspect theme objects and their attributes
# button_primary_background_fill="linear-gradient(to right, #6a11cb 0%, #2575fc 100%)", # Example gradient
# button_primary_color="white",
# button_secondary_background_fill="gray",
# spacing_size_lg="2rem" # Example spacing adjustment
)
# Use tempfile for a base temp directory managed by the app
temp_dir = tempfile.mkdtemp()
print(f"Using temporary directory: {temp_dir}")
# Set Gradio's temp dir if needed (often handled automatically)
# gr.processing_utils.TEMP_DIR = temp_dir # This might be needed in older Gradio versions or specific setups
with gr.Blocks(theme=theme, title="AI Creative Studio") as demo:
# --- Header Section ---
with gr.Row(variant="panel"): # Use a panel variant for distinct header background
with gr.Column(scale=1, min_width=100):
# Placeholder for a logo or icon
gr.Image(value="https://www.gradio.app/_app/immutable/assets/gradio.CHB5adID.svg",
label="Studio Logo",
show_label=False, # Hide the label below the image
height=80,
width=80,
container=False) # Prevent adding extra padding/margin around the image
with gr.Column(scale=4):
gr.Markdown(
"""
# 🎨 Multi-Step AI Creative Pipeline πŸš€
Unleash your imagination! Input a prompt, and our AI orchestrates a sequence:
Prompt Enhancement β†’ Image Generation β†’ Animation.
**Using free models from Hugging Face (T5, Stable Diffusion 1.5, AnimateDiff).**
*Note: 'Animation Style' parameter might not directly control the AnimateDiff model output.*
"""
)
gr.Markdown("---") # Separator
# --- Main Content Area (Input & Output side-by-side initially) ---
with gr.Row():
# --- Input & Controls Column ---
with gr.Column(scale=1):
gr.Markdown("## ✍️ Your Creative Input")
prompt_input = gr.TextArea(
label="Enter your prompt here:",
placeholder="e.g., A majestic dragon flying over snow-capped mountains at sunset",
lines=5,
interactive=True
)
gr.Examples(
["A cyberpunk street scene with neon lights", "A cozy cabin in a snowy forest, digital painting", "An astronaut riding a horse on the moon, surrealism"],
inputs=prompt_input
)
# Advanced Options (Collapsed)
with gr.Accordion("πŸ› οΈ Advanced Settings", open=False):
gr.Markdown("Configure specific parameters for generation.")
with gr.Row():
image_resolution = gr.Slider(label="Image Resolution (px)", minimum=256, maximum=1024, value=512, step=128, interactive=True)
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.1, interactive=True)
with gr.Row():
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0, interactive=True)
animation_frames = gr.Slider(label="Animation Frames", minimum=10, maximum=100, value=40, step=5, interactive=True)
animation_style = gr.Radio(
label="Animation Style",
choices=["Zoom In", "Pan Left", "Rotate", "Swirl"],
value="Zoom In",
interactive=True
)
# Add a button to randomize seed easily
randomize_seed_button = gr.Button("🎲 Randomize Seed")
# Action Button
generate_button = gr.Button("✨ Generate Pipeline Results ✨", variant="primary")
# Live Status Indicator
status_display = gr.Textbox(label="Status", value="Ready", interactive=False, show_copy_button=False)
# --- Output & Results Column ---
with gr.Column(scale=2): # Make output column wider
gr.Markdown("## βœ… Generation Results")
# Row for prompts
with gr.Row():
original_prompt_output = gr.Textbox(label="Original Prompt Used", interactive=False, lines=3, scale=1, show_copy_button=True)
enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt (AI)", interactive=False, lines=3, scale=1, show_copy_button=True)
# Row for media
with gr.Row():
generated_image_output = gr.Image(label="Generated Image", interactive=False, height=450, show_share_button=True, type="filepath") # Specify type="filepath"
generated_animation_output = gr.Video(label="Generated Animation", interactive=False, height=450, show_share_button=True)
# Display Parameters Used (Collapsed or in a smaller section)
with gr.Accordion("πŸ”¬ Parameters Used", open=False): # Collapsible section for details
parameters_used_output = gr.Textbox(
label="Generation Parameters",
interactive=False,
lines=6, # Increased lines slightly to fit the note
max_lines=30,
show_copy_button=True
)
# Dummy output components to catch the individual parameters
# We will combine them in the process_prompt_and_generate function for the Textbox above
res_out = gr.Textbox(visible=False, type="value")
gs_out = gr.Textbox(visible=False, type="value")
seed_out = gr.Textbox(visible=False, type="value")
frames_out = gr.Textbox(visible=False, type="value")
style_out = gr.Textbox(visible=False, type="value")
# Download Buttons (Placeholder)
gr.Markdown("### Download Results")
with gr.Row():
# These buttons are just placeholders for now.
# Real download logic needs separate functions.
# Making them interactive=False as they don't have click events linked
download_image_button = gr.Button("⬇️ Download Image", interactive=False)
download_video_button = gr.Button("⬇️ Download Video", interactive=False)
gr.Markdown("---") # Separator
# --- Logs and Debug Information ---
with gr.Accordion("βš™οΈ Processing Logs & Debug Info", open=False):
logs_output = gr.Textbox(
label="Detailed Logs",
interactive=False,
lines=15, # More lines for detailed logs
max_lines=30,
show_copy_button=True,
# Add some visual cues for logs
container=True # Gives it a distinct container style
)
# --- Define Interactions ---
# Button click triggers the main processing function
# The outputs list maps the function's return values to UI components
# Because process_prompt_and_generate is a generator, Gradio updates the outputs
# with each yielded value. The final yield provides the final state.
generate_button.click(
fn=process_prompt_and_generate,
inputs=[
prompt_input,
image_resolution,
guidance_scale,
seed,
animation_frames,
animation_style
],
outputs=[
original_prompt_output,
enhanced_prompt_output,
generated_image_output,
generated_animation_output,
res_out, # Catch individual params to reconstruct metadata
gs_out,
seed_out,
frames_out,
style_out,
logs_output, # Logs are updated incrementally/finally
status_display # Status is updated incrementally/finally
],
api_name="generate" # Optional: Add an API name for easy calling
).success( # Chain the parameter update after the main process finishes successfully
fn=update_parameters_display,
inputs=[res_out, gs_out, seed_out, frames_out, style_out],
outputs=[parameters_used_output]
)
# Randomize Seed Button Interaction
randomize_seed_button.click(
fn=randomize,
inputs=[],
outputs=[seed] # Update the seed number input field
)
# --- Launch the App ---
if __name__ == "__main__":
print("Gradio AI Creative Studio is starting...")
# Use share=True to make it accessible over the internet (for testing)
# Use inbrowser=True to auto-open the browser
demo.launch(inbrowser=True)
print("App launched!")
# Optional: Clean up the temporary directory when the app stops
# This is not automatically called when you Ctrl+C, but useful in some deployment scenarios
# import shutil
# shutil.rmtree(temp_dir)
# print(f"Cleaned up temporary directory: {temp_dir}")