seawolf2357's picture
Update app.py
4e94f64 verified
raw
history blame
26.3 kB
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
import gradio as gr
import tempfile
import spaces
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import random
import logging
import gc
import time
import hashlib
from dataclasses import dataclass
from typing import Optional, Tuple
from functools import wraps
import threading
import os
# GPU 메모리 관리 설정
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512'
# 로깅 설정
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 설정 관리
@dataclass
class VideoGenerationConfig:
model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
lora_repo_id: str = "Kijai/WanVideo_comfy"
lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
mod_value: int = 32
default_height: int = 512
default_width: int = 896
max_area: float = 480.0 * 832.0
slider_min_h: int = 128
slider_max_h: int = 896
slider_min_w: int = 128
slider_max_w: int = 896
fixed_fps: int = 24
min_frames: int = 8
max_frames: int = 81
default_prompt: str = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt: str = "static, blurred, low quality, watermark, text"
# GPU 메모리 최적화 설정
enable_model_cpu_offload: bool = True
enable_vae_slicing: bool = True
enable_vae_tiling: bool = True
@property
def max_duration(self):
"""최대 허용 duration (초)"""
return self.max_frames / self.fixed_fps
@property
def min_duration(self):
"""최소 허용 duration (초)"""
return self.min_frames / self.fixed_fps
config = VideoGenerationConfig()
MAX_SEED = np.iinfo(np.int32).max
# 글로벌 락 (동시 실행 방지)
generation_lock = threading.Lock()
# 성능 측정 데코레이터
def measure_time(func):
@wraps(func)
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
logger.info(f"{func.__name__} took {time.time()-start:.2f}s")
return result
return wrapper
# GPU 메모리 정리 함수
def clear_gpu_memory():
"""강력한 GPU 메모리 정리"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
# GPU 메모리 상태 로깅
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
logger.info(f"GPU Memory - Allocated: {allocated:.2f}GB, Reserved: {reserved:.2f}GB")
# 모델 관리자 (싱글톤 패턴)
class ModelManager:
_instance = None
_lock = threading.Lock()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
if not hasattr(self, '_initialized'):
self._pipe = None
self._is_loaded = False
self._initialized = True
@property
def pipe(self):
if not self._is_loaded:
self._load_model()
return self._pipe
@measure_time
def _load_model(self):
"""메모리 효율적인 모델 로딩"""
with self._lock:
if self._is_loaded:
return
try:
logger.info("Loading model with memory optimizations...")
clear_gpu_memory()
# 모델 컴포넌트 로드 (메모리 효율적)
with torch.cuda.amp.autocast(enabled=False):
image_encoder = CLIPVisionModel.from_pretrained(
config.model_id,
subfolder="image_encoder",
torch_dtype=torch.float16, # float32 대신 float16 사용
low_cpu_mem_usage=True
)
vae = AutoencoderKLWan.from_pretrained(
config.model_id,
subfolder="vae",
torch_dtype=torch.float16, # float32 대신 float16 사용
low_cpu_mem_usage=True
)
self._pipe = WanImageToVideoPipeline.from_pretrained(
config.model_id,
vae=vae,
image_encoder=image_encoder,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
use_safetensors=True
)
# 스케줄러 설정
self._pipe.scheduler = UniPCMultistepScheduler.from_config(
self._pipe.scheduler.config, flow_shift=8.0
)
# LoRA 로드
causvid_path = hf_hub_download(
repo_id=config.lora_repo_id, filename=config.lora_filename
)
self._pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
self._pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95])
self._pipe.fuse_lora()
# GPU 최적화 설정
if config.enable_model_cpu_offload:
self._pipe.enable_model_cpu_offload()
else:
self._pipe.to("cuda")
if config.enable_vae_slicing:
self._pipe.enable_vae_slicing()
if config.enable_vae_tiling:
self._pipe.enable_vae_tiling()
# xFormers 메모리 효율적인 attention 활성화 (가능한 경우)
try:
self._pipe.enable_xformers_memory_efficient_attention()
logger.info("xFormers memory efficient attention enabled")
except:
logger.info("xFormers not available, using default attention")
self._is_loaded = True
logger.info("Model loaded successfully with optimizations")
clear_gpu_memory()
except Exception as e:
logger.error(f"Error loading model: {e}")
self._is_loaded = False
clear_gpu_memory()
raise
def unload_model(self):
"""모델 언로드 및 메모리 해제"""
with self._lock:
if self._pipe is not None:
del self._pipe
self._pipe = None
self._is_loaded = False
clear_gpu_memory()
logger.info("Model unloaded and memory cleared")
# 싱글톤 인스턴스
model_manager = ModelManager()
# 비디오 생성기 클래스
class VideoGenerator:
def __init__(self, config: VideoGenerationConfig, model_manager: ModelManager):
self.config = config
self.model_manager = model_manager
def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]:
orig_w, orig_h = image.size
if orig_w <= 0 or orig_h <= 0:
return self.config.default_height, self.config.default_width
aspect_ratio = orig_h / orig_w
calc_h = round(np.sqrt(self.config.max_area * aspect_ratio))
calc_w = round(np.sqrt(self.config.max_area / aspect_ratio))
calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value)
calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value)
new_h = int(np.clip(calc_h, self.config.slider_min_h,
(self.config.slider_max_h // self.config.mod_value) * self.config.mod_value))
new_w = int(np.clip(calc_w, self.config.slider_min_w,
(self.config.slider_max_w // self.config.mod_value) * self.config.mod_value))
return new_h, new_w
def validate_inputs(self, image: Image.Image, prompt: str, height: int,
width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]:
if image is None:
return False, "🖼️ Please upload an input image"
if not prompt or len(prompt.strip()) == 0:
return False, "✍️ Please provide a prompt"
if len(prompt) > 500:
return False, "⚠️ Prompt is too long (max 500 characters)"
# 정확한 duration 범위 체크
min_duration = self.config.min_duration
max_duration = self.config.max_duration
if duration < min_duration:
return False, f"⏱️ Duration too short (min {min_duration:.1f}s)"
if duration > max_duration:
return False, f"⏱️ Duration too long (max {max_duration:.1f}s)"
# Zero GPU 환경에서는 더 보수적인 제한 적용
if hasattr(spaces, 'GPU'): # Spaces 환경 체크
if duration > 2.5: # Zero GPU에서는 2.5초로 제한
return False, "⏱️ In Zero GPU environment, duration is limited to 2.5s for stability"
if height > 640 or width > 640: # 해상도도 제한
return False, "📐 In Zero GPU environment, resolution is limited to 640x640"
# GPU 메모리 체크
if torch.cuda.is_available():
try:
free_memory = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
required_memory = (height * width * 3 * 8 * duration * self.config.fixed_fps) / (1024**3)
if free_memory < required_memory * 2:
clear_gpu_memory()
return False, "⚠️ Not enough GPU memory. Try smaller dimensions or shorter duration."
except:
pass # GPU 체크 실패시 계속 진행
return True, None
def generate_unique_filename(self, seed: int) -> str:
timestamp = int(time.time())
unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}"
hash_obj = hashlib.md5(unique_str.encode())
return f"video_{hash_obj.hexdigest()[:8]}.mp4"
video_generator = VideoGenerator(config, model_manager)
# Gradio 함수들
def handle_image_upload(image):
if image is None:
return gr.update(value=config.default_height), gr.update(value=config.default_width)
try:
if not isinstance(image, Image.Image):
raise ValueError("Invalid image format")
new_h, new_w = video_generator.calculate_dimensions(image)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception as e:
logger.error(f"Error processing image: {e}")
gr.Warning("⚠️ Error processing image")
return gr.update(value=config.default_height), gr.update(value=config.default_width)
def get_duration(input_image, prompt, height, width, negative_prompt,
duration_seconds, guidance_scale, steps, seed, randomize_seed, progress):
# Zero GPU 환경에서는 더 보수적인 시간 할당
base_duration = 60
# 단계별 추가 시간
if steps > 8:
base_duration += 30
elif steps > 4:
base_duration += 15
# Duration별 추가 시간
if duration_seconds > 2:
base_duration += 20
elif duration_seconds > 1.5:
base_duration += 10
# 해상도별 추가 시간
pixels = height * width
if pixels > 400000: # 약 640x640
base_duration += 20
elif pixels > 250000: # 약 512x512
base_duration += 10
# Zero GPU 환경에서는 최대 90초로 제한
return min(base_duration, 90)
@spaces.GPU(duration=get_duration)
@measure_time
def generate_video(input_image, prompt, height, width,
negative_prompt=config.default_negative_prompt,
duration_seconds=1.5, guidance_scale=1, steps=4,
seed=42, randomize_seed=False,
progress=gr.Progress(track_tqdm=True)):
# 동시 실행 방지
if not generation_lock.acquire(blocking=False):
raise gr.Error("⏳ Another video is being generated. Please wait...")
try:
progress(0.1, desc="🔍 Validating inputs...")
# Zero GPU 환경에서 추가 검증
if hasattr(spaces, 'GPU'):
logger.info(f"Zero GPU environment detected. Duration: {duration_seconds}s, Resolution: {height}x{width}")
# 입력 검증
is_valid, error_msg = video_generator.validate_inputs(
input_image, prompt, height, width, duration_seconds, steps
)
if not is_valid:
raise gr.Error(error_msg)
# 메모리 정리
clear_gpu_memory()
progress(0.2, desc="🎯 Preparing image...")
target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value)
target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value)
# 프레임 수 계산 (Zero GPU 환경에서 추가 제한)
max_allowed_frames = int(2.5 * config.fixed_fps) if hasattr(spaces, 'GPU') else config.max_frames
num_frames = min(
int(round(duration_seconds * config.fixed_fps)),
max_allowed_frames
)
num_frames = np.clip(num_frames, config.min_frames, max_allowed_frames)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
# 이미지 리사이즈 (메모리 효율적)
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
progress(0.3, desc="🎨 Loading model...")
pipe = model_manager.pipe
progress(0.4, desc="🎬 Generating video frames...")
# 메모리 효율적인 생성
with torch.inference_mode(), torch.cuda.amp.autocast(enabled=True):
try:
output_frames_list = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=target_h,
width=target_w,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
return_dict=True
).frames[0]
except torch.cuda.OutOfMemoryError:
clear_gpu_memory()
raise gr.Error("💾 GPU out of memory. Try smaller dimensions or shorter duration.")
except Exception as e:
logger.error(f"Generation error: {e}")
raise gr.Error(f"❌ Generation failed: {str(e)}")
progress(0.9, desc="💾 Saving video...")
filename = video_generator.generate_unique_filename(current_seed)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=config.fixed_fps)
progress(1.0, desc="✨ Complete!")
logger.info(f"Video generated successfully: {num_frames} frames, {target_h}x{target_w}")
return video_path, current_seed
except Exception as e:
logger.error(f"Unexpected error: {e}")
raise
finally:
# 항상 메모리 정리 및 락 해제
generation_lock.release()
# 메모리 정리
if 'output_frames_list' in locals():
del output_frames_list
if 'resized_image' in locals():
del resized_image
clear_gpu_memory()
# 개선된 CSS 스타일
css = """
.container {
max-width: 1200px;
margin: auto;
padding: 20px;
}
.header {
text-align: center;
margin-bottom: 30px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
padding: 40px;
border-radius: 20px;
color: white;
box-shadow: 0 10px 30px rgba(0,0,0,0.2);
position: relative;
overflow: hidden;
}
.header::before {
content: '';
position: absolute;
top: -50%;
left: -50%;
width: 200%;
height: 200%;
background: radial-gradient(circle, rgba(255,255,255,0.1) 0%, transparent 70%);
animation: pulse 4s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { transform: scale(1); opacity: 0.5; }
50% { transform: scale(1.1); opacity: 0.8; }
}
.header h1 {
font-size: 3em;
margin-bottom: 10px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
position: relative;
z-index: 1;
}
.header p {
font-size: 1.2em;
opacity: 0.95;
position: relative;
z-index: 1;
}
.gpu-status {
position: absolute;
top: 10px;
right: 10px;
background: rgba(0,0,0,0.3);
padding: 5px 15px;
border-radius: 20px;
font-size: 0.8em;
}
.main-content {
background: rgba(255, 255, 255, 0.95);
border-radius: 20px;
padding: 30px;
box-shadow: 0 5px 20px rgba(0,0,0,0.1);
backdrop-filter: blur(10px);
}
.input-section {
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
padding: 25px;
border-radius: 15px;
margin-bottom: 20px;
}
.generate-btn {
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
font-size: 1.3em;
padding: 15px 40px;
border-radius: 30px;
border: none;
cursor: pointer;
transition: all 0.3s ease;
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4);
width: 100%;
margin-top: 20px;
}
.generate-btn:hover {
transform: translateY(-2px);
box-shadow: 0 7px 20px rgba(102, 126, 234, 0.6);
}
.generate-btn:active {
transform: translateY(0);
}
.video-output {
background: #f8f9fa;
padding: 20px;
border-radius: 15px;
text-align: center;
min-height: 400px;
display: flex;
align-items: center;
justify-content: center;
}
.accordion {
background: rgba(255, 255, 255, 0.7);
border-radius: 10px;
margin-top: 15px;
padding: 15px;
}
.slider-container {
background: rgba(255, 255, 255, 0.5);
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
body {
background: linear-gradient(-45deg, #ee7752, #e73c7e, #23a6d5, #23d5ab);
background-size: 400% 400%;
animation: gradient 15s ease infinite;
}
@keyframes gradient {
0% { background-position: 0% 50%; }
50% { background-position: 100% 50%; }
100% { background-position: 0% 50%; }
}
.warning-box {
background: rgba(255, 193, 7, 0.1);
border: 1px solid rgba(255, 193, 7, 0.3);
border-radius: 10px;
padding: 15px;
margin: 10px 0;
color: #856404;
font-size: 0.9em;
}
.footer {
text-align: center;
margin-top: 30px;
color: #666;
font-size: 0.9em;
}
"""
# Gradio UI
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
with gr.Column(elem_classes="container"):
# Header with GPU status
gr.HTML("""
<div class="header">
<h1>🎬 AI Video Magic Studio</h1>
<p>Transform your images into captivating videos with Wan 2.1 + CausVid LoRA</p>
<div class="gpu-status">🖥️ GPU Optimized</div>
</div>
""")
# GPU 메모리 경고
gr.HTML("""
<div class="warning-box">
<strong>💡 Zero GPU Performance Tips:</strong>
<ul style="margin: 5px 0; padding-left: 20px;">
<li>Maximum duration: 2.5 seconds (limited by Zero GPU)</li>
<li>Recommended resolution: 512x512 for stable generation</li>
<li>Use 4-6 steps for optimal speed/quality balance</li>
<li>Wait between generations to avoid queue errors</li>
</ul>
</div>
""")
with gr.Row(elem_classes="main-content"):
with gr.Column(scale=1):
gr.Markdown("### 📸 Input Settings")
with gr.Column(elem_classes="input-section"):
input_image = gr.Image(
type="pil",
label="🖼️ Upload Your Image",
elem_classes="image-upload"
)
prompt_input = gr.Textbox(
label="✨ Animation Prompt",
value=config.default_prompt,
placeholder="Describe how you want your image to move...",
lines=2
)
duration_input = gr.Slider(
minimum=round(config.min_duration, 1),
maximum=2.5 if hasattr(spaces, 'GPU') else round(config.max_duration, 1), # Zero GPU 환경 제한
step=0.1,
value=1.5, # 안전한 기본값
label="⏱️ Video Duration (seconds) - Limited to 2.5s in Zero GPU",
elem_classes="slider-container"
)
with gr.Accordion("🎛️ Advanced Settings", open=False, elem_classes="accordion"):
negative_prompt = gr.Textbox(
label="🚫 Negative Prompt",
value=config.default_negative_prompt,
lines=2
)
with gr.Row():
seed = gr.Slider(
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
label="🎲 Seed"
)
randomize_seed = gr.Checkbox(
label="🔀 Randomize",
value=True
)
with gr.Row():
height_slider = gr.Slider(
minimum=config.slider_min_h,
maximum=config.slider_max_h,
step=config.mod_value,
value=config.default_height,
label="📏 Height (lower = more stable)"
)
width_slider = gr.Slider(
minimum=config.slider_min_w,
maximum=config.slider_max_w,
step=config.mod_value,
value=config.default_width,
label="📐 Width (lower = more stable)"
)
steps_slider = gr.Slider(
minimum=1,
maximum=30,
step=1,
value=4,
label="🔧 Quality Steps (4-8 recommended)"
)
guidance_scale = gr.Slider(
minimum=0.0,
maximum=20.0,
step=0.5,
value=1.0,
label="🎯 Guidance Scale",
visible=False
)
generate_btn = gr.Button(
"🎬 Generate Video",
variant="primary",
elem_classes="generate-btn"
)
with gr.Column(scale=1):
gr.Markdown("### 🎥 Generated Video")
video_output = gr.Video(
label="",
autoplay=True,
elem_classes="video-output"
)
gr.HTML("""
<div class="footer">
<p>💡 Tip: For best results, use clear images with good lighting</p>
</div>
""")
# Examples
gr.Examples(
examples=[
["peng.png", "a penguin playfully dancing in the snow, Antarctica", 512, 512],
["forg.jpg", "the frog jumps around", 448, 448],
],
inputs=[input_image, prompt_input, height_slider, width_slider],
outputs=[video_output, seed],
fn=generate_video,
cache_examples=False # 캐시 비활성화로 메모리 절약
)
# 개선사항 요약 (작게)
gr.HTML("""
<div style="background: rgba(255,255,255,0.9); border-radius: 10px; padding: 15px; margin-top: 20px; font-size: 0.8em; text-align: center;">
<p style="margin: 0; color: #666;">
<strong style="color: #667eea;">Enhanced with:</strong>
🛡️ GPU Crash Protection • ⚡ Memory Optimization • 🎨 Modern UI • 🔧 Clean Architecture
</p>
</div>
""")
# Event handlers
input_image.upload(
fn=handle_image_upload,
inputs=[input_image],
outputs=[height_slider, width_slider]
)
input_image.clear(
fn=handle_image_upload,
inputs=[input_image],
outputs=[height_slider, width_slider]
)
generate_btn.click(
fn=generate_video,
inputs=[
input_image, prompt_input, height_slider, width_slider,
negative_prompt, duration_input, guidance_scale,
steps_slider, seed, randomize_seed
],
outputs=[video_output, seed]
)
if __name__ == "__main__":
demo.queue(max_size=1).launch() # 큐 크기 제한으로 메모리 관리