rahul7star's picture
Update app.py
b811178 verified
import spaces
import os
import subprocess
import tempfile
import glob
import gc
from huggingface_hub import snapshot_download
import gradio as gr
from PIL import Image
import numpy as np
# -------- Model Download --------
repo_id = "Wan-AI/Wan2.2-TI2V-5B"
print(f"Downloading/loading checkpoints for {repo_id}...")
ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False)
print(f"Using checkpoints from {ckpt_dir}")
# -------- Constants --------
FIXED_FPS = 24
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 121
MOD_VALUE = 32
DEFAULT_H_SLIDER_VALUE = 704
DEFAULT_W_SLIDER_VALUE = 1280
NEW_FORMULA_MAX_AREA = 1280.0 * 704.0
SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280
SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280
# -------- Helpers --------
def _calculate_new_dimensions(pil_image):
orig_w, orig_h = pil_image.size
if orig_w <= 0 or orig_h <= 0:
return DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE
aspect_ratio = orig_h / orig_w
calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio))
calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio))
calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE)
calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE)
new_h = int(np.clip(calc_h, SLIDER_MIN_H, (SLIDER_MAX_H // MOD_VALUE) * MOD_VALUE))
new_w = int(np.clip(calc_w, SLIDER_MIN_W, (SLIDER_MAX_W // MOD_VALUE) * MOD_VALUE))
return new_h, new_w
def handle_image_upload_for_dims(uploaded_pil_image, current_h_val, current_w_val):
if uploaded_pil_image is None:
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
try:
if hasattr(uploaded_pil_image, 'shape'):
pil_image = Image.fromarray(uploaded_pil_image).convert("RGB")
else:
pil_image = uploaded_pil_image
new_h, new_w = _calculate_new_dimensions(pil_image)
return gr.update(value=new_h), gr.update(value=new_w)
except Exception:
return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE)
def get_duration(prompt, size, duration_seconds, steps, progress):
if duration_seconds >= 3:
return 220
elif steps > 35 and duration_seconds >= 2:
return 180
elif steps < 35 or duration_seconds < 2:
return 105
else:
return 90
def find_latest_mp4():
files = glob.glob("*.mp4")
if not files:
return None
latest_file = max(files, key=os.path.getctime)
return latest_file
# -------- Generation Functions --------
@spaces.GPU(duration=get_duration)
def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)):
if not prompt.strip():
return None, None, "Please enter a prompt."
temp_dir = tempfile.mkdtemp()
# Ensure size is multiples of MOD_VALUE (h*w)
try:
h, w = size.lower().replace(" ", "").split("*")
h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE)
w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE)
size = f"{h}*{w}"
except Exception:
size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}"
cmd = [
"python", "generate.py",
"--task", "ti2v-5B",
"--size", size,
"--ckpt_dir", ckpt_dir,
"--offload_model", "True",
"--sample_steps", str(int(steps)),
"--convert_model_dtype",
"--t5_cpu",
"--prompt", prompt
]
print(f"[T2V] Running command: {' '.join(cmd)}")
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
return None, None, f"Error during T2V generation: {e}"
gc.collect()
video_file = find_latest_mp4()
if video_file is None:
return None, None, "Generation finished but no video file found."
dest_path = os.path.join(temp_dir, os.path.basename(video_file))
os.rename(video_file, dest_path)
download_link = f"<a href='{os.path.basename(dest_path)}' download>πŸ“₯ Download Video</a>"
return dest_path, download_link, "Text-to-Video generated successfully!"
@spaces.GPU(duration=get_duration)
def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)):
if image is None or not prompt.strip():
return None, None, "Please upload an image and enter a prompt."
temp_dir = tempfile.mkdtemp()
try:
h, w = size.lower().replace(" ", "").split("*")
h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE)
w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE)
size = f"{h}*{w}"
except Exception:
size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}"
image_path = os.path.join(temp_dir, "input.jpg")
Image.fromarray(image).save(image_path)
cmd = [
"python", "generate.py",
"--task", "ti2v-5B",
"--size", size,
"--ckpt_dir", ckpt_dir,
"--offload_model", "True",
"--convert_model_dtype",
"--t5_cpu",
"--image", image_path,
"--prompt", prompt
]
print(f"[I2V] Running command: {' '.join(cmd)}")
try:
subprocess.run(cmd, check=True)
except subprocess.CalledProcessError as e:
return None, None, f"Error during I2V generation: {e}"
gc.collect()
video_file = find_latest_mp4()
if video_file is None:
return None, None, "Generation finished but no video file found."
dest_path = os.path.join(temp_dir, os.path.basename(video_file))
os.rename(video_file, dest_path)
download_link = f"<a href='{os.path.basename(dest_path)}' download>πŸ“₯ Download Video</a>"
return dest_path, download_link, "Image-to-Video generated successfully!"
# -------- Gradio UI --------
with gr.Blocks() as demo:
gr.Markdown("## πŸŽ₯ Wan2.2-TI2V-5B Video Generator")
gr.Markdown("Choose **Text-to-Video** or **Image-to-Video** mode below.")
with gr.Tab("Text-to-Video"):
t2v_prompt = gr.Textbox(
label="Prompt",
value="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage"
)
t2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}")
t2v_duration = gr.Number(label="Video Length (seconds)", value=5)
t2v_steps = gr.Number(label="Inference Steps", value=25)
t2v_btn = gr.Button("Generate from Text")
t2v_video = gr.Video(label="Generated Video")
t2v_download = gr.HTML()
t2v_status = gr.Textbox(label="Status")
t2v_btn.click(
generate_t2v,
[t2v_prompt, t2v_size, t2v_duration, t2v_steps],
[t2v_video, t2v_download, t2v_status]
)
with gr.Tab("Image-to-Video"):
i2v_image = gr.Image(type="numpy", label="Upload Image")
i2v_prompt = gr.Textbox(
label="Prompt",
value=(
"Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. "
"The fluffy-furred feline gazes directly at the camera with a relaxed expression. "
"Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, "
"and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, "
"as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's "
"intricate details and the refreshing atmosphere of the seaside."
)
)
i2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}")
i2v_duration = gr.Number(label="Video Length (seconds)", value=5)
i2v_steps = gr.Number(label="Inference Steps", value=25)
i2v_btn = gr.Button("Generate from Image")
i2v_video = gr.Video(label="Generated Video")
i2v_download = gr.HTML()
i2v_status = gr.Textbox(label="Status")
i2v_btn.click(
generate_i2v,
[i2v_image, i2v_prompt, i2v_size, i2v_duration, i2v_steps],
[i2v_video, i2v_download, i2v_status]
)
# Auto adjust size on image upload for i2v
i2v_image.upload(
fn=handle_image_upload_for_dims,
inputs=[i2v_image, i2v_size, i2v_size],
outputs=[i2v_size, i2v_size]
)
i2v_image.clear(
fn=handle_image_upload_for_dims,
inputs=[i2v_image, i2v_size, i2v_size],
outputs=[i2v_size, i2v_size]
)
if __name__ == "__main__":
demo.launch()