|
import gradio as gr |
|
import time |
|
import random |
|
import torch |
|
import numpy as np |
|
from PIL import Image |
|
import imageio |
|
import tempfile |
|
import os |
|
|
|
|
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
from diffusers import StableDiffusionPipeline, AnimateDiffPipeline, DDIMScheduler, MotionAdapter |
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
print("Loading Prompt Enhancement Model (T5)...") |
|
tokenizer_t5 = T5Tokenizer.from_pretrained("t5-small") |
|
model_t5 = T5ForConditionalGeneration.from_pretrained("t5-small").to(device) |
|
print("T5 model loaded.") |
|
|
|
|
|
print("Loading Image Generation Model (Stable Diffusion 1.5)...") |
|
pipe_sd = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) |
|
|
|
if device == "cuda": |
|
pipe_sd.enable_xformers_memory_efficient_attention() |
|
pipe_sd.enable_vae_slicing() |
|
pipe_sd.enable_cfashion_scaling() |
|
|
|
|
|
|
|
|
|
print("Stable Diffusion 1.5 model loaded.") |
|
|
|
|
|
|
|
print("Loading Animation Model (AnimateDiff)...") |
|
|
|
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5", torch_dtype=torch.float16 if device == "cuda" else torch.float32) |
|
|
|
pipe_anim = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, torch_dtype=torch.float16 if device == "cuda" else torch.float32).to(device) |
|
|
|
pipe_anim.scheduler = DDIMScheduler.from_config(pipe_anim.scheduler.config, clip_sample=False, timestep_spacing="uniform") |
|
|
|
if device == "cuda": |
|
pipe_anim.enable_xformers_memory_efficient_attention() |
|
pipe_anim.enable_vae_slicing() |
|
|
|
|
|
print("AnimateDiff model loaded.") |
|
|
|
|
|
def process_prompt_and_generate(user_prompt, image_resolution, guidance_scale, seed, animation_frames, animation_style): |
|
""" |
|
Runs the AI pipeline using Hugging Face models. |
|
It yields updates for the status and logs. |
|
""" |
|
logs = [] |
|
status = "Starting processing..." |
|
|
|
yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status |
|
|
|
if not user_prompt: |
|
logs.append("Error: No prompt provided.") |
|
status = "Error: No prompt provided." |
|
yield user_prompt, "", None, None, "", "", "", "", "", "\n".join(logs), status |
|
return |
|
|
|
|
|
current_seed = seed if seed != -1 else random.randint(0, 100000000) |
|
generator = torch.Generator(device=device).manual_seed(current_seed) |
|
np.random.seed(current_seed) |
|
|
|
|
|
status = "Enhancing prompt (T5)..." |
|
logs.append(f"User Prompt: '{user_prompt}'") |
|
logs.append(f"Parameters: Resolution={image_resolution}, Guidance Scale={guidance_scale}, Seed={current_seed}, Frames={animation_frames}, Style={animation_style}") |
|
yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
start_time = time.time() |
|
|
|
try: |
|
input_text = f"enhance prompt: {user_prompt}" |
|
input_ids = tokenizer_t5(input_text, return_tensors="pt").input_ids.to(device) |
|
outputs = model_t5.generate(input_ids, max_length=64, num_beams=4, early_stopping=True) |
|
enhanced_prompt = tokenizer_t5.decode(outputs[0], skip_special_tokens=True) |
|
logs.append(f"Enhanced Prompt: '{enhanced_prompt}'") |
|
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
except Exception as e: |
|
logs.append(f"Error during prompt enhancement: {e}") |
|
status = "Error during prompt enhancement." |
|
yield user_prompt, "", None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
return |
|
end_time = time.time() |
|
logs.append(f"Prompt enhancement took {end_time - start_time:.2f} seconds.") |
|
|
|
|
|
status = "Generating image (Stable Diffusion)..." |
|
logs.append(f"Generating initial image ({image_resolution}x{image_resolution}px)...") |
|
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
with torch.no_grad(): |
|
image = pipe_sd( |
|
prompt=enhanced_prompt, |
|
height=image_resolution, |
|
width=image_resolution, |
|
guidance_scale=guidance_scale, |
|
generator=generator |
|
).images[0] |
|
|
|
|
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile: |
|
temp_image_path = tmpfile.name |
|
image.save(temp_image_path) |
|
|
|
logs.append(f"Image generated successfully: {temp_image_path}") |
|
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
except Exception as e: |
|
logs.append(f"Error during image generation: {e}") |
|
status = "Error during image generation." |
|
yield user_prompt, enhanced_prompt, None, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
|
|
if 'temp_image_path' in locals() and os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
return |
|
end_time = time.time() |
|
logs.append(f"Image generation took {end_time - start_time:.2f} seconds.") |
|
|
|
|
|
|
|
status = "Generating animation (AnimateDiff)..." |
|
logs.append(f"Generating animation ({animation_frames} frames, style: {animation_style}). Note: 'Style' parameter currently doesn't directly control AnimateDiff output...") |
|
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
start_time = time.time() |
|
|
|
try: |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
animation_frames_list = pipe_anim( |
|
prompt=enhanced_prompt, |
|
negative_prompt=None, |
|
num_frames=animation_frames, |
|
guidance_scale=guidance_scale, |
|
generator=generator, |
|
|
|
|
|
).frames |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
|
temp_video_path = tmpfile.name |
|
|
|
|
|
try: |
|
imageio.mimwrite(temp_video_path, animation_frames_list, fps=8, quality=8) |
|
except Exception as ffmpeg_error: |
|
logs.append(f"Error saving video with imageio/ffmpeg: {ffmpeg_error}") |
|
logs.append("Ensure ffmpeg is installed and in your PATH, or use imageio.get_writer with a specific backend.") |
|
status = "Error saving video." |
|
|
|
if os.path.exists(temp_video_path): |
|
os.remove(temp_video_path) |
|
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
|
|
if 'temp_image_path' in locals() and os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
return |
|
|
|
|
|
logs.append(f"Animation generated successfully: {temp_video_path}") |
|
yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
except Exception as e: |
|
logs.append(f"Error during animation generation: {e}") |
|
status = "Error during animation generation." |
|
yield user_prompt, enhanced_prompt, temp_image_path, None, str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, "\n".join(logs), status |
|
|
|
if 'temp_image_path' in locals() and os.path.exists(temp_image_path): |
|
os.remove(temp_image_path) |
|
if 'temp_video_path' in locals() and os.path.exists(temp_video_path): |
|
os.remove(temp_video_path) |
|
return |
|
end_time = time.time() |
|
logs.append(f"Animation generation took {end_time - start_time:.2f} seconds.") |
|
|
|
|
|
status = "Process complete!" |
|
logs.append("All steps finished.") |
|
|
|
|
|
|
|
|
|
yield user_prompt, enhanced_prompt, temp_image_path, temp_video_path, \ |
|
str(image_resolution), str(guidance_scale), str(current_seed), str(animation_frames), animation_style, \ |
|
"\n".join(logs), status |
|
|
|
|
|
def update_parameters_display(res, gs, seed, frames, style): |
|
|
|
if not res: |
|
return "" |
|
metadata = f"Resolution: {res}px\nGuidance Scale: {gs}\nSeed: {seed}\nFrames: {frames}\nStyle: {style}\n(Note: Animation Style may not directly control model output)" |
|
return metadata |
|
|
|
|
|
def randomize(): |
|
return random.randint(1, 100000000) |
|
|
|
|
|
|
|
|
|
theme = gr.themes.Monochrome().set( |
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
temp_dir = tempfile.mkdtemp() |
|
print(f"Using temporary directory: {temp_dir}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=theme, title="AI Creative Studio") as demo: |
|
|
|
|
|
with gr.Row(variant="panel"): |
|
with gr.Column(scale=1, min_width=100): |
|
|
|
gr.Image(value="https://www.gradio.app/_app/immutable/assets/gradio.CHB5adID.svg", |
|
label="Studio Logo", |
|
show_label=False, |
|
height=80, |
|
width=80, |
|
container=False) |
|
|
|
with gr.Column(scale=4): |
|
gr.Markdown( |
|
""" |
|
# π¨ Multi-Step AI Creative Pipeline π |
|
Unleash your imagination! Input a prompt, and our AI orchestrates a sequence: |
|
Prompt Enhancement β Image Generation β Animation. |
|
**Using free models from Hugging Face (T5, Stable Diffusion 1.5, AnimateDiff).** |
|
*Note: 'Animation Style' parameter might not directly control the AnimateDiff model output.* |
|
""" |
|
) |
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Row(): |
|
|
|
|
|
with gr.Column(scale=1): |
|
gr.Markdown("## βοΈ Your Creative Input") |
|
prompt_input = gr.TextArea( |
|
label="Enter your prompt here:", |
|
placeholder="e.g., A majestic dragon flying over snow-capped mountains at sunset", |
|
lines=5, |
|
interactive=True |
|
) |
|
gr.Examples( |
|
["A cyberpunk street scene with neon lights", "A cozy cabin in a snowy forest, digital painting", "An astronaut riding a horse on the moon, surrealism"], |
|
inputs=prompt_input |
|
) |
|
|
|
|
|
with gr.Accordion("π οΈ Advanced Settings", open=False): |
|
gr.Markdown("Configure specific parameters for generation.") |
|
|
|
with gr.Row(): |
|
image_resolution = gr.Slider(label="Image Resolution (px)", minimum=256, maximum=1024, value=512, step=128, interactive=True) |
|
guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, value=7.0, step=0.1, interactive=True) |
|
|
|
with gr.Row(): |
|
seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0, interactive=True) |
|
animation_frames = gr.Slider(label="Animation Frames", minimum=10, maximum=100, value=40, step=5, interactive=True) |
|
|
|
animation_style = gr.Radio( |
|
label="Animation Style", |
|
choices=["Zoom In", "Pan Left", "Rotate", "Swirl"], |
|
value="Zoom In", |
|
interactive=True |
|
) |
|
|
|
randomize_seed_button = gr.Button("π² Randomize Seed") |
|
|
|
|
|
generate_button = gr.Button("β¨ Generate Pipeline Results β¨", variant="primary") |
|
|
|
|
|
status_display = gr.Textbox(label="Status", value="Ready", interactive=False, show_copy_button=False) |
|
|
|
|
|
|
|
with gr.Column(scale=2): |
|
gr.Markdown("## β
Generation Results") |
|
|
|
|
|
with gr.Row(): |
|
original_prompt_output = gr.Textbox(label="Original Prompt Used", interactive=False, lines=3, scale=1, show_copy_button=True) |
|
enhanced_prompt_output = gr.Textbox(label="Enhanced Prompt (AI)", interactive=False, lines=3, scale=1, show_copy_button=True) |
|
|
|
|
|
with gr.Row(): |
|
generated_image_output = gr.Image(label="Generated Image", interactive=False, height=450, show_share_button=True, type="filepath") |
|
generated_animation_output = gr.Video(label="Generated Animation", interactive=False, height=450, show_share_button=True) |
|
|
|
|
|
|
|
with gr.Accordion("π¬ Parameters Used", open=False): |
|
parameters_used_output = gr.Textbox( |
|
label="Generation Parameters", |
|
interactive=False, |
|
lines=6, |
|
max_lines=30, |
|
show_copy_button=True |
|
) |
|
|
|
|
|
res_out = gr.Textbox(visible=False, type="value") |
|
gs_out = gr.Textbox(visible=False, type="value") |
|
seed_out = gr.Textbox(visible=False, type="value") |
|
frames_out = gr.Textbox(visible=False, type="value") |
|
style_out = gr.Textbox(visible=False, type="value") |
|
|
|
|
|
|
|
gr.Markdown("### Download Results") |
|
with gr.Row(): |
|
|
|
|
|
|
|
download_image_button = gr.Button("β¬οΈ Download Image", interactive=False) |
|
download_video_button = gr.Button("β¬οΈ Download Video", interactive=False) |
|
|
|
gr.Markdown("---") |
|
|
|
|
|
with gr.Accordion("βοΈ Processing Logs & Debug Info", open=False): |
|
logs_output = gr.Textbox( |
|
label="Detailed Logs", |
|
interactive=False, |
|
lines=15, |
|
max_lines=30, |
|
show_copy_button=True, |
|
|
|
container=True |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
generate_button.click( |
|
fn=process_prompt_and_generate, |
|
inputs=[ |
|
prompt_input, |
|
image_resolution, |
|
guidance_scale, |
|
seed, |
|
animation_frames, |
|
animation_style |
|
], |
|
outputs=[ |
|
original_prompt_output, |
|
enhanced_prompt_output, |
|
generated_image_output, |
|
generated_animation_output, |
|
res_out, |
|
gs_out, |
|
seed_out, |
|
frames_out, |
|
style_out, |
|
logs_output, |
|
status_display |
|
], |
|
api_name="generate" |
|
).success( |
|
fn=update_parameters_display, |
|
inputs=[res_out, gs_out, seed_out, frames_out, style_out], |
|
outputs=[parameters_used_output] |
|
) |
|
|
|
|
|
|
|
randomize_seed_button.click( |
|
fn=randomize, |
|
inputs=[], |
|
outputs=[seed] |
|
) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
print("Gradio AI Creative Studio is starting...") |
|
|
|
|
|
demo.launch(inbrowser=True) |
|
print("App launched!") |
|
|
|
|
|
|
|
|
|
|
|
|