multimodalart HF Staff commited on
Commit
2acf5ad
·
verified ·
1 Parent(s): 5158fc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -190
app.py CHANGED
@@ -19,176 +19,6 @@ LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors"
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}}
23
-
24
- def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict):
25
- global MANUAL_PATCHES_STORE
26
- MANUAL_PATCHES_STORE = {"diff": {}, "diff_b": {}} # Reset for each conversion
27
- peft_compatible_state_dict = {}
28
- unhandled_keys = []
29
-
30
- original_keys_map_to_diffusers = {}
31
-
32
- # Mapping based on ComfyUI's WanModel structure and PeftAdapterMixin logic
33
- # This needs to map the original LoRA key naming to Diffusers' expected PEFT keys
34
- # diffusion_model.blocks.0.self_attn.q.lora_down.weight -> transformer.blocks.0.attn1.to_q.lora_A.weight
35
- # diffusion_model.blocks.0.ffn.0.lora_down.weight -> transformer.blocks.0.ffn.net.0.proj.lora_A.weight
36
- # diffusion_model.text_embedding.0.lora_down.weight -> transformer.condition_embedder.text_embedder.linear_1.lora_A.weight (example)
37
-
38
- # Strip "diffusion_model." and map
39
- for k, v in state_dict.items():
40
- original_k = k # Keep for logging/debugging
41
- if k.startswith("diffusion_model."):
42
- k_stripped = k[len("diffusion_model."):]
43
- elif k.startswith("difusion_model."): # Handle potential typo
44
- k_stripped = k[len("difusion_model."):]
45
- logger.warning(f"Key '{original_k}' starts with 'difusion_model.' (potential typo), processing as 'diffusion_model.'.")
46
- else:
47
- unhandled_keys.append(original_k)
48
- continue
49
-
50
- # Handle .diff and .diff_b keys by storing them separately
51
- if k_stripped.endswith(".diff"):
52
- target_model_key = k_stripped[:-len(".diff")] + ".weight"
53
- MANUAL_PATCHES_STORE["diff"][target_model_key] = v
54
- continue
55
- elif k_stripped.endswith(".diff_b"):
56
- target_model_key = k_stripped[:-len(".diff_b")] + ".bias"
57
- MANUAL_PATCHES_STORE["diff_b"][target_model_key] = v
58
- continue
59
-
60
- # Handle standard LoRA A/B matrices
61
- if ".lora_down.weight" in k_stripped:
62
- diffusers_key_base = k_stripped.replace(".lora_down.weight", "")
63
- # Apply transformations similar to _convert_non_diffusers_wan_lora_to_diffusers from diffusers
64
- # but adapt to the PEFT naming convention (lora_A/lora_B)
65
- # This part needs careful mapping based on WanTransformer3DModel structure
66
-
67
- # Example mappings (these need to be comprehensive for all layers)
68
- if diffusers_key_base.startswith("blocks."):
69
- parts = diffusers_key_base.split(".")
70
- block_idx = parts[1]
71
- attn_type = parts[2] # self_attn or cross_attn
72
- proj_type = parts[3] # q, k, v, o
73
-
74
- if attn_type == "self_attn":
75
- diffusers_peft_key = f"transformer.blocks.{block_idx}.attn1.to_{proj_type}.lora_A.weight"
76
- elif attn_type == "cross_attn":
77
- # WanTransformer3DModel uses attn2 for cross-attention like features
78
- diffusers_peft_key = f"transformer.blocks.{block_idx}.attn2.to_{proj_type}.lora_A.weight"
79
- else: # ffn
80
- ffn_idx = proj_type # "0" or "2"
81
- diffusers_peft_key = f"transformer.blocks.{block_idx}.ffn.net.{ffn_idx}.proj.lora_A.weight"
82
- elif diffusers_key_base.startswith("text_embedding."):
83
- idx_map = {"0": "linear_1", "2": "linear_2"}
84
- idx = diffusers_key_base.split(".")[1]
85
- diffusers_peft_key = f"transformer.condition_embedder.text_embedder.{idx_map[idx]}.lora_A.weight"
86
- elif diffusers_key_base.startswith("time_embedding."):
87
- idx_map = {"0": "linear_1", "2": "linear_2"}
88
- idx = diffusers_key_base.split(".")[1]
89
- diffusers_peft_key = f"transformer.condition_embedder.time_embedder.{idx_map[idx]}.lora_A.weight"
90
- elif diffusers_key_base.startswith("time_projection."): # Assuming '1' from your example
91
- diffusers_peft_key = f"transformer.condition_embedder.time_proj.lora_A.weight"
92
- elif diffusers_key_base.startswith("patch_embedding"):
93
- # WanTransformer3DModel has 'patch_embedding' at the top level
94
- diffusers_peft_key = f"transformer.patch_embedding.lora_A.weight" # This needs to match how PEFT would name it
95
- elif diffusers_key_base.startswith("head.head"):
96
- diffusers_peft_key = f"transformer.proj_out.lora_A.weight"
97
- else:
98
- unhandled_keys.append(original_k)
99
- continue
100
-
101
- peft_compatible_state_dict[diffusers_peft_key] = v
102
- original_keys_map_to_diffusers[k_stripped] = diffusers_peft_key
103
-
104
- elif ".lora_up.weight" in k_stripped:
105
- # Find the corresponding lora_down key to determine the base name
106
- down_key_stripped = k_stripped.replace(".lora_up.weight", ".lora_down.weight")
107
- if down_key_stripped in original_keys_map_to_diffusers:
108
- diffusers_peft_key_A = original_keys_map_to_diffusers[down_key_stripped]
109
- diffusers_peft_key_B = diffusers_peft_key_A.replace(".lora_A.weight", ".lora_B.weight")
110
- peft_compatible_state_dict[diffusers_peft_key_B] = v
111
- else:
112
- unhandled_keys.append(original_k)
113
- elif not (k_stripped.endswith(".alpha") or k_stripped.endswith(".dora_scale")): # Alphas are handled by PEFT if lora_A/B present
114
- unhandled_keys.append(original_k)
115
-
116
-
117
- if unhandled_keys:
118
- logger.warning(f"Custom Wan LoRA Converter: Unhandled keys: {unhandled_keys}")
119
-
120
- return peft_compatible_state_dict
121
-
122
-
123
- def apply_manual_diff_patches(pipe_model, patches_store, lora_strength=1.0):
124
- if not hasattr(pipe_model, "transformer"):
125
- logger.error("Pipeline model does not have a 'transformer' attribute to patch.")
126
- return
127
-
128
- transformer = pipe_model.transformer
129
- changed_params_count = 0
130
-
131
- for key_base, diff_tensor in patches_store.get("diff", {}).items():
132
- # key_base is like "blocks.0.self_attn.q.weight"
133
- # We need to prepend "transformer." to match diffusers internal naming
134
- target_key_full = f"transformer.{key_base}"
135
- try:
136
- module_path_parts = target_key_full.split('.')
137
- param_name = module_path_parts[-1]
138
- module_path = ".".join(module_path_parts[:-1])
139
- module = transformer
140
- for part in module_path.split('.')[1:]: # Skip the first 'transformer'
141
- module = getattr(module, part)
142
-
143
- original_param = getattr(module, param_name)
144
- if original_param.shape != diff_tensor.shape:
145
- logger.warning(f"Shape mismatch for diff patch on {target_key_full}: model {original_param.shape}, lora {diff_tensor.shape}. Skipping.")
146
- continue
147
-
148
- with torch.no_grad():
149
- scaled_diff = (lora_strength * diff_tensor.to(original_param.device, original_param.dtype))
150
- original_param.data.add_(scaled_diff)
151
- changed_params_count +=1
152
- except AttributeError:
153
- logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff patch.")
154
- except Exception as e:
155
- logger.error(f"Error applying diff patch to {target_key_full}: {e}")
156
-
157
-
158
- for key_base, diff_b_tensor in patches_store.get("diff_b", {}).items():
159
- # key_base is like "blocks.0.self_attn.q.bias"
160
- target_key_full = f"transformer.{key_base}"
161
- try:
162
- module_path_parts = target_key_full.split('.')
163
- param_name = module_path_parts[-1]
164
- module_path = ".".join(module_path_parts[:-1])
165
- module = transformer
166
- for part in module_path.split('.')[1:]:
167
- module = getattr(module, part)
168
-
169
- original_param = getattr(module, param_name)
170
- if original_param is None:
171
- logger.warning(f"Bias parameter {target_key_full} is None in model. Skipping diff_b patch.")
172
- continue
173
-
174
- if original_param.shape != diff_b_tensor.shape:
175
- logger.warning(f"Shape mismatch for diff_b patch on {target_key_full}: model {original_param.shape}, lora {diff_b_tensor.shape}. Skipping.")
176
- continue
177
-
178
- with torch.no_grad():
179
- scaled_diff_b = (lora_strength * diff_b_tensor.to(original_param.device, original_param.dtype))
180
- original_param.data.add_(scaled_diff_b)
181
- changed_params_count +=1
182
- except AttributeError:
183
- logger.warning(f"Could not find parameter {target_key_full} in transformer to apply diff_b patch.")
184
- except Exception as e:
185
- logger.error(f"Error applying diff_b patch to {target_key_full}: {e}")
186
- if changed_params_count > 0:
187
- logger.info(f"Applied {changed_params_count} manual diff/diff_b patches.")
188
- else:
189
- logger.info("No manual diff/diff_b patches were applied.")
190
-
191
-
192
  # --- Model Loading ---
193
  logger.info(f"Loading VAE for {MODEL_ID}...")
194
  vae = AutoencoderKLWan.from_pretrained(
@@ -214,26 +44,7 @@ logger.info(f"Downloading LoRA {LORA_FILENAME} from {LORA_REPO_ID}...")
214
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
215
 
216
  logger.info("Loading LoRA weights with custom converter...")
217
-
218
- from safetensors.torch import load_file as load_safetensors
219
- raw_lora_state_dict = load_safetensors(causvid_path)
220
-
221
- # Now call our custom converter which will populate MANUAL_PATCHES_STORE
222
- peft_state_dict = _custom_convert_non_diffusers_wan_lora_to_diffusers(raw_lora_state_dict)
223
-
224
- # Load the LoRA A/B matrices using PEFT
225
- if peft_state_dict:
226
- pipe.load_lora_weights(
227
- peft_state_dict,
228
- adapter_name="causvid_lora"
229
- )
230
- logger.info("PEFT LoRA A/B weights loaded.")
231
- else:
232
- logger.warning("No PEFT-compatible LoRA weights found after conversion.")
233
-
234
- # Apply manual diff_b and diff patches
235
- apply_manual_diff_patches(pipe, MANUAL_PATCHES_STORE, lora_strength=1.0) # Assuming default strength 1.0
236
- logger.info("Manual diff_b/diff patches applied.")
237
 
238
 
239
  # --- Gradio Interface Function ---
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # --- Model Loading ---
23
  logger.info(f"Loading VAE for {MODEL_ID}...")
24
  vae = AutoencoderKLWan.from_pretrained(
 
44
  causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME)
45
 
46
  logger.info("Loading LoRA weights with custom converter...")
47
+ pipe.load_lora_weights(causvid_path,adapter_name="causvid_lora")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  # --- Gradio Interface Function ---