Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,603 Bytes
309fd4d b9186cf 80df978 bec3822 ba7cb71 46b3f7e ba7cb71 908b63e ba7cb71 6961549 ba7cb71 3113790 30b4e47 3113790 30b4e47 176f1a8 3113790 89c4a12 e7f26f5 3113790 89c4a12 ba7cb71 30b4e47 ba7cb71 3113790 ba7cb71 30b4e47 ba7cb71 30b4e47 3113790 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 52f499a 30b4e47 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 3113790 ba7cb71 30b4e47 ba7cb71 30b4e47 ba7cb71 46b3f7e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
#import subprocess
#subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# wan2.2-main/gradio_ti2v.py
import gradio as gr
import torch
from huggingface_hub import snapshot_download
from PIL import Image
import random
import numpy as np
import spaces
import wan
from wan.configs import WAN_CONFIGS, SIZE_CONFIGS, MAX_AREA_CONFIGS, SUPPORTED_SIZES
from wan.utils.utils import cache_video
import gc
# --- 1. Global Setup and Model Loading ---
print("Starting Gradio App for Wan 2.2 TI2V-5B...")
# Download model snapshots from Hugging Face Hub
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"Downloading/loading checkpoints for {repo_id}...")
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"Using checkpoints from {ckpt_dir}")
# Load the model configuration
TASK_NAME = 'ti2v-5B'
cfg = WAN_CONFIGS[TASK_NAME]
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
# Instantiate the pipeline in the global scope
print("Initializing WanTI2V pipeline...")
device = "cuda" if torch.cuda.is_available() else "cpu"
device_id = 0 if torch.cuda.is_available() else -1
pipeline = wan.WanTI2V(
config=cfg,
checkpoint_dir=ckpt_dir,
device_id=device_id,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_sp=False,
t5_cpu=False,
init_on_cpu=False,
convert_model_dtype=True,
)
print("Pipeline initialized and ready.")
# --- Helper Functions ---
def clear_gpu_memory():
"""Clear GPU memory more thoroughly"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
gc.collect()
def select_best_size_for_image(image, available_sizes):
"""Select the size option with aspect ratio closest to the input image."""
if image is None:
return available_sizes[0] # Return first option if no image
img_width, img_height = image.size
img_aspect_ratio = img_height / img_width
best_size = available_sizes[0]
best_diff = float('inf')
for size_str in available_sizes:
# Parse size string like "704*1280"
height, width = map(int, size_str.split('*'))
size_aspect_ratio = height / width
diff = abs(img_aspect_ratio - size_aspect_ratio)
if diff < best_diff:
best_diff = diff
best_size = size_str
return best_size
def handle_image_upload(image):
"""Handle image upload and return the best matching size."""
if image is None:
return gr.update()
pil_image = Image.fromarray(image).convert("RGB")
available_sizes = list(SUPPORTED_SIZES[TASK_NAME])
best_size = select_best_size_for_image(pil_image, available_sizes)
return gr.update(value=best_size)
def validate_inputs(image, prompt, duration_seconds):
"""Validate user inputs"""
errors = []
if not prompt or len(prompt.strip()) < 5:
errors.append("Prompt must be at least 5 characters long.")
if image is not None:
img = Image.fromarray(image)
if img.size[0] * img.size[1] > 4096 * 4096:
errors.append("Image size is too large (maximum 4096x4096).")
if duration_seconds > 5.0 and image is None:
errors.append("Videos longer than 5 seconds require an input image.")
return errors
def get_duration(image,
prompt,
size,
duration_seconds,
sampling_steps,
guide_scale,
shift,
seed,
progress):
"""Calculate dynamic GPU duration based on parameters."""
if sampling_steps > 35 and duration_seconds >= 2:
return 120
elif sampling_steps < 35 or duration_seconds < 2:
return 105
else:
return 90
def apply_template(template, current_prompt):
"""Apply prompt template"""
if "{subject}" in template:
# Extract the main subject from current prompt (simple heuristic)
subject = current_prompt.split(",")[0] if "," in current_prompt else current_prompt
return template.replace("{subject}", subject)
return template + " " + current_prompt
# --- 2. Gradio Inference Function ---
@spaces.GPU(duration=get_duration)
def generate_video(
image,
prompt,
size,
duration_seconds,
sampling_steps,
guide_scale,
shift,
seed,
progress=gr.Progress(track_tqdm=True)
):
"""The main function to generate video, called by the Gradio interface."""
# Validate inputs
errors = validate_inputs(image, prompt, duration_seconds)
if errors:
raise gr.Error("\n".join(errors))
progress(0, desc="Setting up...")
if seed == -1:
seed = random.randint(0, sys.maxsize)
progress(0.1, desc="Processing image...")
input_image = None
if image is not None:
input_image = Image.fromarray(image).convert("RGB")
# Resize image to match selected size
target_height, target_width = map(int, size.split('*'))
input_image = input_image.resize((target_width, target_height))
# Calculate number of frames based on duration
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
progress(0.2, desc="Generating video...")
try:
video_tensor = pipeline.generate(
input_prompt=prompt,
img=input_image, # Pass None for T2V, Image for I2V
size=SIZE_CONFIGS[size],
max_area=MAX_AREA_CONFIGS[size],
frame_num=num_frames, # Use calculated frames instead of cfg.frame_num
shift=shift,
sample_solver='unipc',
sampling_steps=int(sampling_steps),
guide_scale=guide_scale,
seed=seed,
offload_model=True
)
progress(0.9, desc="Saving video...")
# Save the video to a temporary file
video_path = cache_video(
tensor=video_tensor[None], # Add a batch dimension
save_file=None, # cache_video will create a temp file
fps=cfg.sample_fps,
normalize=True,
value_range=(-1, 1)
)
progress(1.0, desc="Complete!")
except torch.cuda.OutOfMemoryError:
clear_gpu_memory()
raise gr.Error("GPU out of memory. Please try with lower settings.")
except Exception as e:
raise gr.Error(f"Video generation failed: {str(e)}")
finally:
if 'video_tensor' in locals():
del video_tensor
clear_gpu_memory()
return video_path
# --- 3. Gradio Interface ---
css = """
.gradio-container {max-width: 1100px !important; margin: 0 auto}
#output_video {height: 500px;}
#input_image {height: 500px;}
.template-btn {margin: 2px !important;}
"""
# Default prompt with motion emphasis
DEFAULT_PROMPT = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
# Prompt templates
templates = {
"Cinematic": "cinematic shot of {subject}, professional lighting, smooth camera movement, 4k quality",
"Animation": "animated style {subject}, vibrant colors, fluid motion, dynamic movement",
"Nature": "nature documentary footage of {subject}, wildlife photography, natural movement",
"Slow Motion": "slow motion capture of {subject}, high speed camera, detailed motion",
"Action": "dynamic action shot of {subject}, fast paced movement, energetic motion"
}
with gr.Blocks(css=css, theme=gr.themes.Soft(), delete_cache=(60, 900)) as demo:
gr.Markdown("""
# Wan 2.2 TI2V Enhanced
Generate high quality videos using **Wan 2.2 5B Text-Image-to-Video model**
[[model]](https://huggingface.co/Wan-AI/Wan2.2-TI2V-5B), [[paper]](https://arxiv.org/abs/2503.20314)
### 💡 Tips for best results:
- 🖼️ Upload an image for better control over the video content
- ⏱️ Longer videos require more processing time
- 🎯 Be specific and descriptive in your prompts
- 🎬 Include motion-related keywords for dynamic videos
""")
with gr.Row():
with gr.Column(scale=2):
image_input = gr.Image(type="numpy", label="Input Image (Optional)", elem_id="input_image")
prompt_input = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=3,
placeholder="Describe the video you want to generate..."
)
# Prompt templates section
with gr.Accordion("Prompt Templates", open=False):
gr.Markdown("Click a template to apply it to your prompt:")
with gr.Row():
template_buttons = {}
for name, template in templates.items():
btn = gr.Button(name, size="sm", elem_classes=["template-btn"])
template_buttons[name] = (btn, template)
# Connect template buttons
for name, (btn, template) in template_buttons.items():
btn.click(
fn=lambda t=template, p=prompt_input: apply_template(t, p),
inputs=[prompt_input],
outputs=prompt_input
)
duration_input = gr.Slider(
minimum=round(MIN_FRAMES_MODEL/FIXED_FPS, 1),
maximum=round(MAX_FRAMES_MODEL/FIXED_FPS, 1),
step=0.1,
value=2.0,
label="Duration (seconds)",
info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps."
)
size_input = gr.Dropdown(
label="Output Resolution",
choices=list(SUPPORTED_SIZES[TASK_NAME]),
value="704*1280"
)
with gr.Column(scale=2):
video_output = gr.Video(label="Generated Video", elem_id="output_video")
# Status indicators
with gr.Row():
status_text = gr.Textbox(
label="Status",
value="Ready",
interactive=False,
max_lines=1
)
with gr.Accordion("Advanced Settings", open=False):
steps_input = gr.Slider(
label="Sampling Steps",
minimum=10,
maximum=50,
value=38,
step=1,
info="Higher values = better quality but slower"
)
scale_input = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=10.0,
value=cfg.sample_guide_scale,
step=0.1,
info="Higher values = closer to prompt but less creative"
)
shift_input = gr.Slider(
label="Sample Shift",
minimum=1.0,
maximum=20.0,
value=cfg.sample_shift,
step=0.1,
info="Affects the sampling process dynamics"
)
seed_input = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0,
info="Use same seed for reproducible results"
)
run_button = gr.Button("Generate Video", variant="primary", size="lg")
# Add image upload handler
image_input.upload(
fn=handle_image_upload,
inputs=[image_input],
outputs=[size_input]
)
image_input.clear(
fn=handle_image_upload,
inputs=[image_input],
outputs=[size_input]
)
# Update status when generating
def update_status_and_generate(*args):
status_text.value = "Generating..."
try:
result = generate_video(*args)
status_text.value = "Complete!"
return result
except Exception as e:
status_text.value = "Error occurred"
raise e
example_image_path = os.path.join(os.path.dirname(__file__), "examples/i2v_input.JPG")
gr.Examples(
examples=[
[example_image_path, "The cat removes the glasses from its eyes with smooth motion.", "1280*704", 1.5],
[None, "A cinematic shot of a boat sailing on calm waves with gentle rocking motion at sunset.", "1280*704", 2.0],
[None, "Drone footage flying smoothly over a futuristic city with flying cars in continuous motion.", "1280*704", 2.0],
[None, DEFAULT_PROMPT + " A waterfall cascading down rocks.", "704*1280", 2.5],
[None, DEFAULT_PROMPT + " Birds flying across a cloudy sky.", "1280*704", 3.0],
],
inputs=[image_input, prompt_input, size_input, duration_input],
outputs=video_output,
fn=generate_video,
cache_examples=False,
)
run_button.click(
fn=generate_video,
inputs=[image_input, prompt_input, size_input, duration_input, steps_input, scale_input, shift_input, seed_input],
outputs=video_output
)
if __name__ == "__main__":
demo.launch() |