wan-fusionx-lora / app_lora.py
Lemonator's picture
Update app_lora.py
e9b5918 verified
raw
history blame
7.11 kB
import spaces
import torch
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler
from diffusers.utils import export_to_video
from transformers import CLIPVisionModel
import gradio as gr
import tempfile
import os
import subprocess
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
import random
import warnings
warnings.filterwarnings("ignore")
MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"
LORA_REPO_ID = "vrgamedevgirl84/Wan14BT2VFusioniX"
LORA_FILENAME = "FusionX_LoRa/Wan2.1_I2V_14B_FusionX_LoRA.safetensors"
# --- Model Loading at Startup ---
image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16)
vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0)
pipe.enable_model_cpu_offload()
# LoRA Loading
try:
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
print("βœ… LoRA downloaded to:", causvid_path)
pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora")
pipe.set_adapters(["causvid_lora"], adapter_weights=[0.75])
pipe.fuse_lora()
except Exception as e:
print(f"❌ Error during LoRA loading: {e}")
# --- Constants ---
MOD_VALUE = 32
DEFAULT_H, DEFAULT_W = 640, 1024
MAX_AREA = DEFAULT_H * DEFAULT_W
SLIDER_MIN_H, SLIDER_MAX_H = 128, 1024
SLIDER_MIN_W, SLIDER_MAX_W = 128, 1024
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS, MIN_FRAMES, MAX_FRAMES = 24, 8, 240
default_prompt = "make this image come alive, cinematic motion, smooth animation"
default_neg_prompt = "static, blurry, watermark, text, signature, ugly, deformed"
# --- Main Generation Function ---
# THE FIX: Set a generous, FIXED duration for the decorator. 180 seconds (3 minutes)
# should be enough for the longest video generation.
@spaces.GPU(duration=180)
def generate_video(input_image, prompt, height, width,
negative_prompt, duration_seconds,
guidance_scale, steps, seed, randomize_seed,
progress=gr.Progress(track_tqdm=True)):
if input_image is None:
raise gr.Error("Please upload an input image.")
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE)
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE)
raw_frames = int(round(duration_seconds * FIXED_FPS))
num_frames = ((raw_frames - 1) // 4) * 4 + 1
num_frames = np.clip(num_frames, MIN_FRAMES, MAX_FRAMES)
if num_frames > 120 and max(target_h, target_w) > 768:
scale = 768 / max(target_h, target_w)
target_h = max(MOD_VALUE, int(target_h * scale) // MOD_VALUE * MOD_VALUE)
target_w = max(MOD_VALUE, int(target_w * scale) // MOD_VALUE * MOD_VALUE)
gr.Info(f"Reduced resolution to {target_w}x{target_h} for long video.")
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
resized_image = input_image.resize((target_w, target_h), Image.Resampling.LANCZOS)
try:
torch.cuda.empty_cache()
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.float16):
frames = pipe(
image=resized_image, prompt=prompt, negative_prompt=negative_prompt,
height=target_h, width=target_w, num_frames=num_frames,
guidance_scale=float(guidance_scale), num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
return_dict=True
).frames[0]
except torch.cuda.OutOfMemoryError as e:
raise gr.Error("Out of GPU memory. Try reducing duration or resolution.")
except Exception as e:
raise gr.Error(f"Generation failed: {e}")
finally:
torch.cuda.empty_cache()
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
import imageio
writer = imageio.get_writer(video_path, fps=FIXED_FPS, codec='libx264',
pixelformat='yuv420p', quality=8)
for frame in frames:
writer.append_data(np.array(frame))
writer.close()
return video_path, current_seed
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("# Wan 2.1 I2V FusionX-LoRA")
with gr.Row():
with gr.Column():
input_image_comp = gr.Image(type="pil", label="Input Image")
prompt_comp = gr.Textbox(label="Prompt", value=default_prompt)
duration_comp = gr.Slider(minimum=round(MIN_FRAMES/FIXED_FPS, 1), maximum=round(MAX_FRAMES/FIXED_FPS, 1), step=0.1, value=2, label="Duration (s)")
with gr.Accordion("Advanced Settings", open=False):
neg_prompt_comp = gr.Textbox(label="Negative Prompt", value=default_neg_prompt, lines=3)
seed_comp = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
rand_seed_comp = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
height_comp = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H, label="Height")
width_comp = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W, label="Width")
steps_comp = gr.Slider(minimum=1, maximum=30, step=1, value=4, label="Steps")
guidance_comp = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.0, label="CFG Scale", visible=False)
gen_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
video_comp = gr.Video(label="Generated Video", autoplay=True, interactive=False)
gr.Markdown("### Tips:\n- For long videos (>5s), consider lower resolutions.\n- 4-8 steps is often optimal.")
def handle_upload(img):
if img is None: return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
try:
w, h = img.size
a = h / w
h_new = int(np.sqrt(MAX_AREA * a))
w_new = int(np.sqrt(MAX_AREA / a))
h_final = max(MOD_VALUE, h_new // MOD_VALUE * MOD_VALUE)
w_final = max(MOD_VALUE, w_new // MOD_VALUE * MOD_VALUE)
return gr.update(value=h_final), gr.update(value=w_final)
except Exception:
return gr.update(value=DEFAULT_H), gr.update(value=DEFAULT_W)
input_image_comp.upload(handle_upload, inputs=input_image_comp, outputs=[height_comp, width_comp])
inputs = [input_image_comp, prompt_comp, height_comp, width_comp, neg_prompt_comp, duration_comp, guidance_comp, steps_comp, seed_comp, rand_seed_comp]
outputs = [video_comp, seed_comp]
gen_button.click(fn=generate_video, inputs=inputs, outputs=outputs)
if __name__ == "__main__":
demo.queue(max_size=3).launch()