Spaces:
Runtime error
Runtime error
import spaces | |
import torch | |
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler | |
from diffusers.utils import export_to_video | |
from transformers import CLIPVisionModel | |
import gradio as gr | |
import tempfile | |
import os | |
import subprocess | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
from PIL import Image | |
import random | |
import warnings | |
warnings.filterwarnings("ignore") | |
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX" | |
LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors" | |
# --- Model Loading at Startup --- | |
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16) | |
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16) | |
pipe = WanImageToVideoPipeline.from_pretrained( | |
MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16 | |
) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) | |
pipe.enable_model_cpu_offload() | |
# LoRA Loading | |
try: | |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) | |
print("β LoRA downloaded to:", causvid_path) | |
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") | |
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75]) | |
pipe.fuse_lora() | |
except Exception as e: | |
print(f"β Error during LoRA loading: {e}") | |
# --- Constants --- | |
MOD_VALUE = 32 | |
DEFAULT_H, DEFAULT_W = 640, 1024 | |
MAX_AREA = DEFAULT_H * DEFAULT_W | |
SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024 | |
SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024 | |
MAX_SEED = np.iinfo(np.int32).max | |
FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 240 | |
default_prompt = "make this image come alive, cinematic motion, smooth animation" | |
default_neg_prompt = "static, blurry, watermark, text, signature, ugly, deformed" | |
# --- Main Generation Function --- | |
# THE FIX: Set a generous, FIXED duration for the decorator. 180 seconds (3 minutes) | |
# should be enough for the longest video generation. | |
def generate_video(input_image, prompt, height, width, | |
negative_prompt, duration_seconds, | |
guidance_scale, steps, seed, randomize_seed, | |
progress=gr.Progress(track_tqdm=True)): | |
if input_image is None: | |
raise gr.Error("Please upload an input image.") | |
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) | |
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) | |
raw_frames = int(round(duration_seconds * FIXED_FPS)) | |
num_frames = ((raw_frames - 1) // 4) * 4 + 1 | |
num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES) | |
if num_frames > 120 and max(target_h, target_w) > 768: | |
scale = 768 / max(target_h, target_w) | |
target_h = max(MOD_VALUE, int(target_h * scale) // MOD_VALUE * MOD_VALUE) | |
target_w = max(MOD_VALUE, int(target_w * scale) // MOD_VALUE * MOD_VALUE) | |
gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.") | |
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS) | |
try: | |
torch.cuda.empty_cache() | |
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16): | |
frames = pipe( | |
image=resized_image, prompt=prompt, negative_prompt=negative_prompt, | |
height=target_h, width=target_w, num_frames=num_frames, | |
guidance_scale=float(guidance_scale), num_inference_steps=int(steps), | |
generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
return_dict=True | |
).frames[0] | |
except torch.cuda.OutOfMemoryError as e: | |
raise gr.Error("Out of GPU memory. Try reducing duration or resolution.") | |
except Exception as e: | |
raise gr.Error(f"Generation failed: {e}") | |
finally: | |
torch.cuda.empty_cache() | |
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
video_path = tmpfile.name | |
import imageio | |
writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264', | |
pixelformat='yuv420p', quality=8) | |
for frame in frames: | |
writer.append_data(np.array(frame)) | |
writer.close() | |
return video_path, current_seed | |
# --- Gradio UI --- | |
with gr.Blocks() as demo: | |
gr.Markdown("# Wan 2.1 I2V FusionX-LoRA") | |
with gr.Row(): | |
with gr.Column(): | |
input_image_comp = gr.Image(type="pil", label="Input Image") | |
prompt_comp = gr.Textbox(label="Prompt", value=default_prompt) | |
duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS, 1), maximum=round(MAX_FRAMES/FIXED_FPS, 1), step=0.1, value=2, label="Duration (s)") | |
with gr.Accordion("Advanced Settings", open=False): | |
neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3) | |
seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) | |
rand_seed_comp = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
height_comp = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H, label="Height") | |
width_comp = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W, label="Width") | |
steps_comp = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Steps") | |
guidance_comp = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="CFG Scale", visible=False) | |
gen_button = gr.Button("Generate Video", variant="primary") | |
with gr.Column(): | |
video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False) | |
gr.Markdown("### Tips:\n- For long videos (>5s), consider lower resolutions.\n- 4-8 steps is often optimal.") | |
def handle_upload(img): | |
if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W) | |
try: | |
w, h = img.size | |
a = h / w | |
h_new = int(np.sqrt(MAX_AREA * a)) | |
w_new = int(np.sqrt(MAX_AREA / a)) | |
h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE) | |
w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE) | |
return gr.update(value=h_final), gr.update(value=w_final) | |
except Exception: | |
return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W) | |
input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp]) | |
inputs = [input_image_comp, prompt_comp, height_comp, width_comp, neg_prompt_comp, duration_comp, guidance_comp, steps_comp, seed_comp, rand_seed_comp] | |
outputs = [video_comp, seed_comp] | |
gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs) | |
if __name__ == "__main__": | |
demo.queue(max_size=3).launch() |