Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,822 Bytes
812e69e 17aa94d 09fcd4a 85f6fcb 1b24a66 812e69e ab613df 69a9a62 3e81ff5 ab613df 09fcd4a 5d27053 09fcd4a 5d27053 adefe82 09fcd4a 4294cf6 09fcd4a 4294cf6 15a6d13 f2c0f66 09fcd4a 4294cf6 09fcd4a 4294cf6 09fcd4a 4294cf6 812e69e 4294cf6 4166d00 4294cf6 4166d00 4294cf6 09fcd4a 4294cf6 4166d00 09fcd4a ec1fc83 09fcd4a ec1fc83 4294cf6 4166d00 4294cf6 812e69e 4294cf6 28ae721 75ef805 78004f2 812e69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import torch
from diffusers import UniPCMultistepScheduler
from diffusers import WanPipeline, AutoencoderKLWan
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
from huggingface_hub import hf_hub_download
from PIL import Image
import numpy as np
import gradio as gr
import spaces
import gc
# --- INITIAL SETUP ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
# --- PRE-FUSION STRATEGY ---
# 1. Load everything on the CPU first. Do NOT use device_map or torch_dtype here yet.
# This prevents ZeroGPU/accelerate from taking over too early.
print("Loading all components on CPU for pre-fusion...")
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae")
pipe = WanPipeline.from_pretrained(model_id, vae=vae)
# 2. Load and FUSE the LoRA while the entire pipeline is still on the CPU.
print("Loading and fusing base LoRA on CPU...")
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
try:
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
pipe.load_lora_weights(causvid_path)
print("✅ Base LoRA loaded on CPU.")
pipe.fuse_lora()
print("✅ Base LoRA fused on CPU.")
# After fusing, we must unload the original LoRA weights to save memory
pipe.unload_lora_weights()
print("✅ Original LoRA weights unloaded.")
except Exception as e:
print(f"⚠️ Could not pre-fuse the LoRA: {e}. The model will run without it.")
# 3. NOW, move the final, fused model to the target device and set the dtype.
# ZeroGPU will now take over a single, coherent model.
print("Moving the final, fused pipeline to device and setting dtype...")
pipe.to(device=device, dtype=torch.bfloat16)
# 4. Configure the scheduler
flow_shift = 1.0
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
# Clean up memory before starting the server
del vae
gc.collect()
torch.cuda.empty_cache()
print("Initialization complete. Gradio is starting...")
@spaces.GPU()
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, progress=gr.Progress(track_tqdm=True)):
# The LoRA is permanently fused. No dynamic loading is possible or needed.
apply_cache_on_pipe(pipe)
try:
output = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=1,
num_inference_steps=num_inference_steps,
guidance_scale=1.0,
)
image = output.frames[0][0]
image = (image * 255).astype(np.uint8)
return Image.fromarray(image)
finally:
# No cleanup needed as the state is static
pass
# Interface is simplified as the custom LoRA option is removed for stability.
iface = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(label="Input prompt"),
gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
],
outputs=gr.Image(label="output"),
)
iface.launch() |