import gradio as gr import time import random import torch import numpy as np from PIL import Image import imageio # For saving video import tempfile # For creating temporary files import os # --- Hugging Face Model Imports --- from transformers import T5ForConditionalGeneration, T5Tokenizer from diffusers import StableDiffusionPipeline, AnimateDiffPipeline, DDIMScheduler, MotionAdapter # --- Model Loading (Load outside the function for better performance) --- # Check for CUDA availability device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load Prompt Enhancement Model print("Loading Prompt Enhancement Model (T5)...") tokenizer_t5 = T5Tokenizer.from_pretrained("t5-small") model_t5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(device) print("T5 model loaded.") # Load Image Generation Model print("Loading Image Generation Model (Stable Diffusion 1.5)...") pipe_sd = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) # Optional: Enable optimizations if using CUDA if device == "cuda": pipe_sd.enable_xformers_memory_efficient_attention() pipe_sd.enable_vae_slicing() pipe_sd.enable_cfashion_scaling() # Typo: Should be enable_cfashion_scaling - correcting in code # Corrected: # pipe_sd.enable_cfashion_scaling() # This method doesn't exist. Common optimizations are xformers, vae slicing, model CPU offload. Let's stick to standard ones. # For SDXL specifically, you might use enable_model_cpu_offload() print("Stable Diffusion 1.5 model loaded.") # Load Animation Model (AnimateDiff) print("Loading Animation Model (AnimateDiff)...") # Load motion module adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32) # Load base SD pipeline pipe_anim = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) # Configure scheduler pipe_anim.scheduler = DDIMScheduler.from_config(pipe_anim.scheduler.config, clip_sample=False, timestep_spacing="uniform") # Optional: Enable optimizations if using CUDA if device == "cuda": pipe_anim.enable_xformers_memory_efficient_attention() pipe_anim.enable_vae_slicing() # pipe_anim.enable_model_cpu_offload() # Can be useful for memory, but slower if components are moved back and forth print("AnimateDiff model loaded.") # --- Function to run the pipeline --- def process_prompt_and_generate(user_prompt, image_resolution, guidance_scale, seed, animation_frames, animation_style): """ Runs the AI pipeline using Hugging Face models. It yields updates for the status and logs. """ logs = [] status = "Starting processing..." # Yield initial state - Gradio expects all outputs to be present, even if empty yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status if not user_prompt: logs.append("Error: No prompt provided.") status = "Error: No prompt provided." yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status return # Ensure seed is a positive integer, use random if -1 current_seed = seed if seed != -1 else random.randint(0, 100000000) generator = torch.Generator(device=device).manual_seed(current_seed) np.random.seed(current_seed) # Seed numpy too for any potential numpy randomness # --- Step 1: Simulate Prompt Enhancement (using T5) --- status = "Enhancing prompt (T5)..." logs.append(f"User Prompt: '{user_prompt}'") logs.append(f"Parameters: Resolution={image_resolution}, Guidance Scale={guidance_scale}, Seed={current_seed}, Frames={animation_frames}, Style={animation_style}") yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Update parameters display early start_time = time.time() try: input_text = f"enhance prompt: {user_prompt}" # T5-small enhancement prefix input_ids = tokenizer_t5(input_text, return_tensors="pt").input_ids.to(device) outputs = model_t5.generate(input_ids, max_length=64, num_beams=4, early_stopping=True) # Keep enhancement concise enhanced_prompt = tokenizer_t5.decode(outputs[0], skip_special_tokens=True) logs.append(f"Enhanced Prompt: '{enhanced_prompt}'") yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status except Exception as e: logs.append(f"Error during prompt enhancement: {e}") status = "Error during prompt enhancement." yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status return end_time = time.time() logs.append(f"Prompt enhancement took {end_time - start_time:.2f} seconds.") # --- Step 2: Simulate Image Generation (using Stable Diffusion) --- status = "Generating image (Stable Diffusion)..." logs.append(f"Generating initial image ({image_resolution}x{image_resolution}px)...") yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status start_time = time.time() try: # Generate the image with torch.no_grad(): image = pipe_sd( prompt=enhanced_prompt, height=image_resolution, width=image_resolution, guidance_scale=guidance_scale, generator=generator ).images[0] # Save the image temporarily # Gradio can handle PIL images directly, but saving to a temp file is also common # Using tempfile for a robust approach with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: temp_image_path = tmpfile.name image.save(temp_image_path) logs.append(f"Image generated successfully: {temp_image_path}") yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status except Exception as e: logs.append(f"Error during image generation: {e}") status = "Error during image generation." yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Clean up temp file if it exists from a partial save if 'temp_image_path' in locals() and os.path.exists(temp_image_path): os.remove(temp_image_path) return end_time = time.time() logs.append(f"Image generation took {end_time - start_time:.2f} seconds.") # --- Step 3: Simulate Animation (using AnimateDiff) --- status = "Generating animation (AnimateDiff)..." logs.append(f"Generating animation ({animation_frames} frames, style: {animation_style}). Note: 'Style' parameter currently doesn't directly control AnimateDiff output...") # Add note about style limitation yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status start_time = time.time() try: # Generate animation frames # AnimateDiff takes text prompt and generates a sequence. # The style parameter doesn't directly map to AnimateDiff options. # We'll use the enhanced prompt and requested frames. # Guidance scale might be applied differently or not at all depending on the pipeline implementation. with torch.no_grad(): # The AnimateDiff pipeline often doesn't have image_resolution, guidance_scale, # etc., parameters in the same way as text2image. It's primarily text-to-video. # We'll use the enhanced prompt and num_frames. # The height/width might default or need explicit setting if supported. # Let's use default resolution for simplicity or check pipeline args. # Assuming base SD resolution (512x512) if not explicitly supported/needed. # The pipe_anim loaded is StableDiffusionPipeline with motion adapter, let's check its call signature. # It should support most SD parameters. animation_frames_list = pipe_anim( prompt=enhanced_prompt, negative_prompt=None, # Could add negative prompt if needed num_frames=animation_frames, guidance_scale=guidance_scale, # Use guidance scale if pipeline supports it generator=generator, # width=image_resolution, # AnimateDiff motion adapter might expect specific resolutions # height=image_resolution, # Commented out for compatibility, using default ).frames # Compile frames into a video with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: temp_video_path = tmpfile.name # Use imageio to write video - requires ffmpeg or similar backend # Ensure imageio can find a writer (like ffmpeg) try: imageio.mimwrite(temp_video_path, animation_frames_list, fps=8, quality=8) # Adjust fps and quality as needed except Exception as ffmpeg_error: logs.append(f"Error saving video with imageio/ffmpeg: {ffmpeg_error}") logs.append("Ensure ffmpeg is installed and in your PATH, or use imageio.get_writer with a specific backend.") status = "Error saving video." # Attempt cleanup if os.path.exists(temp_video_path): os.remove(temp_video_path) yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Clean up temp image if 'temp_image_path' in locals() and os.path.exists(temp_image_path): os.remove(temp_image_path) return logs.append(f"Animation generated successfully: {temp_video_path}") yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status except Exception as e: logs.append(f"Error during animation generation: {e}") status = "Error during animation generation." yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status # Clean up temp files if 'temp_image_path' in locals() and os.path.exists(temp_image_path): os.remove(temp_image_path) if 'temp_video_path' in locals() and os.path.exists(temp_video_path): os.remove(temp_video_path) return end_time = time.time() logs.append(f"Animation generation took {end_time - start_time:.2f} seconds.") # --- Finalizing Outputs --- status = "Process complete!" logs.append("All steps finished.") # Ensure all outputs are returned in the final state (yielded) # The last yield in a generator function provides the final values for Gradio # Let's make the last yield explicitly contain all final values yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, \ str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, \ "\n".join(logs), status # --- Function to update the parameters display (called after main function) --- def update_parameters_display(res, gs, seed, frames, style): # This function remains the same, it just formats the strings passed from the main function if not res: # Check if results exist (e.g., first yield is empty) return "" metadata = f"Resolution: {res}px\nGuidance Scale: {gs}\nSeed: {seed}\nFrames: {frames}\nStyle: {style}\n(Note: Animation Style may not directly control model output)" # Add note here too return metadata # --- Function to randomize seed --- def randomize(): return random.randint(1, 100000000) # Generate a random seed # --- Gradio UI Definition --- # Choose a more modern theme theme = gr.themes.Monochrome().set( # Customize colors slightly for a softer look # You can inspect theme objects and their attributes # button_primary_background_fill="linear-gradient(to right, #6a11cb 0%, #2575fc 100%)", # Example gradient # button_primary_color="white", # button_secondary_background_fill="gray", # spacing_size_lg="2rem" # Example spacing adjustment ) # Use tempfile for a base temp directory managed by the app temp_dir = tempfile.mkdtemp() print(f"Using temporary directory: {temp_dir}") # Set Gradio's temp dir if needed (often handled automatically) # gr.processing_utils.TEMP_DIR = temp_dir # This might be needed in older Gradio versions or specific setups with gr.Blocks(theme=theme, title="AI Creative Studio") as demo: # --- Header Section --- with gr.Row(variant="panel"): # Use a panel variant for distinct header background with gr.Column(scale=1, min_width=100): # Placeholder for a logo or icon gr.Image(value="https://www.gradio.app/_app/immutable/assets/gradio.CHB5adID.svg", label="Studio Logo", show_label=False, # Hide the label below the image height=80, width=80, container=False) # Prevent adding extra padding/margin around the image with gr.Column(scale=4): gr.Markdown( """ # 🎨 Multi-Step AI Creative Pipeline 🚀 Unleash your imagination! Input a prompt, and our AI orchestrates a sequence: Prompt Enhancement → Image Generation → Animation. **Using free models from Hugging Face (T5, Stable Diffusion 1.5, AnimateDiff).** *Note: 'Animation Style' parameter might not directly control the AnimateDiff model output.* """ ) gr.Markdown("---") # Separator # --- Main Content Area (Input & Output side-by-side initially) --- with gr.Row(): # --- Input & Controls Column --- with gr.Column(scale=1): gr.Markdown("## ✍️ Your Creative Input") prompt_input = gr.TextArea( label="Enter your prompt here:", placeholder="e.g., A majestic dragon flying over snow-capped mountains at sunset", lines=5, interactive=True ) gr.Examples( ["A cyberpunk street scene with neon lights", "A cozy cabin in a snowy forest, digital painting", "An astronaut riding a horse on the moon, surrealism"], inputs=prompt_input ) # Advanced Options (Collapsed) with gr.Accordion("🛠️ Advanced Settings", open=False): gr.Markdown("Configure specific parameters for generation.") with gr.Row(): image_resolution = gr.Slider(label="Image Resolution (px)", minimum=256, maximum=1024, value=512, step=128, interactive=True) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.1, interactive=True) with gr.Row(): seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0, interactive=True) animation_frames = gr.Slider(label="Animation Frames", minimum=10, maximum=100, value=40, step=5, interactive=True) animation_style = gr.Radio( label="Animation Style", choices=["Zoom In", "Pan Left", "Rotate", "Swirl"], value="Zoom In", interactive=True ) # Add a button to randomize seed easily randomize_seed_button = gr.Button("🎲 Randomize Seed") # Action Button generate_button = gr.Button("✨ Generate Pipeline Results ✨", variant="primary") # Live Status Indicator status_display = gr.Textbox(label="Status", value="Ready", interactive=False, show_copy_button=False) # --- Output & Results Column --- with gr.Column(scale=2): # Make output column wider gr.Markdown("## ✅ Generation Results") # Row for prompts with gr.Row(): original_prompt_output = gr.Textbox(label="Original Prompt Used", interactive=False, lines=3, scale=1, show_copy_button=True) enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt (AI)", interactive=False, lines=3, scale=1, show_copy_button=True) # Row for media with gr.Row(): generated_image_output = gr.Image(label="Generated Image", interactive=False, height=450, show_share_button=True, type="filepath") # Specify type="filepath" generated_animation_output = gr.Video(label="Generated Animation", interactive=False, height=450, show_share_button=True) # Display Parameters Used (Collapsed or in a smaller section) with gr.Accordion("🔬 Parameters Used", open=False): # Collapsible section for details parameters_used_output = gr.Textbox( label="Generation Parameters", interactive=False, lines=6, # Increased lines slightly to fit the note max_lines=30, show_copy_button=True ) # Dummy output components to catch the individual parameters # We will combine them in the process_prompt_and_generate function for the Textbox above res_out = gr.Textbox(visible=False, type="value") gs_out = gr.Textbox(visible=False, type="value") seed_out = gr.Textbox(visible=False, type="value") frames_out = gr.Textbox(visible=False, type="value") style_out = gr.Textbox(visible=False, type="value") # Download Buttons (Placeholder) gr.Markdown("### Download Results") with gr.Row(): # These buttons are just placeholders for now. # Real download logic needs separate functions. # Making them interactive=False as they don't have click events linked download_image_button = gr.Button("⬇️ Download Image", interactive=False) download_video_button = gr.Button("⬇️ Download Video", interactive=False) gr.Markdown("---") # Separator # --- Logs and Debug Information --- with gr.Accordion("⚙️ Processing Logs & Debug Info", open=False): logs_output = gr.Textbox( label="Detailed Logs", interactive=False, lines=15, # More lines for detailed logs max_lines=30, show_copy_button=True, # Add some visual cues for logs container=True # Gives it a distinct container style ) # --- Define Interactions --- # Button click triggers the main processing function # The outputs list maps the function's return values to UI components # Because process_prompt_and_generate is a generator, Gradio updates the outputs # with each yielded value. The final yield provides the final state. generate_button.click( fn=process_prompt_and_generate, inputs=[ prompt_input, image_resolution, guidance_scale, seed, animation_frames, animation_style ], outputs=[ original_prompt_output, enhanced_prompt_output, generated_image_output, generated_animation_output, res_out, # Catch individual params to reconstruct metadata gs_out, seed_out, frames_out, style_out, logs_output, # Logs are updated incrementally/finally status_display # Status is updated incrementally/finally ], api_name="generate" # Optional: Add an API name for easy calling ).success( # Chain the parameter update after the main process finishes successfully fn=update_parameters_display, inputs=[res_out, gs_out, seed_out, frames_out, style_out], outputs=[parameters_used_output] ) # Randomize Seed Button Interaction randomize_seed_button.click( fn=randomize, inputs=[], outputs=[seed] # Update the seed number input field ) # --- Launch the App --- if __name__ == "__main__": print("Gradio AI Creative Studio is starting...") # Use share=True to make it accessible over the internet (for testing) # Use inbrowser=True to auto-open the browser demo.launch(inbrowser=True) print("App launched!") # Optional: Clean up the temporary directory when the app stops # This is not automatically called when you Ctrl+C, but useful in some deployment scenarios # import shutil # shutil.rmtree(temp_dir) # print(f"Cleaned up temporary directory: {temp_dir}")