import spaces import torch from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler from diffusers.utils import export_to_video from transformers import CLIPVisionModel import gradio as gr import tempfile import os import subprocess from huggingface_hub import hf_hub_download import numpy as np from PIL import Image import random import warnings warnings.filterwarnings("ignore") MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX" LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors" # --- Model Loading at Startup --- image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16) vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16) pipe = WanImageToVideoPipeline.from_pretrained( MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16 ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) pipe.enable_model_cpu_offload() # LoRA Loading try: causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) print("✅ LoRA downloaded to:", causvid_path) pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75]) pipe.fuse_lora() except Exception as e: print(f"❌ Error during LoRA loading: {e}") # --- Constants --- MOD_VALUE = 32 DEFAULT_H, DEFAULT_W = 640, 1024 MAX_AREA = DEFAULT_H * DEFAULT_W SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024 SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024 MAX_SEED = np.iinfo(np.int32).max FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 240 default_prompt = "make this image come alive, cinematic motion, smooth animation" default_neg_prompt = "static, blurry, watermark, text, signature, ugly, deformed" # --- Main Generation Function --- # THE FIX: Set a generous, FIXED duration for the decorator. 180 seconds (3 minutes) # should be enough for the longest video generation. @spaces.GPU(duration=180) def generate_video(input_image, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)): if input_image is None: raise gr.Error("Please upload an input image.") target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) raw_frames = int(round(duration_seconds * FIXED_FPS)) num_frames = ((raw_frames - 1) // 4) * 4 + 1 num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES) if num_frames > 120 and max(target_h, target_w) > 768: scale = 768 / max(target_h, target_w) target_h = max(MOD_VALUE, int(target_h * scale) // MOD_VALUE * MOD_VALUE) target_w = max(MOD_VALUE, int(target_w * scale) // MOD_VALUE * MOD_VALUE) gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.") current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) try: torch.cuda.empty_cache() with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): frames = pipe( image=resized_image, prompt=prompt, negative_prompt=negative_prompt, height=target_h, width=target_w, num_frames=num_frames, guidance_scale=float(guidance_scale), num_inference_steps=int(steps), generator=torch.Generator(device="cuda").manual_seed(current_seed), return_dict=True ).frames[0] except torch.cuda.OutOfMemoryError as e: raise gr.Error("Out of GPU memory. Try reducing duration or resolution.") except Exception as e: raise gr.Error(f"Generation failed: {e}") finally: torch.cuda.empty_cache() with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: video_path = tmpfile.name import imageio writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264', pixelformat='yuv420p', quality=8) for frame in frames: writer.append_data(np.array(frame)) writer.close() return video_path, current_seed # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("# Wan 2.1 I2V FusionX-LoRA") with gr.Row(): with gr.Column(): input_image_comp = gr.Image(type="pil", label="Input Image") prompt_comp = gr.Textbox(label="Prompt", value=default_prompt) duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS, 1), maximum=round(MAX_FRAMES/FIXED_FPS, 1), step=0.1, value=2, label="Duration (s)") with gr.Accordion("Advanced Settings", open=False): neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3) seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) rand_seed_comp = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): height_comp = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H, label="Height") width_comp = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W, label="Width") steps_comp = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Steps") guidance_comp = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="CFG Scale", visible=False) gen_button = gr.Button("Generate Video", variant="primary") with gr.Column(): video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False) gr.Markdown("### Tips:\n- For long videos (>5s), consider lower resolutions.\n- 4-8 steps is often optimal.") def handle_upload(img): if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W) try: w, h = img.size a = h / w h_new = int(np.sqrt(MAX_AREA * a)) w_new = int(np.sqrt(MAX_AREA / a)) h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE) w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE) return gr.update(value=h_final), gr.update(value=w_final) except Exception: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W) input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp]) inputs = [input_image_comp, prompt_comp, height_comp, width_comp, neg_prompt_comp, duration_comp, guidance_comp, steps_comp, seed_comp, rand_seed_comp] outputs = [video_comp, seed_comp] gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs) if __name__ == "__main__": demo.queue(max_size=3).launch()