|
import torch |
|
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler |
|
from diffusers.utils import export_to_video |
|
from transformers import CLIPVisionModel |
|
import gradio as gr |
|
import tempfile |
|
import spaces |
|
from huggingface_hub import hf_hub_download |
|
import numpy as np |
|
from PIL import Image |
|
import random |
|
import logging |
|
import gc |
|
import time |
|
import hashlib |
|
from dataclasses import dataclass |
|
from typing import Optional, Tuple |
|
from functools import wraps |
|
import threading |
|
import os |
|
|
|
|
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:256' |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
class VideoGenerationConfig: |
|
model_id: str = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" |
|
lora_repo_id: str = "Kijai/WanVideo_comfy" |
|
lora_filename: str = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" |
|
mod_value: int = 32 |
|
|
|
default_height: int = 384 |
|
default_width: int = 384 |
|
max_area: float = 384.0 * 384.0 |
|
slider_min_h: int = 128 |
|
slider_max_h: int = 640 |
|
slider_min_w: int = 128 |
|
slider_max_w: int = 640 |
|
fixed_fps: int = 24 |
|
min_frames: int = 8 |
|
max_frames: int = 36 |
|
default_prompt: str = "make this image come alive, cinematic motion" |
|
default_negative_prompt: str = "static, blurred, low quality" |
|
|
|
enable_model_cpu_offload: bool = True |
|
enable_vae_slicing: bool = True |
|
enable_vae_tiling: bool = True |
|
|
|
@property |
|
def max_duration(self): |
|
"""์ต๋ ํ์ฉ duration (์ด)""" |
|
return self.max_frames / self.fixed_fps |
|
|
|
@property |
|
def min_duration(self): |
|
"""์ต์ ํ์ฉ duration (์ด)""" |
|
return self.min_frames / self.fixed_fps |
|
|
|
config = VideoGenerationConfig() |
|
MAX_SEED = np.iinfo(np.int32).max |
|
|
|
|
|
pipe = None |
|
generation_lock = threading.Lock() |
|
|
|
|
|
def measure_time(func): |
|
@wraps(func) |
|
def wrapper(*args, **kwargs): |
|
start = time.time() |
|
result = func(*args, **kwargs) |
|
logger.info(f"{func.__name__} took {time.time()-start:.2f}s") |
|
return result |
|
return wrapper |
|
|
|
|
|
def clear_gpu_memory(): |
|
"""๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ (Zero GPU ์์ )""" |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
try: |
|
torch.cuda.empty_cache() |
|
torch.cuda.synchronize() |
|
except: |
|
pass |
|
|
|
|
|
class VideoGenerator: |
|
def __init__(self, config: VideoGenerationConfig): |
|
self.config = config |
|
|
|
def calculate_dimensions(self, image: Image.Image) -> Tuple[int, int]: |
|
orig_w, orig_h = image.size |
|
if orig_w <= 0 or orig_h <= 0: |
|
return self.config.default_height, self.config.default_width |
|
|
|
aspect_ratio = orig_h / orig_w |
|
|
|
|
|
max_area = 384.0 * 384.0 |
|
|
|
calc_h = round(np.sqrt(max_area * aspect_ratio)) |
|
calc_w = round(np.sqrt(max_area / aspect_ratio)) |
|
|
|
calc_h = max(self.config.mod_value, (calc_h // self.config.mod_value) * self.config.mod_value) |
|
calc_w = max(self.config.mod_value, (calc_w // self.config.mod_value) * self.config.mod_value) |
|
|
|
|
|
new_h = int(np.clip(calc_h, self.config.slider_min_h, 640)) |
|
new_w = int(np.clip(calc_w, self.config.slider_min_w, 640)) |
|
|
|
|
|
new_h = (new_h // self.config.mod_value) * self.config.mod_value |
|
new_w = (new_w // self.config.mod_value) * self.config.mod_value |
|
|
|
return new_h, new_w |
|
|
|
def validate_inputs(self, image: Image.Image, prompt: str, height: int, |
|
width: int, duration: float, steps: int) -> Tuple[bool, Optional[str]]: |
|
if image is None: |
|
return False, "๐ผ๏ธ Please upload an input image" |
|
|
|
if not prompt or len(prompt.strip()) == 0: |
|
return False, "โ๏ธ Please provide a prompt" |
|
|
|
if len(prompt) > 300: |
|
return False, "โ ๏ธ Prompt is too long (max 300 characters)" |
|
|
|
|
|
if duration < 0.3: |
|
return False, "โฑ๏ธ Duration too short (min 0.3s)" |
|
|
|
if duration > 1.5: |
|
return False, "โฑ๏ธ Duration too long (max 1.5s for stability)" |
|
|
|
|
|
max_pixels = 384 * 384 |
|
if height * width > max_pixels: |
|
return False, f"๐ Total pixels limited to {max_pixels:,} (e.g., 384ร384)" |
|
|
|
if height > 640 or width > 640: |
|
return False, "๐ Maximum dimension is 640 pixels" |
|
|
|
if steps > 6: |
|
return False, "๐ง Maximum 6 steps in Zero GPU environment" |
|
|
|
return True, None |
|
|
|
def generate_unique_filename(self, seed: int) -> str: |
|
timestamp = int(time.time()) |
|
unique_str = f"{timestamp}_{seed}_{random.randint(1000, 9999)}" |
|
hash_obj = hashlib.md5(unique_str.encode()) |
|
return f"video_{hash_obj.hexdigest()[:8]}.mp4" |
|
|
|
video_generator = VideoGenerator(config) |
|
|
|
|
|
def handle_image_upload(image): |
|
if image is None: |
|
return gr.update(value=config.default_height), gr.update(value=config.default_width) |
|
|
|
try: |
|
if not isinstance(image, Image.Image): |
|
raise ValueError("Invalid image format") |
|
|
|
new_h, new_w = video_generator.calculate_dimensions(image) |
|
return gr.update(value=new_h), gr.update(value=new_w) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing image: {e}") |
|
gr.Warning("โ ๏ธ Error processing image") |
|
return gr.update(value=config.default_height), gr.update(value=config.default_width) |
|
|
|
def get_duration(input_image, prompt, height, width, negative_prompt, |
|
duration_seconds, guidance_scale, steps, seed, randomize_seed, progress): |
|
|
|
base_duration = 40 |
|
|
|
|
|
pixels = height * width |
|
if pixels > 200000: |
|
base_duration += 20 |
|
elif pixels > 147456: |
|
base_duration += 10 |
|
|
|
|
|
if steps > 4: |
|
base_duration += 10 |
|
|
|
|
|
return min(base_duration, 70) |
|
|
|
@spaces.GPU(duration=get_duration) |
|
@measure_time |
|
def generate_video(input_image, prompt, height, width, |
|
negative_prompt=config.default_negative_prompt, |
|
duration_seconds=1.0, guidance_scale=1, steps=3, |
|
seed=42, randomize_seed=False, |
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
global pipe |
|
|
|
|
|
if not generation_lock.acquire(blocking=False): |
|
raise gr.Error("โณ Another video is being generated. Please wait...") |
|
|
|
try: |
|
progress(0.05, desc="๐ Validating inputs...") |
|
logger.info(f"Starting generation - Resolution: {height}x{width}, Duration: {duration_seconds}s, Steps: {steps}") |
|
|
|
|
|
is_valid, error_msg = video_generator.validate_inputs( |
|
input_image, prompt, height, width, duration_seconds, steps |
|
) |
|
if not is_valid: |
|
raise gr.Error(error_msg) |
|
|
|
|
|
clear_gpu_memory() |
|
|
|
progress(0.1, desc="๐ Loading model...") |
|
|
|
|
|
if pipe is None: |
|
try: |
|
|
|
image_encoder = CLIPVisionModel.from_pretrained( |
|
config.model_id, |
|
subfolder="image_encoder", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
vae = AutoencoderKLWan.from_pretrained( |
|
config.model_id, |
|
subfolder="vae", |
|
torch_dtype=torch.float16, |
|
low_cpu_mem_usage=True |
|
) |
|
|
|
pipe = WanImageToVideoPipeline.from_pretrained( |
|
config.model_id, |
|
vae=vae, |
|
image_encoder=image_encoder, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
use_safetensors=True |
|
) |
|
|
|
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config( |
|
pipe.scheduler.config, flow_shift=8.0 |
|
) |
|
|
|
|
|
try: |
|
causvid_path = hf_hub_download( |
|
repo_id=config.lora_repo_id, filename=config.lora_filename |
|
) |
|
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") |
|
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.95]) |
|
pipe.fuse_lora() |
|
except: |
|
logger.warning("LoRA loading skipped") |
|
|
|
|
|
pipe.to("cuda") |
|
|
|
|
|
pipe.enable_vae_slicing() |
|
pipe.enable_vae_tiling() |
|
|
|
|
|
try: |
|
pipe.enable_xformers_memory_efficient_attention() |
|
except: |
|
pass |
|
|
|
logger.info("Model loaded successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"Model loading failed: {e}") |
|
raise gr.Error("Failed to load model") |
|
|
|
progress(0.3, desc="๐ฏ Preparing image...") |
|
|
|
|
|
target_h = max(config.mod_value, (int(height) // config.mod_value) * config.mod_value) |
|
target_w = max(config.mod_value, (int(width) // config.mod_value) * config.mod_value) |
|
|
|
|
|
num_frames = min( |
|
int(round(duration_seconds * config.fixed_fps)), |
|
36 |
|
) |
|
num_frames = max(8, num_frames) |
|
|
|
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) |
|
|
|
|
|
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) |
|
|
|
progress(0.4, desc="๐ฌ Generating video...") |
|
|
|
|
|
with torch.inference_mode(), torch.amp.autocast('cuda', enabled=True): |
|
try: |
|
|
|
output_frames_list = pipe( |
|
image=resized_image, |
|
prompt=prompt[:200], |
|
negative_prompt=negative_prompt[:100], |
|
height=target_h, |
|
width=target_w, |
|
num_frames=num_frames, |
|
guidance_scale=float(guidance_scale), |
|
num_inference_steps=int(steps), |
|
generator=torch.Generator(device="cuda").manual_seed(current_seed), |
|
return_dict=True |
|
).frames[0] |
|
|
|
except torch.cuda.OutOfMemoryError: |
|
clear_gpu_memory() |
|
raise gr.Error("๐พ GPU out of memory. Try smaller dimensions.") |
|
except Exception as e: |
|
logger.error(f"Generation error: {e}") |
|
raise gr.Error(f"โ Generation failed: {str(e)[:100]}") |
|
|
|
progress(0.9, desc="๐พ Saving video...") |
|
|
|
|
|
filename = video_generator.generate_unique_filename(current_seed) |
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: |
|
video_path = tmpfile.name |
|
|
|
export_to_video(output_frames_list, video_path, fps=config.fixed_fps) |
|
|
|
progress(1.0, desc="โจ Complete!") |
|
logger.info(f"Video generated: {num_frames} frames, {target_h}x{target_w}") |
|
|
|
|
|
del output_frames_list |
|
del resized_image |
|
clear_gpu_memory() |
|
|
|
return video_path, current_seed |
|
|
|
except gr.Error: |
|
raise |
|
except Exception as e: |
|
logger.error(f"Unexpected error: {e}") |
|
raise gr.Error(f"โ Error: {str(e)[:100]}") |
|
|
|
finally: |
|
generation_lock.release() |
|
clear_gpu_memory() |
|
|
|
|
|
css = """ |
|
.container { |
|
max-width: 1000px; |
|
margin: auto; |
|
padding: 20px; |
|
} |
|
|
|
.header { |
|
text-align: center; |
|
margin-bottom: 20px; |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
padding: 30px; |
|
border-radius: 15px; |
|
color: white; |
|
box-shadow: 0 5px 15px rgba(0,0,0,0.2); |
|
} |
|
|
|
.header h1 { |
|
font-size: 2.5em; |
|
margin-bottom: 10px; |
|
} |
|
|
|
.warning-box { |
|
background: #fff3cd; |
|
border: 1px solid #ffeaa7; |
|
border-radius: 8px; |
|
padding: 12px; |
|
margin: 10px 0; |
|
color: #856404; |
|
font-size: 0.9em; |
|
} |
|
|
|
.generate-btn { |
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
color: white; |
|
font-size: 1.2em; |
|
padding: 12px 30px; |
|
border-radius: 25px; |
|
border: none; |
|
cursor: pointer; |
|
width: 100%; |
|
margin-top: 15px; |
|
} |
|
|
|
.generate-btn:hover { |
|
transform: translateY(-2px); |
|
box-shadow: 0 5px 15px rgba(102, 126, 234, 0.4); |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: |
|
with gr.Column(elem_classes="container"): |
|
|
|
gr.HTML(""" |
|
<div class="header"> |
|
<h1>๐ฌ AI Video Generator</h1> |
|
<p>Transform images into videos with Wan 2.1 (Zero GPU Optimized)</p> |
|
</div> |
|
""") |
|
|
|
|
|
gr.HTML(""" |
|
<div class="warning-box"> |
|
<strong>โก Zero GPU Limitations:</strong> |
|
<ul style="margin: 5px 0; padding-left: 20px;"> |
|
<li>Max resolution: 384ร384 (recommended)</li> |
|
<li>Max duration: 1.5 seconds</li> |
|
<li>Max steps: 6 (3-4 recommended)</li> |
|
<li>Processing time: ~40-60 seconds</li> |
|
</ul> |
|
</div> |
|
""") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
input_image = gr.Image( |
|
type="pil", |
|
label="๐ผ๏ธ Upload Image" |
|
) |
|
|
|
prompt_input = gr.Textbox( |
|
label="โจ Animation Prompt", |
|
value=config.default_prompt, |
|
placeholder="Describe the motion...", |
|
lines=2, |
|
max_lines=3 |
|
) |
|
|
|
duration_input = gr.Slider( |
|
minimum=0.3, |
|
maximum=1.5, |
|
step=0.1, |
|
value=1.0, |
|
label="โฑ๏ธ Duration (seconds)" |
|
) |
|
|
|
with gr.Accordion("โ๏ธ Settings", open=False): |
|
negative_prompt = gr.Textbox( |
|
label="Negative Prompt", |
|
value=config.default_negative_prompt, |
|
lines=1 |
|
) |
|
|
|
with gr.Row(): |
|
height_slider = gr.Slider( |
|
minimum=128, |
|
maximum=640, |
|
step=32, |
|
value=384, |
|
label="Height" |
|
) |
|
width_slider = gr.Slider( |
|
minimum=128, |
|
maximum=640, |
|
step=32, |
|
value=384, |
|
label="Width" |
|
) |
|
|
|
steps_slider = gr.Slider( |
|
minimum=1, |
|
maximum=6, |
|
step=1, |
|
value=3, |
|
label="Steps (3-4 recommended)" |
|
) |
|
|
|
with gr.Row(): |
|
seed = gr.Slider( |
|
minimum=0, |
|
maximum=MAX_SEED, |
|
step=1, |
|
value=42, |
|
label="Seed" |
|
) |
|
randomize_seed = gr.Checkbox( |
|
label="Random", |
|
value=True |
|
) |
|
|
|
guidance_scale = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
step=0.5, |
|
value=1.0, |
|
label="Guidance Scale", |
|
visible=False |
|
) |
|
|
|
generate_btn = gr.Button( |
|
"๐ฌ Generate Video", |
|
variant="primary", |
|
elem_classes="generate-btn" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
video_output = gr.Video( |
|
label="Generated Video", |
|
autoplay=True |
|
) |
|
|
|
gr.Markdown(""" |
|
### ๐ก Tips: |
|
- Use 384ร384 for best results |
|
- Keep prompts simple and clear |
|
- 3-4 steps is optimal |
|
- Wait for completion before next generation |
|
""") |
|
|
|
|
|
input_image.upload( |
|
fn=handle_image_upload, |
|
inputs=[input_image], |
|
outputs=[height_slider, width_slider] |
|
) |
|
|
|
generate_btn.click( |
|
fn=generate_video, |
|
inputs=[ |
|
input_image, prompt_input, height_slider, width_slider, |
|
negative_prompt, duration_input, guidance_scale, |
|
steps_slider, seed, randomize_seed |
|
], |
|
outputs=[video_output, seed] |
|
) |
|
|
|
if __name__ == "__main__": |
|
logger.info("Starting app in Zero GPU environment") |
|
demo.queue(max_size=2) |
|
demo.launch() |