seawolf2357's picture
Update app.py
8b98825 verified
raw
history blame
20.8 kB
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
import gradio as gr
import tempfile
import spaces
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import random
import logging
import gc
import time
import hashlib
from dataclasses import dataclass
from typing import Optional, Tuple
from functools import wraps
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 설정 관리
@dataclass
class VideoGenerationConfig:
model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
lora_repo_id: str = "Kijai/WanVideo_comfy"
lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
mod_value: int = 32
default_height: int = 512
default_width: int = 896
max_area: float = 480.0 * 832.0
slider_min_h: int = 128
slider_max_h: int = 896
slider_min_w: int = 128
slider_max_w: int = 896
fixed_fps: int = 24
min_frames: int = 8
max_frames: int = 81
default_prompt: str = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt: str = "static, blurred, low quality, watermark, text"
config = VideoGenerationConfig()
MAX_SEED = np.iinfo(np.int32).max
# 성능 측정 데코레이터
def measure_time(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
logger.info(f"{func.__name__} took {time.time()-start:.2f}s")
return result
return wrapper
# 모델 관리자
class ModelManager:
def __init__(self):
self._pipe = None
self._is_loaded = False
@property
def pipe(self):
if not self._is_loaded:
self._load_model()
return self._pipe
@measure_time
def _load_model(self):
logger.info("Loading model...")
image_encoder = CLIPVisionModel.from_pretrained(
config.model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(
config.model_id, subfolder="vae", torch_dtype=torch.float32
)
self._pipe = WanImageToVideoPipeline.from_pretrained(
config.model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16
)
self._pipe.scheduler = UniPCMultistepScheduler.from_config(
self._pipe.scheduler.config, flow_shift=8.0
)
self._pipe.to("cuda")
causvid_path = hf_hub_download(
repo_id=config.lora_repo_id, filename=config.lora_filename
)
self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
self._pipe.fuse_lora()
self._is_loaded = True
logger.info("Model loaded successfully")
model_manager = ModelManager()
# 비디오 생성기 클래스
class VideoGenerator:
def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager):
self.config = config
self.model_manager = model_manager
def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]:
orig_w, orig_h = image.size
if orig_w <= 0 or orig_h <= 0:
return self.config.default_height, self.config.default_width
aspect_ratio = orig_h / orig_w
calc_h = round(np.sqrt(self.config.max_area * aspect_ratio))
calc_w = round(np.sqrt(self.config.max_area / aspect_ratio))
calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value)
calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value)
new_h = int(np.clip(calc_h, self.config.slider_min_h,
(self.config.slider_max_h // self.config.mod_value) * self.config.mod_value))
new_w = int(np.clip(calc_w, self.config.slider_min_w,
(self.config.slider_max_w // self.config.mod_value) * self.config.mod_value))
return new_h, new_w
def validate_inputs(self, image: Image.Image, prompt: str, height: int,
width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]:
if image is None:
return False, "🖼️ Please upload an input image"
if not prompt or len(prompt.strip()) == 0:
return False, "✍️ Please provide a prompt"
if len(prompt) > 500:
return False, "⚠️ Prompt is too long (max 500 characters)"
if duration < self.config.min_frames / self.config.fixed_fps:
return False, f"⏱️ Duration too short (min {self.config.min_frames/self.config.fixed_fps:.1f}s)"
if duration > self.config.max_frames / self.config.fixed_fps:
return False, f"⏱️ Duration too long (max {self.config.max_frames/self.config.fixed_fps:.1f}s)"
return True, None
def generate_unique_filename(self, seed: int) -> str:
timestamp = int(time.time())
unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}"
hash_obj = hashlib.md5(unique_str.encode())
return f"video_{hash_obj.hexdigest()[:8]}.mp4"
video_generator = VideoGenerator(config, model_manager)
# Gradio 함수들
def handle_image_upload(image):
if image is None:
return gr.update(value=config.default_height), gr.update(value=config.default_width)
try:
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format")
new_h, new_w = video_generator.calculate_dimensions(image)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception as e:
logger.error(f"Error processing image: {e}")
gr.Warning("⚠️ Error processing image")
return gr.update(value=config.default_height), gr.update(value=config.default_width)
def get_duration(input_image, prompt, height, width, negative_prompt,
duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
if steps > 4 and duration_seconds > 2:
return 90
elif steps > 4 or duration_seconds > 2:
return 75
else:
return 60
@spaces.GPU(duration=get_duration)
@measure_time
def generate_video(input_image, prompt, height, width,
negative_prompt=config.default_negative_prompt,
duration_seconds=2, guidance_scale=1, steps=4,
seed=42, randomize_seed=False,
progress=gr.Progress(track_tqdm=True)):
progress(0.1, desc="🔍 Validating inputs...")
# 입력 검증
is_valid, error_msg = video_generator.validate_inputs(
input_image, prompt, height, width, duration_seconds, steps
)
if not is_valid:
raise gr.Error(error_msg)
try:
progress(0.2, desc="🎯 Preparing image...")
target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value)
target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value)
num_frames = np.clip(int(round(duration_seconds * config.fixed_fps)),
config.min_frames, config.max_frames)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
progress(0.3, desc="🎨 Loading model...")
pipe = model_manager.pipe
progress(0.4, desc="🎬 Generating video frames...")
with torch.inference_mode():
output_frames_list = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=target_h,
width=target_w,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed)
).frames[0]
progress(0.9, desc="💾 Saving video...")
filename = video_generator.generate_unique_filename(current_seed)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=config.fixed_fps)
progress(1.0, desc="✨ Complete!")
return video_path, current_seed
finally:
# 메모리 정리
if 'output_frames_list' in locals():
del output_frames_list
gc.collect()
torch.cuda.empty_cache()
# CSS 스타일
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 40px;
border-radius: 20px;
color: white;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
}
.header h1 {
font-size: 3em;
margin-bottom: 10px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
.header p {
font-size: 1.2em;
opacity: 0.95;
}
.main-content {
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
padding: 30px;
box-shadow: 0 5px 20px rgba(0,0,0,0.1);
backdrop-filter: blur(10px);
}
.input-section {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 25px;
border-radius: 15px;
margin-bottom: 20px;
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
font-size: 1.3em;
padding: 15px 40px;
border-radius: 30px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
width: 100%;
margin-top: 20px;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6);
}
.video-output {
background: #f8f9fa;
padding: 20px;
border-radius: 15px;
text-align: center;
min-height: 400px;
display: flex;
align-items: center;
justify-content: center;
}
.accordion {
background: rgba(255, 255, 255, 0.7);
border-radius: 10px;
margin-top: 15px;
padding: 15px;
}
.slider-container {
background: rgba(255, 255, 255, 0.5);
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
body {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
}
@keyframes gradient {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.gr-button-secondary {
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%);
}
.footer {
text-align: center;
margin-top: 30px;
color: #666;
font-size: 0.9em;
}
"""
# Gradio UI
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_classes="container"):
# Header
gr.HTML("""
<div class="header">
<h1>🎬 AI Video Magic Studio</h1>
<p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p>
</div>
""")
with gr.Row(elem_classes="main-content"):
with gr.Column(scale=1):
gr.Markdown("### 📸 Input Settings")
with gr.Column(elem_classes="input-section"):
input_image = gr.Image(
type="pil",
label="🖼️ Upload Your Image",
elem_classes="image-upload"
)
prompt_input = gr.Textbox(
label="✨ Animation Prompt",
value=config.default_prompt,
placeholder="Describe how you want your image to move...",
lines=2
)
duration_input = gr.Slider(
minimum=round(config.min_frames/config.fixed_fps, 1),
maximum=round(config.max_frames/config.fixed_fps, 1),
step=0.1,
value=2,
label="⏱️ Video Duration (seconds)",
elem_classes="slider-container"
)
with gr.Accordion("🎛️ Advanced Settings", open=False, elem_classes="accordion"):
negative_prompt = gr.Textbox(
label="🚫 Negative Prompt",
value=config.default_negative_prompt,
lines=2
)
with gr.Row():
seed = gr.Slider(
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
label="🎲 Seed"
)
randomize_seed = gr.Checkbox(
label="🔀 Randomize",
value=True
)
with gr.Row():
height_slider = gr.Slider(
minimum=config.slider_min_h,
maximum=config.slider_max_h,
step=config.mod_value,
value=config.default_height,
label="📏 Height"
)
width_slider = gr.Slider(
minimum=config.slider_min_w,
maximum=config.slider_max_w,
step=config.mod_value,
value=config.default_width,
label="📐 Width"
)
steps_slider = gr.Slider(
minimum=1,
maximum=30,
step=1,
value=4,
label="🔧 Quality Steps (4-8 recommended)"
)
guidance_scale = gr.Slider(
minimum=0.0,
maximum=20.0,
step=0.5,
value=1.0,
label="🎯 Guidance Scale",
visible=False
)
generate_btn = gr.Button(
"🎬 Generate Video",
variant="primary",
elem_classes="generate-btn"
)
with gr.Column(scale=1):
gr.Markdown("### 🎥 Generated Video")
video_output = gr.Video(
label="",
autoplay=True,
elem_classes="video-output"
)
gr.HTML("""
<div class="footer">
<p>💡 Tip: For best results, use clear images with good lighting</p>
</div>
""")
# Examples
gr.Examples(
examples=[
["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512],
["forg.jpg", "the frog jumps around", 448, 832],
],
inputs=[input_image, prompt_input, height_slider, width_slider],
outputs=[video_output, seed],
fn=generate_video,
cache_examples="lazy"
)
# Examples 섹션 후에 추가
gr.HTML("""
<div class="improvements-container" style="background: rgba(255, 255, 255, 0.95); backdrop-filter: blur(10px); border-radius: 15px; padding: 20px; margin: 20px auto; max-width: 800px; box-shadow: 0 5px 20px rgba(0,0,0,0.1); font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;">
<div class="improvements-header" style="text-align: center; margin-bottom: 20px;">
<h3 style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 1.5em; margin: 0; font-weight: 700;">✨ Enhanced Features</h3>
<p style="color: #666; font-size: 0.9em; margin-top: 5px;">Optimized for performance, stability, and user experience</p>
</div>
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;">
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;">
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🛡️</span>
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Robust Error Handling</div>
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Advanced validation and recovery mechanisms</div>
</div>
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;">
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">⚡</span>
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Performance Optimized</div>
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Faster processing with smart resource management</div>
</div>
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;">
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🎨</span>
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Modern UI/UX</div>
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Beautiful interface with smooth animations</div>
</div>
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;">
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🔧</span>
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Clean Architecture</div>
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Modular design for easy maintenance</div>
</div>
</div>
<div style="display: flex; flex-wrap: wrap; gap: 5px; margin-top: 15px; justify-content: center;">
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">PyTorch</span>
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">Diffusers</span>
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">Gradio</span>
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">CUDA Optimized</span>
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">LoRA Enhanced</span>
</div>
</div>
""")
# Event handlers
input_image.upload(
fn=handle_image_upload,
inputs=[input_image],
outputs=[height_slider, width_slider]
)
input_image.clear(
fn=handle_image_upload,
inputs=[input_image],
outputs=[height_slider, width_slider]
)
generate_btn.click(
fn=generate_video,
inputs=[
input_image, prompt_input, height_slider, width_slider,
negative_prompt, duration_input, guidance_scale,
steps_slider, seed, randomize_seed
],
outputs=[video_output, seed]
)
if __name__ == "__main__":
demo.queue().launch()