Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,97 +1,82 @@
|
|
1 |
import torch
|
2 |
-
from diffusers import UniPCMultistepScheduler
|
3 |
-
from diffusers import WanPipeline, AutoencoderKLWan
|
4 |
-
# from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
5 |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
6 |
-
from diffusers.models import UNetSpatioTemporalConditionModel
|
7 |
-
from transformers import T5EncoderModel, T5Tokenizer
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
-
|
10 |
from PIL import Image
|
11 |
import numpy as np
|
12 |
-
|
13 |
import gradio as gr
|
14 |
import spaces
|
15 |
|
|
|
16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
17 |
|
18 |
-
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
|
|
19 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
|
|
|
|
|
|
20 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
21 |
-
|
|
|
22 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
23 |
|
|
|
|
|
24 |
pipe.to(device)
|
25 |
|
26 |
-
#
|
27 |
-
#
|
28 |
-
#
|
29 |
-
# beta_start=0.00085, # Starting beta value
|
30 |
-
# beta_end=0.012, # Ending beta value
|
31 |
-
# beta_schedule="linear", # Linear beta schedule (other options: "scaled_linear", "squaredcos_cap_v2")
|
32 |
-
# num_train_timesteps=1000, # Number of timesteps
|
33 |
-
# flow_shift=flow_shift
|
34 |
-
# )
|
35 |
-
|
36 |
|
37 |
-
|
38 |
-
# pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
39 |
-
# pipe.scheduler.config,
|
40 |
-
# flow_shift=flow_shift # Retain flow_shift for WanPipeline compatibility
|
41 |
-
# )
|
42 |
-
|
43 |
-
# --- LoRA State Management ---
|
44 |
-
# Define unique names for our adapters
|
45 |
-
DEFAULT_LORA_NAME = "causvid_lora"
|
46 |
-
CUSTOM_LORA_NAME = "custom_lora"
|
47 |
-
# Track which custom LoRA is currently loaded to avoid reloading
|
48 |
-
CURRENTLY_LOADED_CUSTOM_LORA = None
|
49 |
-
|
50 |
-
# Load the default base LoRA ONCE at startup
|
51 |
-
print("Loading base LoRA...")
|
52 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
53 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
54 |
-
#
|
55 |
-
# causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
56 |
-
# pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
|
57 |
-
# print(f"✅ Default LoRA '{DEFAULT_LORA_NAME}' loaded successfully.")
|
58 |
-
# except Exception as e:
|
59 |
-
# print(f"⚠️ Default LoRA could not be loaded: {e}")
|
60 |
-
# DEFAULT_LORA_NAME = None
|
61 |
-
|
62 |
-
# print("Initialization complete. Gradio is starting...")
|
63 |
-
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
@spaces.GPU()
|
67 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
68 |
-
#
|
69 |
-
# pipe.unload_lora_weights()
|
70 |
-
# pipe.load_lora_weights(lora_id.strip())
|
71 |
-
|
72 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
73 |
-
|
74 |
-
|
75 |
-
pipe.load_lora_weights(causvid_path, adapter_name=DEFAULT_LORA_NAME)
|
76 |
-
|
77 |
-
# If a custom LoRA is provided, load it as well.
|
78 |
if clean_lora_id:
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
95 |
try:
|
96 |
output = pipe(
|
97 |
prompt=prompt,
|
@@ -100,36 +85,29 @@ def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_ste
|
|
100 |
width=width,
|
101 |
num_frames=1,
|
102 |
num_inference_steps=num_inference_steps,
|
103 |
-
guidance_scale=1.0,
|
104 |
)
|
105 |
image = output.frames[0][0]
|
106 |
image = (image * 255).astype(np.uint8)
|
107 |
return Image.fromarray(image)
|
108 |
finally:
|
109 |
-
#
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
# pipe.disable_lora()
|
118 |
-
print("Unloading all LoRAs to clean up.")
|
119 |
-
pipe.unload_lora_weights()
|
120 |
-
|
121 |
-
|
122 |
iface = gr.Interface(
|
123 |
fn=generate,
|
124 |
inputs=[
|
125 |
gr.Textbox(label="Input prompt"),
|
126 |
-
],
|
127 |
-
additional_inputs = [
|
128 |
gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
|
129 |
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
|
130 |
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
|
131 |
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
|
132 |
-
gr.Textbox(label="LoRA ID"),
|
133 |
],
|
134 |
outputs=gr.Image(label="output"),
|
135 |
)
|
|
|
1 |
import torch
|
2 |
+
from diffusers import UniPCMultistepScheduler
|
3 |
+
from diffusers import WanPipeline, AutoencoderKLWan
|
|
|
4 |
from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
|
|
|
|
|
5 |
from huggingface_hub import hf_hub_download
|
|
|
6 |
from PIL import Image
|
7 |
import numpy as np
|
|
|
8 |
import gradio as gr
|
9 |
import spaces
|
10 |
|
11 |
+
# --- INITIAL SETUP ---
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
print(f"Using device: {device}")
|
14 |
|
15 |
+
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" # Using the large model
|
16 |
+
print("Loading VAE...")
|
17 |
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
18 |
+
|
19 |
+
print("Loading WanPipeline...")
|
20 |
+
# low_cpu_mem_usage is often needed for meta device loading
|
21 |
pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
|
22 |
+
|
23 |
+
flow_shift = 1.0
|
24 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=flow_shift)
|
25 |
|
26 |
+
# Move the base pipeline to the GPU. This is where meta-device loading happens.
|
27 |
+
print("Moving pipeline to device...")
|
28 |
pipe.to(device)
|
29 |
|
30 |
+
# --- LORA FUSING (Done ONCE at startup) ---
|
31 |
+
# We will fuse the base LoRA permanently into the model for performance and to solve the device issue.
|
32 |
+
# This means we cannot dynamically unload it, but it's the correct approach for a fixed setup.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
print("Loading and Fusing base LoRA...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
CAUSVID_LORA_REPO = "Kijai/WanVideo_comfy"
|
36 |
CAUSVID_LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
37 |
+
CUSTOM_LORA_NAME = "custom_lora" # For any optional LoRA
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
try:
|
40 |
+
causvid_path = hf_hub_download(repo_id=CAUSVID_LORA_REPO, filename=CAUSVID_LORA_FILENAME)
|
41 |
+
# 1. Load the LoRA weights. They will likely be on the CPU.
|
42 |
+
pipe.load_lora_weights(causvid_path) # Use the default adapter name
|
43 |
+
print("✅ Base LoRA loaded.")
|
44 |
+
|
45 |
+
# 2. Fuse the weights into the base model. This resolves the device mismatch.
|
46 |
+
pipe.fuse_lora()
|
47 |
+
print("✅ Base LoRA fused successfully.")
|
48 |
+
|
49 |
+
except Exception as e:
|
50 |
+
print(f"⚠️ Could not load or fuse the base LoRA: {e}")
|
51 |
+
|
52 |
+
print("Initialization complete. Gradio is starting...")
|
53 |
|
54 |
@spaces.GPU()
|
55 |
def generate(prompt, negative_prompt, width=1024, height=1024, num_inference_steps=30, lora_id=None, progress=gr.Progress(track_tqdm=True)):
|
56 |
+
# The base LoRA is already fused. We only need to handle the optional custom LoRA.
|
|
|
|
|
|
|
57 |
clean_lora_id = lora_id.strip() if lora_id else ""
|
58 |
+
|
59 |
+
# We will load and unload the custom LoRA dynamically for each run
|
|
|
|
|
|
|
60 |
if clean_lora_id:
|
61 |
+
try:
|
62 |
+
print(f"Loading custom LoRA '{clean_lora_id}' for this run...")
|
63 |
+
# Load the custom LoRA. Note: We will NOT fuse this one to keep it temporary.
|
64 |
+
pipe.load_lora_weights(clean_lora_id, adapter_name=CUSTOM_LORA_NAME)
|
65 |
+
|
66 |
+
# ✨ This is the critical part for dynamic LoRAs on large models.
|
67 |
+
# We must explicitly move the adapter to the correct device.
|
68 |
+
pipe.to(device, dtype=pipe.transformer.dtype) # Ensure dtype matches
|
69 |
+
|
70 |
+
pipe.set_adapters([CUSTOM_LORA_NAME], adapter_weights=[1.0])
|
71 |
+
print(f"✅ Custom LoRA '{CUSTOM_LORA_NAME}' activated.")
|
72 |
+
except Exception as e:
|
73 |
+
print(f"⚠️ Failed to load custom LoRA '{clean_lora_id}': {e}. Running without it.")
|
74 |
+
# Ensure no adapters are active if loading failed
|
75 |
+
pipe.disable_lora()
|
76 |
+
|
77 |
+
# Apply performance optimizations
|
78 |
+
apply_cache_on_pipe(pipe)
|
79 |
+
|
80 |
try:
|
81 |
output = pipe(
|
82 |
prompt=prompt,
|
|
|
85 |
width=width,
|
86 |
num_frames=1,
|
87 |
num_inference_steps=num_inference_steps,
|
88 |
+
guidance_scale=1.0,
|
89 |
)
|
90 |
image = output.frames[0][0]
|
91 |
image = (image * 255).astype(np.uint8)
|
92 |
return Image.fromarray(image)
|
93 |
finally:
|
94 |
+
# Clean up the dynamic LoRA if it was loaded
|
95 |
+
if clean_lora_id:
|
96 |
+
print(f"Unloading '{CUSTOM_LORA_NAME}' to clean up.")
|
97 |
+
pipe.unload_lora_weights(CUSTOM_LORA_NAME)
|
98 |
+
# It's good practice to disable all just in case.
|
99 |
+
pipe.disable_lora()
|
100 |
+
|
101 |
+
# --- Your Gradio Interface Code (no changes needed here) ---
|
|
|
|
|
|
|
|
|
|
|
102 |
iface = gr.Interface(
|
103 |
fn=generate,
|
104 |
inputs=[
|
105 |
gr.Textbox(label="Input prompt"),
|
|
|
|
|
106 |
gr.Textbox(label="Negative prompt", value = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"),
|
107 |
gr.Slider(label="Width", minimum=480, maximum=1280, step=16, value=1024),
|
108 |
gr.Slider(label="Height", minimum=480, maximum=1280, step=16, value=1024),
|
109 |
gr.Slider(minimum=1, maximum=80, step=1, label="Inference Steps", value=10),
|
110 |
+
gr.Textbox(label="LoRA ID (Optional, loads dynamically)"),
|
111 |
],
|
112 |
outputs=gr.Image(label="output"),
|
113 |
)
|