Spaces:
Sleeping
Sleeping
import torch | |
from diffusers import UniPCMultistepScheduler | |
from diffusers import WanPipeline, AutoencoderKLWan | |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe | |
from huggingface_hub import hf_hub_download | |
from PIL import Image | |
import numpy as np | |
import gradio as gr | |
import spaces | |
# --- INITIAL SETUP --- | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" | |
print("Loading VAE...") | |
# VAE is often kept in float32 for precision during decoding | |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) | |
print("Loading WanPipeline in bfloat16...") | |
# Load the main model in bfloat16 to save memory | |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) | |
flow_shift = 1.0 | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift) | |
# Move the base pipeline to the GPU. This is where meta-device loading happens. | |
print("Moving pipeline to device...") | |
pipe.to(device) | |
# --- LORA FUSING (Done ONCE at startup) --- | |
print("Loading and Fusing base LoRA...") | |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy" | |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
CUSTOM_LORA_NAME = "custom_lora" | |
try: | |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME) | |
pipe.load_lora_weights(causvid_path) # Loads LoRA, likely onto CPU | |
print("✅ Base LoRA loaded.") | |
pipe.fuse_lora() # Fuses float32 LoRA into bfloat16 model | |
print("✅ Base LoRA fused.") | |
# FIX for Dtype Mismatch: After fusing, some layers became float32. | |
# We must cast the entire pipeline back to bfloat16 to ensure consistency. | |
pipe.to(dtype=torch.bfloat16) | |
print("✅ Pipeline converted back to bfloat16 post-fusion.") | |
except Exception as e: | |
print(f"⚠️ Could not load or fuse the base LoRA: {e}") | |
print("Initialization complete. Gradio is starting...") | |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)): | |
clean_lora_id = lora_id.strip() if lora_id else "" | |
if clean_lora_id: | |
try: | |
print(f"Applying custom LoRA '{clean_lora_id}' for this run...") | |
# 1. Load the temporary LoRA | |
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME) | |
# 2. Fuse it into the model. This is the key to avoiding device/meta errors. | |
pipe.fuse_lora(adapter_names=[CUSTOM_LORA_NAME]) | |
# 3. Ensure dtype consistency after the new fusion | |
pipe.to(dtype=torch.bfloat16) | |
print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' fused and activated.") | |
except Exception as e: | |
print(f"⚠️ Failed to load or fuse custom LoRA '{clean_lora_id}': {e}. Running without it.") | |
clean_lora_id = "" # Clear the id so we don't try to clean it up | |
apply_cache_on_pipe(pipe) | |
try: | |
output = pipe( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
height=height, | |
width=width, | |
num_frames=1, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=1.0, | |
) | |
image = output.frames[0][0] | |
image = (image * 255).astype(np.uint8) | |
return Image.fromarray(image) | |
finally: | |
# Clean up the dynamic LoRA if it was successfully loaded and fused | |
if clean_lora_id: | |
print(f"Cleaning up custom LoRA '{CUSTOM_LORA_NAME}'...") | |
# 1. Unfuse the temporary LoRA to revert the model's weights | |
pipe.unfuse_lora(adapter_names=[CUSTOM_LORA_NAME]) | |
# 2. Unload the LoRA weights from memory | |
# FIX for TypeError: unload_lora_weights takes no adapter name | |
pipe.unload_lora_weights() | |
print("✅ Custom LoRA unfused and unloaded.") | |
# --- Your Gradio Interface Code (no changes needed here) --- | |
iface = gr.Interface( | |
fn=generate, | |
inputs=[ | |
gr.Textbox(label="Input prompt"), | |
gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"), | |
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024), | |
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024), | |
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10), | |
gr.Textbox(label="LoRA ID (Optional)"), | |
], | |
outputs=gr.Image(label="output"), | |
) | |
iface.launch() |