Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -12,12 +12,13 @@ import spaces
|
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
print(f"Using device: {device}")
|
14 |
|
15 |
-
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
16 |
print("Loading VAE...")
|
|
|
17 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
18 |
|
19 |
-
print("Loading WanPipeline...")
|
20 |
-
#
|
21 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
22 |
|
23 |
flow_shift = 1.0
|
@@ -28,24 +29,24 @@ print("Moving pipeline to device...")
|
|
28 |
pipe.to(device)
|
29 |
|
30 |
# --- LORA FUSING (Done ONCE at startup) ---
|
31 |
-
# We will fuse the base LoRA permanently into the model for performance and to solve the device issue.
|
32 |
-
# This means we cannot dynamically unload it, but it's the correct approach for a fixed setup.
|
33 |
-
|
34 |
print("Loading and Fusing base LoRA...")
|
35 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
36 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
37 |
-
CUSTOM_LORA_NAME = "custom_lora"
|
38 |
|
39 |
try:
|
40 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
41 |
-
|
42 |
-
pipe.load_lora_weights(causvid_path) # Use the default adapter name
|
43 |
print("✅ Base LoRA loaded.")
|
44 |
|
45 |
-
#
|
46 |
-
|
47 |
-
print("✅ Base LoRA fused successfully.")
|
48 |
|
|
|
|
|
|
|
|
|
|
|
49 |
except Exception as e:
|
50 |
print(f"⚠️ Could not load or fuse the base LoRA: {e}")
|
51 |
|
@@ -53,28 +54,25 @@ print("Initialization complete. Gradio is starting...")
|
|
53 |
|
54 |
@spaces.GPU()
|
55 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
56 |
-
# The base LoRA is already fused. We only need to handle the optional custom LoRA.
|
57 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
58 |
|
59 |
-
# We will load and unload the custom LoRA dynamically for each run
|
60 |
if clean_lora_id:
|
61 |
try:
|
62 |
-
print(f"
|
63 |
-
# Load the
|
64 |
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
pipe.to(device, dtype=pipe.transformer.dtype) # Ensure dtype matches
|
69 |
|
70 |
-
|
71 |
-
|
|
|
|
|
72 |
except Exception as e:
|
73 |
-
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}. Running without it.")
|
74 |
-
#
|
75 |
-
pipe.disable_lora()
|
76 |
|
77 |
-
# Apply performance optimizations
|
78 |
apply_cache_on_pipe(pipe)
|
79 |
|
80 |
try:
|
@@ -91,12 +89,16 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
|
|
91 |
image = (image * 255).astype(np.uint8)
|
92 |
return Image.fromarray(image)
|
93 |
finally:
|
94 |
-
# Clean up the dynamic LoRA if it was loaded
|
95 |
if clean_lora_id:
|
96 |
-
print(f"
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
|
|
|
|
|
|
100 |
|
101 |
# --- Your Gradio Interface Code (no changes needed here) ---
|
102 |
iface = gr.Interface(
|
@@ -107,7 +109,7 @@ iface = gr.Interface(
|
|
107 |
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
|
108 |
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
|
109 |
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
|
110 |
-
gr.Textbox(label="LoRA ID (Optional
|
111 |
],
|
112 |
outputs=gr.Image(label="output"),
|
113 |
)
|
|
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
print(f"Using device: {device}")
|
14 |
|
15 |
+
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
16 |
print("Loading VAE...")
|
17 |
+
# VAE is often kept in float32 for precision during decoding
|
18 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
19 |
|
20 |
+
print("Loading WanPipeline in bfloat16...")
|
21 |
+
# Load the main model in bfloat16 to save memory
|
22 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
23 |
|
24 |
flow_shift = 1.0
|
|
|
29 |
pipe.to(device)
|
30 |
|
31 |
# --- LORA FUSING (Done ONCE at startup) ---
|
|
|
|
|
|
|
32 |
print("Loading and Fusing base LoRA...")
|
33 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
34 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
35 |
+
CUSTOM_LORA_NAME = "custom_lora"
|
36 |
|
37 |
try:
|
38 |
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
39 |
+
pipe.load_lora_weights(causvid_path) # Loads LoRA, likely onto CPU
|
|
|
40 |
print("✅ Base LoRA loaded.")
|
41 |
|
42 |
+
pipe.fuse_lora() # Fuses float32 LoRA into bfloat16 model
|
43 |
+
print("✅ Base LoRA fused.")
|
|
|
44 |
|
45 |
+
# FIX for Dtype Mismatch: After fusing, some layers became float32.
|
46 |
+
# We must cast the entire pipeline back to bfloat16 to ensure consistency.
|
47 |
+
pipe.to(dtype=torch.bfloat16)
|
48 |
+
print("✅ Pipeline converted back to bfloat16 post-fusion.")
|
49 |
+
|
50 |
except Exception as e:
|
51 |
print(f"⚠️ Could not load or fuse the base LoRA: {e}")
|
52 |
|
|
|
54 |
|
55 |
@spaces.GPU()
|
56 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
|
|
57 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
58 |
|
|
|
59 |
if clean_lora_id:
|
60 |
try:
|
61 |
+
print(f"Applying custom LoRA '{clean_lora_id}' for this run...")
|
62 |
+
# 1. Load the temporary LoRA
|
63 |
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
64 |
|
65 |
+
# 2. Fuse it into the model. This is the key to avoiding device/meta errors.
|
66 |
+
pipe.fuse_lora(adapter_names=[CUSTOM_LORA_NAME])
|
|
|
67 |
|
68 |
+
# 3. Ensure dtype consistency after the new fusion
|
69 |
+
pipe.to(dtype=torch.bfloat16)
|
70 |
+
print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' fused and activated.")
|
71 |
+
|
72 |
except Exception as e:
|
73 |
+
print(f"⚠️ Failed to load or fuse custom LoRA '{clean_lora_id}': {e}. Running without it.")
|
74 |
+
clean_lora_id = "" # Clear the id so we don't try to clean it up
|
|
|
75 |
|
|
|
76 |
apply_cache_on_pipe(pipe)
|
77 |
|
78 |
try:
|
|
|
89 |
image = (image * 255).astype(np.uint8)
|
90 |
return Image.fromarray(image)
|
91 |
finally:
|
92 |
+
# Clean up the dynamic LoRA if it was successfully loaded and fused
|
93 |
if clean_lora_id:
|
94 |
+
print(f"Cleaning up custom LoRA '{CUSTOM_LORA_NAME}'...")
|
95 |
+
# 1. Unfuse the temporary LoRA to revert the model's weights
|
96 |
+
pipe.unfuse_lora(adapter_names=[CUSTOM_LORA_NAME])
|
97 |
+
|
98 |
+
# 2. Unload the LoRA weights from memory
|
99 |
+
# FIX for TypeError: unload_lora_weights takes no adapter name
|
100 |
+
pipe.unload_lora_weights()
|
101 |
+
print("✅ Custom LoRA unfused and unloaded.")
|
102 |
|
103 |
# --- Your Gradio Interface Code (no changes needed here) ---
|
104 |
iface = gr.Interface(
|
|
|
109 |
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
|
110 |
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
|
111 |
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
|
112 |
+
gr.Textbox(label="LoRA ID (Optional)"),
|
113 |
],
|
114 |
outputs=gr.Image(label="output"),
|
115 |
)
|