Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler | |
from diffusers.utils import export_to_video | |
from transformers import CLIPVisionModel | |
import gradio as gr | |
import tempfile | |
import spaces | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
from PIL import Image | |
import random | |
import logging | |
import gc | |
import time | |
import hashlib | |
from dataclasses import dataclass | |
from typing import Optional, Tuple | |
from functools import wraps | |
# 로깅 설정 | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# 설정 관리 | |
class VideoGenerationConfig: | |
model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
lora_repo_id: str = "Kijai/WanVideo_comfy" | |
lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
mod_value: int = 32 | |
default_height: int = 512 | |
default_width: int = 896 | |
max_area: float = 480.0 * 832.0 | |
slider_min_h: int = 128 | |
slider_max_h: int = 896 | |
slider_min_w: int = 128 | |
slider_max_w: int = 896 | |
fixed_fps: int = 24 | |
min_frames: int = 8 | |
max_frames: int = 81 | |
default_prompt: str = "make this image come alive, cinematic motion, smooth animation" | |
default_negative_prompt: str = "static, blurred, low quality, watermark, text" | |
config = VideoGenerationConfig() | |
MAX_SEED = np.iinfo(np.int32).max | |
# 성능 측정 데코레이터 | |
def measure_time(func): | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
logger.info(f"{func.__name__} took {time.time()-start:.2f}s") | |
return result | |
return wrapper | |
# 모델 관리자 | |
class ModelManager: | |
def __init__(self): | |
self._pipe = None | |
self._is_loaded = False | |
def pipe(self): | |
if not self._is_loaded: | |
self._load_model() | |
return self._pipe | |
def _load_model(self): | |
logger.info("Loading model...") | |
image_encoder = CLIPVisionModel.from_pretrained( | |
config.model_id, subfolder="image_encoder", torch_dtype=torch.float32 | |
) | |
vae = AutoencoderKLWan.from_pretrained( | |
config.model_id, subfolder="vae", torch_dtype=torch.float32 | |
) | |
self._pipe = WanImageToVideoPipeline.from_pretrained( | |
config.model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16 | |
) | |
self._pipe.scheduler = UniPCMultistepScheduler.from_config( | |
self._pipe.scheduler.config, flow_shift=8.0 | |
) | |
self._pipe.to("cuda") | |
causvid_path = hf_hub_download( | |
repo_id=config.lora_repo_id, filename=config.lora_filename | |
) | |
self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") | |
self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) | |
self._pipe.fuse_lora() | |
self._is_loaded = True | |
logger.info("Model loaded successfully") | |
model_manager = ModelManager() | |
# 비디오 생성기 클래스 | |
class VideoGenerator: | |
def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager): | |
self.config = config | |
self.model_manager = model_manager | |
def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]: | |
orig_w, orig_h = image.size | |
if orig_w <= 0 or orig_h <= 0: | |
return self.config.default_height, self.config.default_width | |
aspect_ratio = orig_h / orig_w | |
calc_h = round(np.sqrt(self.config.max_area * aspect_ratio)) | |
calc_w = round(np.sqrt(self.config.max_area / aspect_ratio)) | |
calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value) | |
calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value) | |
new_h = int(np.clip(calc_h, self.config.slider_min_h, | |
(self.config.slider_max_h // self.config.mod_value) * self.config.mod_value)) | |
new_w = int(np.clip(calc_w, self.config.slider_min_w, | |
(self.config.slider_max_w // self.config.mod_value) * self.config.mod_value)) | |
return new_h, new_w | |
def validate_inputs(self, image: Image.Image, prompt: str, height: int, | |
width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]: | |
if image is None: | |
return False, "🖼️ Please upload an input image" | |
if not prompt or len(prompt.strip()) == 0: | |
return False, "✍️ Please provide a prompt" | |
if len(prompt) > 500: | |
return False, "⚠️ Prompt is too long (max 500 characters)" | |
if duration < self.config.min_frames / self.config.fixed_fps: | |
return False, f"⏱️ Duration too short (min {self.config.min_frames/self.config.fixed_fps:.1f}s)" | |
if duration > self.config.max_frames / self.config.fixed_fps: | |
return False, f"⏱️ Duration too long (max {self.config.max_frames/self.config.fixed_fps:.1f}s)" | |
return True, None | |
def generate_unique_filename(self, seed: int) -> str: | |
timestamp = int(time.time()) | |
unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}" | |
hash_obj = hashlib.md5(unique_str.encode()) | |
return f"video_{hash_obj.hexdigest()[:8]}.mp4" | |
video_generator = VideoGenerator(config, model_manager) | |
# Gradio 함수들 | |
def handle_image_upload(image): | |
if image is None: | |
return gr.update(value=config.default_height), gr.update(value=config.default_width) | |
try: | |
if not isinstance(image, Image.Image): | |
raise ValueError("Invalid image format") | |
new_h, new_w = video_generator.calculate_dimensions(image) | |
return gr.update(value=new_h), gr.update(value=new_w) | |
except Exception as e: | |
logger.error(f"Error processing image: {e}") | |
gr.Warning("⚠️ Error processing image") | |
return gr.update(value=config.default_height), gr.update(value=config.default_width) | |
def get_duration(input_image, prompt, height, width, negative_prompt, | |
duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): | |
if steps > 4 and duration_seconds > 2: | |
return 90 | |
elif steps > 4 or duration_seconds > 2: | |
return 75 | |
else: | |
return 60 | |
def generate_video(input_image, prompt, height, width, | |
negative_prompt=config.default_negative_prompt, | |
duration_seconds=2, guidance_scale=1, steps=4, | |
seed=42, randomize_seed=False, | |
progress=gr.Progress(track_tqdm=True)): | |
progress(0.1, desc="🔍 Validating inputs...") | |
# 입력 검증 | |
is_valid, error_msg = video_generator.validate_inputs( | |
input_image, prompt, height, width, duration_seconds, steps | |
) | |
if not is_valid: | |
raise gr.Error(error_msg) | |
try: | |
progress(0.2, desc="🎯 Preparing image...") | |
target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value) | |
target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value) | |
num_frames = np.clip(int(round(duration_seconds * config.fixed_fps)), | |
config.min_frames, config.max_frames) | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
progress(0.3, desc="🎨 Loading model...") | |
pipe = model_manager.pipe | |
progress(0.4, desc="🎬 Generating video frames...") | |
with torch.inference_mode(): | |
output_frames_list = pipe( | |
image=resized_image, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=target_h, | |
width=target_w, | |
num_frames=num_frames, | |
guidance_scale=float(guidance_scale), | |
num_inference_steps=int(steps), | |
generator=torch.Generator(device="cuda").manual_seed(current_seed) | |
).frames[0] | |
progress(0.9, desc="💾 Saving video...") | |
filename = video_generator.generate_unique_filename(current_seed) | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
video_path = tmpfile.name | |
export_to_video(output_frames_list, video_path, fps=config.fixed_fps) | |
progress(1.0, desc="✨ Complete!") | |
return video_path, current_seed | |
finally: | |
# 메모리 정리 | |
if 'output_frames_list' in locals(): | |
del output_frames_list | |
gc.collect() | |
torch.cuda.empty_cache() | |
# CSS 스타일 | |
css = """ | |
.container { | |
max-width: 1200px; | |
margin: auto; | |
padding: 20px; | |
} | |
.header { | |
text-align: center; | |
margin-bottom: 30px; | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
padding: 40px; | |
border-radius: 20px; | |
color: white; | |
box-shadow: 0 10px 30px rgba(0,0,0,0.2); | |
} | |
.header h1 { | |
font-size: 3em; | |
margin-bottom: 10px; | |
text-shadow: 2px 2px 4px rgba(0,0,0,0.3); | |
} | |
.header p { | |
font-size: 1.2em; | |
opacity: 0.95; | |
} | |
.main-content { | |
background: rgba(255, 255, 255, 0.95); | |
border-radius: 20px; | |
padding: 30px; | |
box-shadow: 0 5px 20px rgba(0,0,0,0.1); | |
backdrop-filter: blur(10px); | |
} | |
.input-section { | |
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); | |
padding: 25px; | |
border-radius: 15px; | |
margin-bottom: 20px; | |
} | |
.generate-btn { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
color: white; | |
font-size: 1.3em; | |
padding: 15px 40px; | |
border-radius: 30px; | |
border: none; | |
cursor: pointer; | |
transition: all 0.3s ease; | |
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); | |
width: 100%; | |
margin-top: 20px; | |
} | |
.generate-btn:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6); | |
} | |
.video-output { | |
background: #f8f9fa; | |
padding: 20px; | |
border-radius: 15px; | |
text-align: center; | |
min-height: 400px; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
} | |
.accordion { | |
background: rgba(255, 255, 255, 0.7); | |
border-radius: 10px; | |
margin-top: 15px; | |
padding: 15px; | |
} | |
.slider-container { | |
background: rgba(255, 255, 255, 0.5); | |
padding: 15px; | |
border-radius: 10px; | |
margin: 10px 0; | |
} | |
body { | |
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab); | |
background-size: 400% 400%; | |
animation: gradient 15s ease infinite; | |
} | |
@keyframes gradient { | |
0% { background-position: 0% 50%; } | |
50% { background-position: 100% 50%; } | |
100% { background-position: 0% 50%; } | |
} | |
.gr-button-secondary { | |
background: linear-gradient(135deg, #f093fb 0%, #f5576c 100%); | |
} | |
.footer { | |
text-align: center; | |
margin-top: 30px; | |
color: #666; | |
font-size: 0.9em; | |
} | |
""" | |
# Gradio UI | |
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
with gr.Column(elem_classes="container"): | |
# Header | |
gr.HTML(""" | |
<div class="header"> | |
<h1>🎬 AI Video Magic Studio</h1> | |
<p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p> | |
</div> | |
""") | |
with gr.Row(elem_classes="main-content"): | |
with gr.Column(scale=1): | |
gr.Markdown("### 📸 Input Settings") | |
with gr.Column(elem_classes="input-section"): | |
input_image = gr.Image( | |
type="pil", | |
label="🖼️ Upload Your Image", | |
elem_classes="image-upload" | |
) | |
prompt_input = gr.Textbox( | |
label="✨ Animation Prompt", | |
value=config.default_prompt, | |
placeholder="Describe how you want your image to move...", | |
lines=2 | |
) | |
duration_input = gr.Slider( | |
minimum=round(config.min_frames/config.fixed_fps, 1), | |
maximum=round(config.max_frames/config.fixed_fps, 1), | |
step=0.1, | |
value=2, | |
label="⏱️ Video Duration (seconds)", | |
elem_classes="slider-container" | |
) | |
with gr.Accordion("🎛️ Advanced Settings", open=False, elem_classes="accordion"): | |
negative_prompt = gr.Textbox( | |
label="🚫 Negative Prompt", | |
value=config.default_negative_prompt, | |
lines=2 | |
) | |
with gr.Row(): | |
seed = gr.Slider( | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
label="🎲 Seed" | |
) | |
randomize_seed = gr.Checkbox( | |
label="🔀 Randomize", | |
value=True | |
) | |
with gr.Row(): | |
height_slider = gr.Slider( | |
minimum=config.slider_min_h, | |
maximum=config.slider_max_h, | |
step=config.mod_value, | |
value=config.default_height, | |
label="📏 Height" | |
) | |
width_slider = gr.Slider( | |
minimum=config.slider_min_w, | |
maximum=config.slider_max_w, | |
step=config.mod_value, | |
value=config.default_width, | |
label="📐 Width" | |
) | |
steps_slider = gr.Slider( | |
minimum=1, | |
maximum=30, | |
step=1, | |
value=4, | |
label="🔧 Quality Steps (4-8 recommended)" | |
) | |
guidance_scale = gr.Slider( | |
minimum=0.0, | |
maximum=20.0, | |
step=0.5, | |
value=1.0, | |
label="🎯 Guidance Scale", | |
visible=False | |
) | |
generate_btn = gr.Button( | |
"🎬 Generate Video", | |
variant="primary", | |
elem_classes="generate-btn" | |
) | |
with gr.Column(scale=1): | |
gr.Markdown("### 🎥 Generated Video") | |
video_output = gr.Video( | |
label="", | |
autoplay=True, | |
elem_classes="video-output" | |
) | |
gr.HTML(""" | |
<div class="footer"> | |
<p>💡 Tip: For best results, use clear images with good lighting</p> | |
</div> | |
""") | |
# Examples | |
gr.Examples( | |
examples=[ | |
["peng.png", "a penguin playfully dancing in the snow, Antarctica", 896, 512], | |
["forg.jpg", "the frog jumps around", 448, 832], | |
], | |
inputs=[input_image, prompt_input, height_slider, width_slider], | |
outputs=[video_output, seed], | |
fn=generate_video, | |
cache_examples="lazy" | |
) | |
# Examples 섹션 후에 추가 | |
gr.HTML(""" | |
<div class="improvements-container" style="background: rgba(255, 255, 255, 0.95); backdrop-filter: blur(10px); border-radius: 15px; padding: 20px; margin: 20px auto; max-width: 800px; box-shadow: 0 5px 20px rgba(0,0,0,0.1); font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif;"> | |
<div class="improvements-header" style="text-align: center; margin-bottom: 20px;"> | |
<h3 style="background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 1.5em; margin: 0; font-weight: 700;">✨ Enhanced Features</h3> | |
<p style="color: #666; font-size: 0.9em; margin-top: 5px;">Optimized for performance, stability, and user experience</p> | |
</div> | |
<div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 15px;"> | |
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;"> | |
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🛡️</span> | |
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Robust Error Handling</div> | |
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Advanced validation and recovery mechanisms</div> | |
</div> | |
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;"> | |
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">⚡</span> | |
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Performance Optimized</div> | |
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Faster processing with smart resource management</div> | |
</div> | |
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;"> | |
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🎨</span> | |
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Modern UI/UX</div> | |
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Beautiful interface with smooth animations</div> | |
</div> | |
<div style="background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%); border-radius: 10px; padding: 15px;"> | |
<span style="font-size: 1.5em; margin-bottom: 8px; display: block;">🔧</span> | |
<div style="font-weight: 600; color: #333; font-size: 0.95em; margin-bottom: 5px;">Clean Architecture</div> | |
<div style="font-size: 0.75em; color: #666; line-height: 1.4;">Modular design for easy maintenance</div> | |
</div> | |
</div> | |
<div style="display: flex; flex-wrap: wrap; gap: 5px; margin-top: 15px; justify-content: center;"> | |
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">PyTorch</span> | |
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">Diffusers</span> | |
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">Gradio</span> | |
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">CUDA Optimized</span> | |
<span style="background: rgba(102, 126, 234, 0.1); color: #667eea; padding: 3px 10px; border-radius: 20px; font-size: 0.7em; font-weight: 500;">LoRA Enhanced</span> | |
</div> | |
</div> | |
""") | |
# Event handlers | |
input_image.upload( | |
fn=handle_image_upload, | |
inputs=[input_image], | |
outputs=[height_slider, width_slider] | |
) | |
input_image.clear( | |
fn=handle_image_upload, | |
inputs=[input_image], | |
outputs=[height_slider, width_slider] | |
) | |
generate_btn.click( | |
fn=generate_video, | |
inputs=[ | |
input_image, prompt_input, height_slider, width_slider, | |
negative_prompt, duration_input, guidance_scale, | |
steps_slider, seed, randomize_seed | |
], | |
outputs=[video_output, seed] | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() |