Spaces:
Paused
Paused
import spaces | |
import os | |
import subprocess | |
import tempfile | |
import glob | |
import gc | |
from huggingface_hub import snapshot_download | |
import gradio as gr | |
from PIL import Image | |
import numpy as np | |
# -------- Model Download -------- | |
repo_id = "Wan-AI/Wan2.2-TI2V-5B" | |
print(f"Downloading/loading checkpoints for {repo_id}...") | |
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) | |
print(f"Using checkpoints from {ckpt_dir}") | |
# -------- Constants -------- | |
FIXED_FPS = 24 | |
MIN_FRAMES_MODEL = 8 | |
MAX_FRAMES_MODEL = 121 | |
MOD_VALUE = 32 | |
DEFAULT_H_SLIDER_VALUE = 704 | |
DEFAULT_W_SLIDER_VALUE = 1280 | |
NEW_FORMULA_MAX_AREA = 1280.0 * 704.0 | |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280 | |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280 | |
# -------- Helpers -------- | |
def _calculate_new_dimensions(pil_image): | |
orig_w, orig_h = pil_image.size | |
if orig_w <= 0 or orig_h <= 0: | |
return DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE | |
aspect_ratio = orig_h / orig_w | |
calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio)) | |
calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio)) | |
calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE) | |
calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE) | |
new_h = int(np.clip(calc_h, SLIDER_MIN_H, (SLIDER_MAX_H // MOD_VALUE) * MOD_VALUE)) | |
new_w = int(np.clip(calc_w, SLIDER_MIN_W, (SLIDER_MAX_W // MOD_VALUE) * MOD_VALUE)) | |
return new_h, new_w | |
def handle_image_upload_for_dims(uploaded_pil_image, current_h_val, current_w_val): | |
if uploaded_pil_image is None: | |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) | |
try: | |
if hasattr(uploaded_pil_image, 'shape'): | |
pil_image = Image.fromarray(uploaded_pil_image).convert("RGB") | |
else: | |
pil_image = uploaded_pil_image | |
new_h, new_w = _calculate_new_dimensions(pil_image) | |
return gr.update(value=new_h), gr.update(value=new_w) | |
except Exception: | |
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) | |
def get_duration(prompt, size, duration_seconds, steps, progress): | |
if duration_seconds >= 3: | |
return 220 | |
elif steps > 35 and duration_seconds >= 2: | |
return 180 | |
elif steps < 35 or duration_seconds < 2: | |
return 105 | |
else: | |
return 90 | |
def find_latest_mp4(): | |
files = glob.glob("*.mp4") | |
if not files: | |
return None | |
latest_file = max(files, key=os.path.getctime) | |
return latest_file | |
# -------- Generation Functions -------- | |
def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)): | |
if not prompt.strip(): | |
return None, None, "Please enter a prompt." | |
temp_dir = tempfile.mkdtemp() | |
# Ensure size is multiples of MOD_VALUE (h*w) | |
try: | |
h, w = size.lower().replace(" ", "").split("*") | |
h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE) | |
w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE) | |
size = f"{h}*{w}" | |
except Exception: | |
size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}" | |
cmd = [ | |
"python", "generate.py", | |
"--task", "ti2v-5B", | |
"--size", size, | |
"--ckpt_dir", ckpt_dir, | |
"--offload_model", "True", | |
"--sample_steps", str(int(steps)), | |
"--convert_model_dtype", | |
"--t5_cpu", | |
"--prompt", prompt | |
] | |
print(f"[T2V] Running command: {' '.join(cmd)}") | |
try: | |
subprocess.run(cmd, check=True) | |
except subprocess.CalledProcessError as e: | |
return None, None, f"Error during T2V generation: {e}" | |
gc.collect() | |
video_file = find_latest_mp4() | |
if video_file is None: | |
return None, None, "Generation finished but no video file found." | |
dest_path = os.path.join(temp_dir, os.path.basename(video_file)) | |
os.rename(video_file, dest_path) | |
download_link = f"<a href='{os.path.basename(dest_path)}' download>π₯ Download Video</a>" | |
return dest_path, download_link, "Text-to-Video generated successfully!" | |
def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)): | |
if image is None or not prompt.strip(): | |
return None, None, "Please upload an image and enter a prompt." | |
temp_dir = tempfile.mkdtemp() | |
try: | |
h, w = size.lower().replace(" ", "").split("*") | |
h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE) | |
w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE) | |
size = f"{h}*{w}" | |
except Exception: | |
size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}" | |
image_path = os.path.join(temp_dir, "input.jpg") | |
Image.fromarray(image).save(image_path) | |
cmd = [ | |
"python", "generate.py", | |
"--task", "ti2v-5B", | |
"--size", size, | |
"--ckpt_dir", ckpt_dir, | |
"--offload_model", "True", | |
"--convert_model_dtype", | |
"--t5_cpu", | |
"--image", image_path, | |
"--prompt", prompt | |
] | |
print(f"[I2V] Running command: {' '.join(cmd)}") | |
try: | |
subprocess.run(cmd, check=True) | |
except subprocess.CalledProcessError as e: | |
return None, None, f"Error during I2V generation: {e}" | |
gc.collect() | |
video_file = find_latest_mp4() | |
if video_file is None: | |
return None, None, "Generation finished but no video file found." | |
dest_path = os.path.join(temp_dir, os.path.basename(video_file)) | |
os.rename(video_file, dest_path) | |
download_link = f"<a href='{os.path.basename(dest_path)}' download>π₯ Download Video</a>" | |
return dest_path, download_link, "Image-to-Video generated successfully!" | |
# -------- Gradio UI -------- | |
with gr.Blocks() as demo: | |
gr.Markdown("## π₯ Wan2.2-TI2V-5B Video Generator") | |
gr.Markdown("Choose **Text-to-Video** or **Image-to-Video** mode below.") | |
with gr.Tab("Text-to-Video"): | |
t2v_prompt = gr.Textbox( | |
label="Prompt", | |
value="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" | |
) | |
t2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}") | |
t2v_duration = gr.Number(label="Video Length (seconds)", value=5) | |
t2v_steps = gr.Number(label="Inference Steps", value=25) | |
t2v_btn = gr.Button("Generate from Text") | |
t2v_video = gr.Video(label="Generated Video") | |
t2v_download = gr.HTML() | |
t2v_status = gr.Textbox(label="Status") | |
t2v_btn.click( | |
generate_t2v, | |
[t2v_prompt, t2v_size, t2v_duration, t2v_steps], | |
[t2v_video, t2v_download, t2v_status] | |
) | |
with gr.Tab("Image-to-Video"): | |
i2v_image = gr.Image(type="numpy", label="Upload Image") | |
i2v_prompt = gr.Textbox( | |
label="Prompt", | |
value=( | |
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " | |
"The fluffy-furred feline gazes directly at the camera with a relaxed expression. " | |
"Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " | |
"and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, " | |
"as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's " | |
"intricate details and the refreshing atmosphere of the seaside." | |
) | |
) | |
i2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}") | |
i2v_duration = gr.Number(label="Video Length (seconds)", value=5) | |
i2v_steps = gr.Number(label="Inference Steps", value=25) | |
i2v_btn = gr.Button("Generate from Image") | |
i2v_video = gr.Video(label="Generated Video") | |
i2v_download = gr.HTML() | |
i2v_status = gr.Textbox(label="Status") | |
i2v_btn.click( | |
generate_i2v, | |
[i2v_image, i2v_prompt, i2v_size, i2v_duration, i2v_steps], | |
[i2v_video, i2v_download, i2v_status] | |
) | |
# Auto adjust size on image upload for i2v | |
i2v_image.upload( | |
fn=handle_image_upload_for_dims, | |
inputs=[i2v_image, i2v_size, i2v_size], | |
outputs=[i2v_size, i2v_size] | |
) | |
i2v_image.clear( | |
fn=handle_image_upload_for_dims, | |
inputs=[i2v_image, i2v_size, i2v_size], | |
outputs=[i2v_size, i2v_size] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |