Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -19,176 +19,6 @@ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
|
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
22 |
-
MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}}
|
23 |
-
|
24 |
-
def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
25 |
-
global MANUAL_PATCHES_STORE
|
26 |
-
MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}} # Reset for each conversion
|
27 |
-
peft_compatible_state_dict = {}
|
28 |
-
unhandled_keys = []
|
29 |
-
|
30 |
-
original_keys_map_to_diffusers = {}
|
31 |
-
|
32 |
-
# Mapping based on ComfyUI's WanModel structure and PeftAdapterMixin logic
|
33 |
-
# This needs to map the original LoRA key naming to Diffusers' expected PEFT keys
|
34 |
-
# diffusion_model.blocks.0.self_attn.q.lora_down.weight -> transformer.blocks.0.attn1.to_q.lora_A.weight
|
35 |
-
# diffusion_model.blocks.0.ffn.0.lora_down.weight -> transformer.blocks.0.ffn.net.0.proj.lora_A.weight
|
36 |
-
# diffusion_model.text_embedding.0.lora_down.weight -> transformer.condition_embedder.text_embedder.linear_1.lora_A.weight (example)
|
37 |
-
|
38 |
-
# Strip "diffusion_model." and map
|
39 |
-
for k, v in state_dict.items():
|
40 |
-
original_k = k # Keep for logging/debugging
|
41 |
-
if k.startswith("diffusion_model."):
|
42 |
-
k_stripped = k[len("diffusion_model."):]
|
43 |
-
elif k.startswith("difusion_model."): # Handle potential typo
|
44 |
-
k_stripped = k[len("difusion_model."):]
|
45 |
-
logger.warning(f"Key '{original_k}' starts with 'difusion_model.' (potential typo), processing as 'diffusion_model.'.")
|
46 |
-
else:
|
47 |
-
unhandled_keys.append(original_k)
|
48 |
-
continue
|
49 |
-
|
50 |
-
# Handle .diff and .diff_b keys by storing them separately
|
51 |
-
if k_stripped.endswith(".diff"):
|
52 |
-
target_model_key = k_stripped[:-len(".diff")] + ".weight"
|
53 |
-
MANUAL_PATCHES_STORE["diff"][target_model_key] = v
|
54 |
-
continue
|
55 |
-
elif k_stripped.endswith(".diff_b"):
|
56 |
-
target_model_key = k_stripped[:-len(".diff_b")] + ".bias"
|
57 |
-
MANUAL_PATCHES_STORE["diff_b"][target_model_key] = v
|
58 |
-
continue
|
59 |
-
|
60 |
-
# Handle standard LoRA A/B matrices
|
61 |
-
if ".lora_down.weight" in k_stripped:
|
62 |
-
diffusers_key_base = k_stripped.replace(".lora_down.weight", "")
|
63 |
-
# Apply transformations similar to _convert_non_diffusers_wan_lora_to_diffusers from diffusers
|
64 |
-
# but adapt to the PEFT naming convention (lora_A/lora_B)
|
65 |
-
# This part needs careful mapping based on WanTransformer3DModel structure
|
66 |
-
|
67 |
-
# Example mappings (these need to be comprehensive for all layers)
|
68 |
-
if diffusers_key_base.startswith("blocks."):
|
69 |
-
parts = diffusers_key_base.split(".")
|
70 |
-
block_idx = parts[1]
|
71 |
-
attn_type = parts[2] # self_attn or cross_attn
|
72 |
-
proj_type = parts[3] # q, k, v, o
|
73 |
-
|
74 |
-
if attn_type == "self_attn":
|
75 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.attn1.to_{proj_type}.lora_A.weight"
|
76 |
-
elif attn_type == "cross_attn":
|
77 |
-
# WanTransformer3DModel uses attn2 for cross-attention like features
|
78 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.attn2.to_{proj_type}.lora_A.weight"
|
79 |
-
else: # ffn
|
80 |
-
ffn_idx = proj_type # "0" or "2"
|
81 |
-
diffusers_peft_key = f"transformer.blocks.{block_idx}.ffn.net.{ffn_idx}.proj.lora_A.weight"
|
82 |
-
elif diffusers_key_base.startswith("text_embedding."):
|
83 |
-
idx_map = {"0": "linear_1", "2": "linear_2"}
|
84 |
-
idx = diffusers_key_base.split(".")[1]
|
85 |
-
diffusers_peft_key = f"transformer.condition_embedder.text_embedder.{idx_map[idx]}.lora_A.weight"
|
86 |
-
elif diffusers_key_base.startswith("time_embedding."):
|
87 |
-
idx_map = {"0": "linear_1", "2": "linear_2"}
|
88 |
-
idx = diffusers_key_base.split(".")[1]
|
89 |
-
diffusers_peft_key = f"transformer.condition_embedder.time_embedder.{idx_map[idx]}.lora_A.weight"
|
90 |
-
elif diffusers_key_base.startswith("time_projection."): # Assuming '1' from your example
|
91 |
-
diffusers_peft_key = f"transformer.condition_embedder.time_proj.lora_A.weight"
|
92 |
-
elif diffusers_key_base.startswith("patch_embedding"):
|
93 |
-
# WanTransformer3DModel has 'patch_embedding' at the top level
|
94 |
-
diffusers_peft_key = f"transformer.patch_embedding.lora_A.weight" # This needs to match how PEFT would name it
|
95 |
-
elif diffusers_key_base.startswith("head.head"):
|
96 |
-
diffusers_peft_key = f"transformer.proj_out.lora_A.weight"
|
97 |
-
else:
|
98 |
-
unhandled_keys.append(original_k)
|
99 |
-
continue
|
100 |
-
|
101 |
-
peft_compatible_state_dict[diffusers_peft_key] = v
|
102 |
-
original_keys_map_to_diffusers[k_stripped] = diffusers_peft_key
|
103 |
-
|
104 |
-
elif ".lora_up.weight" in k_stripped:
|
105 |
-
# Find the corresponding lora_down key to determine the base name
|
106 |
-
down_key_stripped = k_stripped.replace(".lora_up.weight", ".lora_down.weight")
|
107 |
-
if down_key_stripped in original_keys_map_to_diffusers:
|
108 |
-
diffusers_peft_key_A = original_keys_map_to_diffusers[down_key_stripped]
|
109 |
-
diffusers_peft_key_B = diffusers_peft_key_A.replace(".lora_A.weight", ".lora_B.weight")
|
110 |
-
peft_compatible_state_dict[diffusers_peft_key_B] = v
|
111 |
-
else:
|
112 |
-
unhandled_keys.append(original_k)
|
113 |
-
elif not (k_stripped.endswith(".alpha") or k_stripped.endswith(".dora_scale")): # Alphas are handled by PEFT if lora_A/B present
|
114 |
-
unhandled_keys.append(original_k)
|
115 |
-
|
116 |
-
|
117 |
-
if unhandled_keys:
|
118 |
-
logger.warning(f"Custom Wan LoRA Converter: Unhandled keys: {unhandled_keys}")
|
119 |
-
|
120 |
-
return peft_compatible_state_dict
|
121 |
-
|
122 |
-
|
123 |
-
def apply_manual_diff_patches(pipe_model, patches_store, lora_strength=1.0):
|
124 |
-
if not hasattr(pipe_model, "transformer"):
|
125 |
-
logger.error("Pipeline model does not have a 'transformer' attribute to patch.")
|
126 |
-
return
|
127 |
-
|
128 |
-
transformer = pipe_model.transformer
|
129 |
-
changed_params_count = 0
|
130 |
-
|
131 |
-
for key_base, diff_tensor in patches_store.get("diff", {}).items():
|
132 |
-
# key_base is like "blocks.0.self_attn.q.weight"
|
133 |
-
# We need to prepend "transformer." to match diffusers internal naming
|
134 |
-
target_key_full = f"transformer.{key_base}"
|
135 |
-
try:
|
136 |
-
module_path_parts = target_key_full.split('.')
|
137 |
-
param_name = module_path_parts[-1]
|
138 |
-
module_path = ".".join(module_path_parts[:-1])
|
139 |
-
module = transformer
|
140 |
-
for part in module_path.split('.')[1:]: # Skip the first 'transformer'
|
141 |
-
module = getattr(module, part)
|
142 |
-
|
143 |
-
original_param = getattr(module, param_name)
|
144 |
-
if original_param.shape != diff_tensor.shape:
|
145 |
-
logger.warning(f"Shape mismatch for diff patch on {target_key_full}: model {original_param.shape}, lora {diff_tensor.shape}. Skipping.")
|
146 |
-
continue
|
147 |
-
|
148 |
-
with torch.no_grad():
|
149 |
-
scaled_diff = (lora_strength * diff_tensor.to(original_param.device, original_param.dtype))
|
150 |
-
original_param.data.add_(scaled_diff)
|
151 |
-
changed_params_count +=1
|
152 |
-
except AttributeError:
|
153 |
-
logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff patch.")
|
154 |
-
except Exception as e:
|
155 |
-
logger.error(f"Error applying diff patch to {target_key_full}: {e}")
|
156 |
-
|
157 |
-
|
158 |
-
for key_base, diff_b_tensor in patches_store.get("diff_b", {}).items():
|
159 |
-
# key_base is like "blocks.0.self_attn.q.bias"
|
160 |
-
target_key_full = f"transformer.{key_base}"
|
161 |
-
try:
|
162 |
-
module_path_parts = target_key_full.split('.')
|
163 |
-
param_name = module_path_parts[-1]
|
164 |
-
module_path = ".".join(module_path_parts[:-1])
|
165 |
-
module = transformer
|
166 |
-
for part in module_path.split('.')[1:]:
|
167 |
-
module = getattr(module, part)
|
168 |
-
|
169 |
-
original_param = getattr(module, param_name)
|
170 |
-
if original_param is None:
|
171 |
-
logger.warning(f"Bias parameter {target_key_full} is None in model. Skipping diff_b patch.")
|
172 |
-
continue
|
173 |
-
|
174 |
-
if original_param.shape != diff_b_tensor.shape:
|
175 |
-
logger.warning(f"Shape mismatch for diff_b patch on {target_key_full}: model {original_param.shape}, lora {diff_b_tensor.shape}. Skipping.")
|
176 |
-
continue
|
177 |
-
|
178 |
-
with torch.no_grad():
|
179 |
-
scaled_diff_b = (lora_strength * diff_b_tensor.to(original_param.device, original_param.dtype))
|
180 |
-
original_param.data.add_(scaled_diff_b)
|
181 |
-
changed_params_count +=1
|
182 |
-
except AttributeError:
|
183 |
-
logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff_b patch.")
|
184 |
-
except Exception as e:
|
185 |
-
logger.error(f"Error applying diff_b patch to {target_key_full}: {e}")
|
186 |
-
if changed_params_count > 0:
|
187 |
-
logger.info(f"Applied {changed_params_count} manual diff/diff_b patches.")
|
188 |
-
else:
|
189 |
-
logger.info("No manual diff/diff_b patches were applied.")
|
190 |
-
|
191 |
-
|
192 |
# --- Model Loading ---
|
193 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
194 |
vae = AutoencoderKLWan.from_pretrained(
|
@@ -214,26 +44,7 @@ logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
|
|
214 |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
215 |
|
216 |
logger.info("Loading LoRA weights with custom converter...")
|
217 |
-
|
218 |
-
from safetensors.torch import load_file as load_safetensors
|
219 |
-
raw_lora_state_dict = load_safetensors(causvid_path)
|
220 |
-
|
221 |
-
# Now call our custom converter which will populate MANUAL_PATCHES_STORE
|
222 |
-
peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
|
223 |
-
|
224 |
-
# Load the LoRA A/B matrices using PEFT
|
225 |
-
if peft_state_dict:
|
226 |
-
pipe.load_lora_weights(
|
227 |
-
peft_state_dict,
|
228 |
-
adapter_name="causvid_lora"
|
229 |
-
)
|
230 |
-
logger.info("PEFT LoRA A/B weights loaded.")
|
231 |
-
else:
|
232 |
-
logger.warning("No PEFT-compatible LoRA weights found after conversion.")
|
233 |
-
|
234 |
-
# Apply manual diff_b and diff patches
|
235 |
-
apply_manual_diff_patches(pipe, MANUAL_PATCHES_STORE, lora_strength=1.0) # Assuming default strength 1.0
|
236 |
-
logger.info("Manual diff_b/diff patches applied.")
|
237 |
|
238 |
|
239 |
# --- Gradio Interface Function ---
|
|
|
19 |
logging.basicConfig(level=logging.INFO)
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
# --- Model Loading ---
|
23 |
logger.info(f"Loading VAE for {MODEL_ID}...")
|
24 |
vae = AutoencoderKLWan.from_pretrained(
|
|
|
44 |
causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
|
45 |
|
46 |
logger.info("Loading LoRA weights with custom converter...")
|
47 |
+
pipe.load_lora_weights(causvid_path,adapter_name="causvid_lora")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
|
50 |
# --- Gradio Interface Function ---
|