import spaces import os import subprocess import tempfile import glob import gc from huggingface_hub import snapshot_download import gradio as gr from PIL import Image import numpy as np # -------- Model Download -------- repo_id = "Wan-AI/Wan2.2-TI2V-5B" print(f"Downloading/loading checkpoints for {repo_id}...") ckpt_dir = snapshot_download(repo_id, local_dir_use_symlinks=False) print(f"Using checkpoints from {ckpt_dir}") # -------- Constants -------- FIXED_FPS = 24 MIN_FRAMES_MODEL = 8 MAX_FRAMES_MODEL = 121 MOD_VALUE = 32 DEFAULT_H_SLIDER_VALUE = 704 DEFAULT_W_SLIDER_VALUE = 1280 NEW_FORMULA_MAX_AREA = 1280.0 * 704.0 SLIDER_MIN_H, SLIDER_MAX_H = 128, 1280 SLIDER_MIN_W, SLIDER_MAX_W = 128, 1280 # -------- Helpers -------- def _calculate_new_dimensions(pil_image): orig_w, orig_h = pil_image.size if orig_w <= 0 or orig_h <= 0: return DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE aspect_ratio = orig_h / orig_w calc_h = round(np.sqrt(NEW_FORMULA_MAX_AREA * aspect_ratio)) calc_w = round(np.sqrt(NEW_FORMULA_MAX_AREA / aspect_ratio)) calc_h = max(MOD_VALUE, (calc_h // MOD_VALUE) * MOD_VALUE) calc_w = max(MOD_VALUE, (calc_w // MOD_VALUE) * MOD_VALUE) new_h = int(np.clip(calc_h, SLIDER_MIN_H, (SLIDER_MAX_H // MOD_VALUE) * MOD_VALUE)) new_w = int(np.clip(calc_w, SLIDER_MIN_W, (SLIDER_MAX_W // MOD_VALUE) * MOD_VALUE)) return new_h, new_w def handle_image_upload_for_dims(uploaded_pil_image, current_h_val, current_w_val): if uploaded_pil_image is None: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) try: if hasattr(uploaded_pil_image, 'shape'): pil_image = Image.fromarray(uploaded_pil_image).convert("RGB") else: pil_image = uploaded_pil_image new_h, new_w = _calculate_new_dimensions(pil_image) return gr.update(value=new_h), gr.update(value=new_w) except Exception: return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) def get_duration(prompt, size, duration_seconds, steps, progress): if duration_seconds >= 3: return 220 elif steps > 35 and duration_seconds >= 2: return 180 elif steps < 35 or duration_seconds < 2: return 105 else: return 90 def find_latest_mp4(): files = glob.glob("*.mp4") if not files: return None latest_file = max(files, key=os.path.getctime) return latest_file # -------- Generation Functions -------- @spaces.GPU(duration=get_duration) def generate_t2v(prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)): if not prompt.strip(): return None, None, "Please enter a prompt." temp_dir = tempfile.mkdtemp() # Ensure size is multiples of MOD_VALUE (h*w) try: h, w = size.lower().replace(" ", "").split("*") h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE) w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE) size = f"{h}*{w}" except Exception: size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}" cmd = [ "python", "generate.py", "--task", "ti2v-5B", "--size", size, "--ckpt_dir", ckpt_dir, "--offload_model", "True", "--sample_steps", str(int(steps)), "--convert_model_dtype", "--t5_cpu", "--prompt", prompt ] print(f"[T2V] Running command: {' '.join(cmd)}") try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: return None, None, f"Error during T2V generation: {e}" gc.collect() video_file = find_latest_mp4() if video_file is None: return None, None, "Generation finished but no video file found." dest_path = os.path.join(temp_dir, os.path.basename(video_file)) os.rename(video_file, dest_path) download_link = f"📥 Download Video" return dest_path, download_link, "Text-to-Video generated successfully!" @spaces.GPU(duration=get_duration) def generate_i2v(image, prompt, size="1280*704", duration_seconds=5, steps=25, progress=gr.Progress(track_tqdm=True)): if image is None or not prompt.strip(): return None, None, "Please upload an image and enter a prompt." temp_dir = tempfile.mkdtemp() try: h, w = size.lower().replace(" ", "").split("*") h = max(MOD_VALUE, (int(h) // MOD_VALUE) * MOD_VALUE) w = max(MOD_VALUE, (int(w) // MOD_VALUE) * MOD_VALUE) size = f"{h}*{w}" except Exception: size = f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}" image_path = os.path.join(temp_dir, "input.jpg") Image.fromarray(image).save(image_path) cmd = [ "python", "generate.py", "--task", "ti2v-5B", "--size", size, "--ckpt_dir", ckpt_dir, "--offload_model", "True", "--convert_model_dtype", "--t5_cpu", "--image", image_path, "--prompt", prompt ] print(f"[I2V] Running command: {' '.join(cmd)}") try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as e: return None, None, f"Error during I2V generation: {e}" gc.collect() video_file = find_latest_mp4() if video_file is None: return None, None, "Generation finished but no video file found." dest_path = os.path.join(temp_dir, os.path.basename(video_file)) os.rename(video_file, dest_path) download_link = f"📥 Download Video" return dest_path, download_link, "Image-to-Video generated successfully!" # -------- Gradio UI -------- with gr.Blocks() as demo: gr.Markdown("## 🎥 Wan2.2-TI2V-5B Video Generator") gr.Markdown("Choose **Text-to-Video** or **Image-to-Video** mode below.") with gr.Tab("Text-to-Video"): t2v_prompt = gr.Textbox( label="Prompt", value="Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage" ) t2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}") t2v_duration = gr.Number(label="Video Length (seconds)", value=5) t2v_steps = gr.Number(label="Inference Steps", value=25) t2v_btn = gr.Button("Generate from Text") t2v_video = gr.Video(label="Generated Video") t2v_download = gr.HTML() t2v_status = gr.Textbox(label="Status") t2v_btn.click( generate_t2v, [t2v_prompt, t2v_size, t2v_duration, t2v_steps], [t2v_video, t2v_download, t2v_status] ) with gr.Tab("Image-to-Video"): i2v_image = gr.Image(type="numpy", label="Upload Image") i2v_prompt = gr.Textbox( label="Prompt", value=( "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. " "The fluffy-furred feline gazes directly at the camera with a relaxed expression. " "Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, " "and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, " "as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's " "intricate details and the refreshing atmosphere of the seaside." ) ) i2v_size = gr.Textbox(label="Video Size (HxW)", value=f"{DEFAULT_H_SLIDER_VALUE}*{DEFAULT_W_SLIDER_VALUE}") i2v_duration = gr.Number(label="Video Length (seconds)", value=5) i2v_steps = gr.Number(label="Inference Steps", value=25) i2v_btn = gr.Button("Generate from Image") i2v_video = gr.Video(label="Generated Video") i2v_download = gr.HTML() i2v_status = gr.Textbox(label="Status") i2v_btn.click( generate_i2v, [i2v_image, i2v_prompt, i2v_size, i2v_duration, i2v_steps], [i2v_video, i2v_download, i2v_status] ) # Auto adjust size on image upload for i2v i2v_image.upload( fn=handle_image_upload_for_dims, inputs=[i2v_image, i2v_size, i2v_size], outputs=[i2v_size, i2v_size] ) i2v_image.clear( fn=handle_image_upload_for_dims, inputs=[i2v_image, i2v_size, i2v_size], outputs=[i2v_size, i2v_size] ) if __name__ == "__main__": demo.launch()