Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import
|
3 |
from diffusers import WanPipeline, AutoencoderKLWan
|
4 |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
5 |
from huggingface_hub import hf_hub_download
|
@@ -14,65 +14,97 @@ print(f"Using device: {device}")
|
|
14 |
|
15 |
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
16 |
print("Loading VAE...")
|
17 |
-
# VAE is often kept in float32 for precision during decoding
|
18 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
19 |
|
20 |
print("Loading WanPipeline in bfloat16...")
|
21 |
-
#
|
22 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
23 |
|
24 |
flow_shift = 1.0
|
25 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
26 |
|
27 |
-
# Move the base pipeline to the GPU.
|
28 |
-
print("Moving pipeline to device...")
|
29 |
pipe.to(device)
|
30 |
|
31 |
-
# --- LORA
|
32 |
-
|
33 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
34 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
|
|
35 |
CUSTOM_LORA_NAME = "custom_lora"
|
36 |
|
|
|
37 |
try:
|
38 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
39 |
-
|
40 |
-
print("✅ Base LoRA loaded.")
|
41 |
-
|
42 |
-
pipe.fuse_lora() # Fuses float32 LoRA into bfloat16 model
|
43 |
-
print("✅ Base LoRA fused.")
|
44 |
-
|
45 |
-
# FIX for Dtype Mismatch: After fusing, some layers became float32.
|
46 |
-
# We must cast the entire pipeline back to bfloat16 to ensure consistency.
|
47 |
-
pipe.to(dtype=torch.bfloat16)
|
48 |
-
print("✅ Pipeline converted back to bfloat16 post-fusion.")
|
49 |
-
|
50 |
except Exception as e:
|
51 |
-
|
|
|
52 |
|
53 |
print("Initialization complete. Gradio is starting...")
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
@spaces.GPU()
|
56 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
57 |
-
clean_lora_id = lora_id.strip() if lora_id else ""
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
if clean_lora_id:
|
60 |
try:
|
61 |
-
print(f"
|
62 |
-
# 1. Load the temporary LoRA
|
63 |
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
64 |
|
65 |
-
#
|
66 |
-
pipe
|
67 |
-
|
68 |
-
# 3. Ensure dtype consistency after the new fusion
|
69 |
-
pipe.to(dtype=torch.bfloat16)
|
70 |
-
print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' fused and activated.")
|
71 |
|
|
|
|
|
72 |
except Exception as e:
|
73 |
-
print(f"⚠️ Failed to load
|
74 |
-
|
|
|
|
|
75 |
|
|
|
|
|
|
|
|
|
|
|
76 |
apply_cache_on_pipe(pipe)
|
77 |
|
78 |
try:
|
@@ -89,18 +121,19 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
|
|
89 |
image = (image * 255).astype(np.uint8)
|
90 |
return Image.fromarray(image)
|
91 |
finally:
|
92 |
-
#
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
|
|
|
|
102 |
|
103 |
-
# --- Your Gradio Interface Code (no changes needed here) ---
|
104 |
iface = gr.Interface(
|
105 |
fn=generate,
|
106 |
inputs=[
|
|
|
1 |
import torch
|
2 |
+
from diffusers import UniPCMultepScheduler
|
3 |
from diffusers import WanPipeline, AutoencoderKLWan
|
4 |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
5 |
from huggingface_hub import hf_hub_download
|
|
|
14 |
|
15 |
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
16 |
print("Loading VAE...")
|
|
|
17 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
18 |
|
19 |
print("Loading WanPipeline in bfloat16...")
|
20 |
+
# This will use ZeroGPU/accelerate with meta devices
|
21 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
22 |
|
23 |
flow_shift = 1.0
|
24 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
25 |
|
26 |
+
# Move the base pipeline to the GPU. ZeroGPU will manage this.
|
27 |
+
print("Moving pipeline to device (ZeroGPU will handle offloading)...")
|
28 |
pipe.to(device)
|
29 |
|
30 |
+
# --- LORA SETUP ---
|
31 |
+
# We will NOT fuse anything. Everything will be handled dynamically.
|
32 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
33 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
34 |
+
BASE_LORA_NAME = "causvid_lora"
|
35 |
CUSTOM_LORA_NAME = "custom_lora"
|
36 |
|
37 |
+
print("Downloading base LoRA...")
|
38 |
try:
|
39 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
40 |
+
print("✅ Base LoRA downloaded.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
except Exception as e:
|
42 |
+
causvid_path = None
|
43 |
+
print(f"⚠️ Could not download base LoRA: {e}")
|
44 |
|
45 |
print("Initialization complete. Gradio is starting...")
|
46 |
|
47 |
+
def move_adapter_to_device(pipe, adapter_name, device):
|
48 |
+
"""
|
49 |
+
Surgically moves only the parameters of a specific LoRA adapter to the target device.
|
50 |
+
This avoids touching the base model's meta tensors.
|
51 |
+
"""
|
52 |
+
print(f"Moving adapter '{adapter_name}' to {device}...")
|
53 |
+
for param in pipe.transformer.parameters():
|
54 |
+
if hasattr(param, "adapter_name") and param.adapter_name == adapter_name:
|
55 |
+
param.data = param.data.to(device, non_blocking=True)
|
56 |
+
if param.grad is not None:
|
57 |
+
param.grad.data = param.grad.data.to(device, non_blocking=True)
|
58 |
+
print(f"✅ Adapter '{adapter_name}' moved.")
|
59 |
+
|
60 |
@spaces.GPU()
|
61 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
|
|
62 |
|
63 |
+
# --- DYNAMIC LORA MANAGEMENT FOR EACH RUN ---
|
64 |
+
# Start with a clean slate by disabling any active adapters from previous runs
|
65 |
+
pipe.disable_lora()
|
66 |
+
|
67 |
+
active_adapters = []
|
68 |
+
adapter_weights = []
|
69 |
+
|
70 |
+
# 1. Load the Base LoRA
|
71 |
+
if causvid_path:
|
72 |
+
try:
|
73 |
+
# We load it for every run to ensure a clean state
|
74 |
+
print(f"Loading base LoRA '{BASE_LORA_NAME}'...")
|
75 |
+
pipe.load_lora_weights(causvid_path, adapter_name=BASE_LORA_NAME)
|
76 |
+
|
77 |
+
# THE CRITICAL FIX: Move only this adapter's weights to the GPU
|
78 |
+
move_adapter_to_device(pipe, BASE_LORA_NAME, device)
|
79 |
+
|
80 |
+
active_adapters.append(BASE_LORA_NAME)
|
81 |
+
adapter_weights.append(1.0)
|
82 |
+
except Exception as e:
|
83 |
+
print(f"⚠️ Failed to load base LoRA: {e}")
|
84 |
+
|
85 |
+
# 2. Load the Custom LoRA if provided
|
86 |
+
clean_lora_id = lora_id.strip() if lora_id else ""
|
87 |
if clean_lora_id:
|
88 |
try:
|
89 |
+
print(f"Loading custom LoRA '{CUSTOM_LORA_NAME}' from '{clean_lora_id}'...")
|
|
|
90 |
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
91 |
|
92 |
+
# THE CRITICAL FIX: Move only this adapter's weights to the GPU
|
93 |
+
move_adapter_to_device(pipe, CUSTOM_LORA_NAME, device)
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
active_adapters.append(CUSTOM_LORA_NAME)
|
96 |
+
adapter_weights.append(1.0)
|
97 |
except Exception as e:
|
98 |
+
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}")
|
99 |
+
# If it fails, delete the adapter config to prevent issues
|
100 |
+
if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
|
101 |
+
del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
|
102 |
|
103 |
+
# 3. Activate the successfully loaded adapters
|
104 |
+
if active_adapters:
|
105 |
+
print(f"Activating adapters: {active_adapters} with weights: {adapter_weights}")
|
106 |
+
pipe.set_adapters(active_adapters, adapter_weights)
|
107 |
+
|
108 |
apply_cache_on_pipe(pipe)
|
109 |
|
110 |
try:
|
|
|
121 |
image = (image * 255).astype(np.uint8)
|
122 |
return Image.fromarray(image)
|
123 |
finally:
|
124 |
+
# --- PROPER CLEANUP ---
|
125 |
+
print("Cleaning up LoRAs for this run...")
|
126 |
+
# Disable adapters to stop them from being used
|
127 |
+
pipe.disable_lora()
|
128 |
+
|
129 |
+
# Delete the LoRA configs from the model to truly unload them
|
130 |
+
if BASE_LORA_NAME in pipe.transformer.peft_config:
|
131 |
+
del pipe.transformer.peft_config[BASE_LORA_NAME]
|
132 |
+
if CUSTOM_LORA_NAME in pipe.transformer.peft_config:
|
133 |
+
del pipe.transformer.peft_config[CUSTOM_LORA_NAME]
|
134 |
+
print("✅ LoRAs cleaned up.")
|
135 |
+
|
136 |
|
|
|
137 |
iface = gr.Interface(
|
138 |
fn=generate,
|
139 |
inputs=[
|