Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | 
         @@ -192,7 +192,7 @@ def _custom_convert_non_diffusers_wan_lora_to_diffusers(state_dict): 
     | 
|
| 192 | 
         
             
                return final_peft_state_dict
         
     | 
| 193 | 
         | 
| 194 | 
         | 
| 195 | 
         
            -
            def apply_manual_diff_patches(pipe_model, patches):
         
     | 
| 196 | 
         
             
                """
         
     | 
| 197 | 
         
             
                Manually applies diff_b/diff patches to the model.
         
     | 
| 198 | 
         
             
                Assumes PEFT LoRA layers have already been loaded.
         
     | 
| 
         @@ -204,87 +204,95 @@ def apply_manual_diff_patches(pipe_model, patches): 
     | 
|
| 204 | 
         
             
                logger.info(f"Applying {len(patches)} manual diff patches...")
         
     | 
| 205 | 
         
             
                patched_keys_count = 0
         
     | 
| 206 | 
         
             
                unpatched_keys_count = 0
         
     | 
| 
         | 
|
| 207 | 
         | 
| 208 | 
         
             
                for key, diff_tensor in patches.items():
         
     | 
| 209 | 
         
             
                    try:
         
     | 
| 210 | 
         
            -
                         
     | 
| 211 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 212 | 
         | 
| 213 | 
         
            -
                        # Navigate to the parent module
         
     | 
| 214 | 
         
            -
                        #  
     | 
| 215 | 
         
            -
                         
     | 
| 216 | 
         
            -
                         
     | 
| 217 | 
         
            -
             
     | 
| 218 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 219 | 
         
             
                            else:
         
     | 
| 220 | 
         
            -
                                 
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
             
     | 
| 223 | 
         
            -
             
     | 
| 224 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 225 | 
         | 
| 226 | 
         
            -
                         
     | 
| 227 | 
         
            -
             
     | 
| 228 | 
         
            -
             
     | 
| 229 | 
         
            -
             
     | 
| 230 | 
         
            -
             
     | 
| 231 | 
         
            -
                        # If PEFT wrapped it, the actual nn.Linear or nn.LayerNorm is in `base_layer`
         
     | 
| 232 | 
         
            -
                        if hasattr(target_layer, "base_layer") and isinstance(target_layer.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
         
     | 
| 233 | 
         
            -
                            layer_to_modify = target_layer.base_layer
         
     | 
| 234 | 
         
            -
                        else:
         
     | 
| 235 | 
         
            -
                            layer_to_modify = target_layer
         
     | 
| 236 | 
         | 
| 237 | 
         
            -
                         
     | 
| 238 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 239 | 
         
             
                            unpatched_keys_count +=1
         
     | 
| 240 | 
         
             
                            continue
         
     | 
| 241 | 
         | 
| 242 | 
         
            -
                        original_param = getattr(layer_to_modify, param_name)
         
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
                        if original_param is None and param_name == "bias":
         
     | 
| 245 | 
         
            -
                            # If bias is None (e.g., LayerNorm with elementwise_affine=False, or Linear(bias=False)),
         
     | 
| 246 | 
         
            -
                            # we might need to initialize it if the diff expects to add to it.
         
     | 
| 247 | 
         
            -
                            # For Linear layers, if bias was False, it should remain False unless LoRA intends to add one.
         
     | 
| 248 | 
         
            -
                            # For LayerNorm, if elementwise_affine was False, adding a bias diff means it becomes affine.
         
     | 
| 249 | 
         
            -
                            if isinstance(layer_to_modify, torch.nn.Linear):
         
     | 
| 250 | 
         
            -
                                if layer_to_modify.bias is None: # Check if bias was intentionally None
         
     | 
| 251 | 
         
            -
                                    logger.warning(f"Original layer {layer_to_modify} for key '{key}' has no bias. Creating one to apply diff_b. This might be unintended if bias=False was set.")
         
     | 
| 252 | 
         
            -
                                    layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
         
     | 
| 253 | 
         
            -
                                    original_param = layer_to_modify.bias
         
     | 
| 254 | 
         
            -
                                else: # Should not happen if original_param was None but layer_to_modify.bias isn't
         
     | 
| 255 | 
         
            -
                                    pass
         
     | 
| 256 | 
         
            -
                            elif isinstance(layer_to_modify, torch.nn.LayerNorm):
         
     | 
| 257 | 
         
            -
                                if not layer_to_modify.elementwise_affine:
         
     | 
| 258 | 
         
            -
                                    logger.warning(f"LayerNorm {layer_to_modify} for key '{key}' was not elementwise_affine. Applying bias diff will make it effectively affine for bias.")
         
     | 
| 259 | 
         
            -
                                    # LayerNorm bias is initialized to zeros if elementwise_affine is True
         
     | 
| 260 | 
         
            -
                                    layer_to_modify.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
         
     | 
| 261 | 
         
            -
                                    original_param = layer_to_modify.bias
         
     | 
| 262 | 
         
            -
                                    # Also need to ensure weight exists if a weight diff is applied later
         
     | 
| 263 | 
         
            -
                                    if param_name == "bias" and not hasattr(layer_to_modify, "weight"):
         
     | 
| 264 | 
         
            -
                                         layer_to_modify.weight = torch.nn.Parameter(torch.ones_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype)) # Norm weights init to 1
         
     | 
| 265 | 
         | 
| 266 | 
         
             
                        if original_param is not None:
         
     | 
| 267 | 
         
             
                            if original_param.shape != diff_tensor.shape:
         
     | 
| 268 | 
         
            -
                                 
     | 
| 269 | 
         
            -
                                unpatched_keys_count +=1
         
     | 
| 270 | 
         
             
                                continue
         
     | 
| 271 | 
         
             
                            with torch.no_grad():
         
     | 
| 272 | 
         
             
                                original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
         
     | 
| 273 | 
         
            -
                            logger.info(f"Successfully applied diff to '{key}'")
         
     | 
| 274 | 
         
            -
                            patched_keys_count +=1
         
     | 
| 275 | 
         
             
                        else:
         
     | 
| 276 | 
         
            -
                             
     | 
| 277 | 
         
            -
                            unpatched_keys_count +=1
         
     | 
| 278 | 
         
            -
             
     | 
| 279 | 
         | 
| 280 | 
         
             
                    except AttributeError as e:
         
     | 
| 281 | 
         
            -
                         
     | 
| 282 | 
         
            -
                        unpatched_keys_count +=1
         
     | 
| 283 | 
         
             
                    except Exception as e:
         
     | 
| 284 | 
         
            -
                         
     | 
| 285 | 
         
            -
                        unpatched_keys_count +=1
         
     | 
| 286 | 
         
            -
                logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
         
     | 
| 287 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 288 | 
         | 
| 289 | 
         
             
            # --- Model Loading ---
         
     | 
| 290 | 
         
             
            logger.info(f"Loading VAE for {MODEL_ID}...")
         
     | 
| 
         @@ -411,6 +419,7 @@ with gr.Blocks() as demo: 
     | 
|
| 411 | 
         
             
                        width_input,
         
     | 
| 412 | 
         
             
                        num_frames_input,
         
     | 
| 413 | 
         
             
                        guidance_scale_input,
         
     | 
| 
         | 
|
| 414 | 
         
             
                        fps_input
         
     | 
| 415 | 
         
             
                    ],
         
     | 
| 416 | 
         
             
                    outputs=video_output
         
     | 
| 
         | 
|
| 192 | 
         
             
                return final_peft_state_dict
         
     | 
| 193 | 
         | 
| 194 | 
         | 
| 195 | 
         
            +
            def apply_manual_diff_patches(pipe_model: torch.nn.Module, patches: Dict[str, torch.Tensor]):
         
     | 
| 196 | 
         
             
                """
         
     | 
| 197 | 
         
             
                Manually applies diff_b/diff patches to the model.
         
     | 
| 198 | 
         
             
                Assumes PEFT LoRA layers have already been loaded.
         
     | 
| 
         | 
|
| 204 | 
         
             
                logger.info(f"Applying {len(patches)} manual diff patches...")
         
     | 
| 205 | 
         
             
                patched_keys_count = 0
         
     | 
| 206 | 
         
             
                unpatched_keys_count = 0
         
     | 
| 207 | 
         
            +
                skipped_keys_details = []
         
     | 
| 208 | 
         | 
| 209 | 
         
             
                for key, diff_tensor in patches.items():
         
     | 
| 210 | 
         
             
                    try:
         
     | 
| 211 | 
         
            +
                        # key is like "transformer.blocks.0.attn1.to_q.bias"
         
     | 
| 212 | 
         
            +
                        current_module = pipe_model # Starts from pipe.transformer
         
     | 
| 213 | 
         
            +
                        path_parts = key.split('.')[1:] # Remove "transformer." prefix for getattr navigation
         
     | 
| 214 | 
         
            +
                                                       # e.g., ["blocks", "0", "attn1", "to_q", "bias"]
         
     | 
| 215 | 
         | 
| 216 | 
         
            +
                        # Navigate to the parent module of the parameter
         
     | 
| 217 | 
         
            +
                        # Example: for "blocks.0.attn1.to_q.bias", parent_module_path is "blocks.0.attn1.to_q"
         
     | 
| 218 | 
         
            +
                        parent_module_path = path_parts[:-1]
         
     | 
| 219 | 
         
            +
                        param_name_to_patch = path_parts[-1] # "bias" or "weight"
         
     | 
| 220 | 
         
            +
             
     | 
| 221 | 
         
            +
                        for part in parent_module_path:
         
     | 
| 222 | 
         
            +
                            if hasattr(current_module, part):
         
     | 
| 223 | 
         
            +
                                current_module = getattr(current_module, part)
         
     | 
| 224 | 
         
            +
                            elif hasattr(current_module, 'base_layer') and hasattr(current_module.base_layer, part):
         
     | 
| 225 | 
         
            +
                                # This case is unlikely here as we are navigating *to* the layer,
         
     | 
| 226 | 
         
            +
                                # not trying to access a sub-component of a base_layer.
         
     | 
| 227 | 
         
            +
                                # PEFT wrapping affects the layer itself, not its parent structure.
         
     | 
| 228 | 
         
            +
                                current_module = getattr(current_module.base_layer, part)
         
     | 
| 229 | 
         
             
                            else:
         
     | 
| 230 | 
         
            +
                                raise AttributeError(f"Submodule '{part}' not found in path '{'.'.join(parent_module_path)}' within {key}")
         
     | 
| 231 | 
         
            +
             
     | 
| 232 | 
         
            +
                        # Now, current_module is the layer whose parameter we want to patch
         
     | 
| 233 | 
         
            +
                        # e.g., if key was transformer.blocks.0.attn1.to_q.bias,
         
     | 
| 234 | 
         
            +
                        # current_module is the to_q Linear layer (or LoraLayer wrapping it)
         
     | 
| 235 | 
         
            +
             
     | 
| 236 | 
         
            +
                        layer_to_modify = current_module
         
     | 
| 237 | 
         
            +
                        # If PEFT wrapped the Linear layer (common for attention q,k,v,o and ffn projections)
         
     | 
| 238 | 
         
            +
                        if hasattr(layer_to_modify, "base_layer") and isinstance(layer_to_modify.base_layer, (torch.nn.Linear, torch.nn.LayerNorm)):
         
     | 
| 239 | 
         
            +
                            actual_param_owner = layer_to_modify.base_layer
         
     | 
| 240 | 
         
            +
                        else: # For non-wrapped layers like LayerNorm, or if it's already the base_layer
         
     | 
| 241 | 
         
            +
                            actual_param_owner = layer_to_modify
         
     | 
| 242 | 
         | 
| 243 | 
         
            +
                        if not hasattr(actual_param_owner, param_name_to_patch):
         
     | 
| 244 | 
         
            +
                            skipped_keys_details.append(f"Key: {key}, Reason: Parameter '{param_name_to_patch}' not found in layer '{actual_param_owner}'. Layer type: {type(actual_param_owner)}")
         
     | 
| 245 | 
         
            +
                            unpatched_keys_count += 1
         
     | 
| 246 | 
         
            +
                            continue
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 247 | 
         | 
| 248 | 
         
            +
                        original_param = getattr(actual_param_owner, param_name_to_patch)
         
     | 
| 249 | 
         
            +
             
     | 
| 250 | 
         
            +
                        if original_param is None and param_name_to_patch == "bias":
         
     | 
| 251 | 
         
            +
                            logger.info(f"Key '{key}': Original bias is None. Attempting to initialize.")
         
     | 
| 252 | 
         
            +
                            if isinstance(actual_param_owner, torch.nn.Linear) or isinstance(actual_param_owner, torch.nn.LayerNorm):
         
     | 
| 253 | 
         
            +
                                # For LayerNorm, bias exists if elementwise_affine=True (default).
         
     | 
| 254 | 
         
            +
                                # If it was False, we are making it affine by adding a bias.
         
     | 
| 255 | 
         
            +
                                # For Linear, if bias was False, we are adding one.
         
     | 
| 256 | 
         
            +
                                actual_param_owner.bias = torch.nn.Parameter(torch.zeros_like(diff_tensor, device=diff_tensor.device, dtype=diff_tensor.dtype))
         
     | 
| 257 | 
         
            +
                                original_param = actual_param_owner.bias
         
     | 
| 258 | 
         
            +
                                logger.info(f"Key '{key}': Initialized bias for {type(actual_param_owner)}.")
         
     | 
| 259 | 
         
            +
                            else:
         
     | 
| 260 | 
         
            +
                                skipped_keys_details.append(f"Key: {key}, Reason: Original bias is None and layer '{actual_param_owner}' is not Linear or LayerNorm. Cannot initialize.")
         
     | 
| 261 | 
         
            +
                                unpatched_keys_count +=1
         
     | 
| 262 | 
         
            +
                                continue
         
     | 
| 263 | 
         
            +
                        
         
     | 
| 264 | 
         
            +
                        # Special handling for RMSNorm which typically has no bias
         
     | 
| 265 | 
         
            +
                        if isinstance(actual_param_owner, torch.nn.RMSNorm) and param_name_to_patch == "bias":
         
     | 
| 266 | 
         
            +
                            skipped_keys_details.append(f"Key: {key}, Reason: Layer '{actual_param_owner}' is RMSNorm which has no bias parameter. Skipping bias diff.")
         
     | 
| 267 | 
         
             
                            unpatched_keys_count +=1
         
     | 
| 268 | 
         
             
                            continue
         
     | 
| 269 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 270 | 
         | 
| 271 | 
         
             
                        if original_param is not None:
         
     | 
| 272 | 
         
             
                            if original_param.shape != diff_tensor.shape:
         
     | 
| 273 | 
         
            +
                                skipped_keys_details.append(f"Key: {key}, Reason: Shape mismatch. Model param: {original_param.shape}, LoRA diff: {diff_tensor.shape}. Layer: {actual_param_owner}")
         
     | 
| 274 | 
         
            +
                                unpatched_keys_count += 1
         
     | 
| 275 | 
         
             
                                continue
         
     | 
| 276 | 
         
             
                            with torch.no_grad():
         
     | 
| 277 | 
         
             
                                original_param.add_(diff_tensor.to(original_param.device, original_param.dtype))
         
     | 
| 278 | 
         
            +
                            # logger.info(f"Successfully applied diff to '{key}'") # Too verbose, will log summary
         
     | 
| 279 | 
         
            +
                            patched_keys_count += 1
         
     | 
| 280 | 
         
             
                        else:
         
     | 
| 281 | 
         
            +
                            skipped_keys_details.append(f"Key: {key}, Reason: Original parameter '{param_name_to_patch}' is None and was not initialized. Layer: {actual_param_owner}")
         
     | 
| 282 | 
         
            +
                            unpatched_keys_count += 1
         
     | 
| 
         | 
|
| 283 | 
         | 
| 284 | 
         
             
                    except AttributeError as e:
         
     | 
| 285 | 
         
            +
                        skipped_keys_details.append(f"Key: {key}, Reason: AttributeError - {e}")
         
     | 
| 286 | 
         
            +
                        unpatched_keys_count += 1
         
     | 
| 287 | 
         
             
                    except Exception as e:
         
     | 
| 288 | 
         
            +
                        skipped_keys_details.append(f"Key: {key}, Reason: General Exception - {e}")
         
     | 
| 289 | 
         
            +
                        unpatched_keys_count += 1
         
     | 
| 
         | 
|
| 290 | 
         | 
| 291 | 
         
            +
                logger.info(f"Manual patching summary: {patched_keys_count} keys patched, {unpatched_keys_count} keys failed or skipped.")
         
     | 
| 292 | 
         
            +
                if unpatched_keys_count > 0:
         
     | 
| 293 | 
         
            +
                    logger.warning("Details of unpatched/skipped keys:")
         
     | 
| 294 | 
         
            +
                    for detail in skipped_keys_details:
         
     | 
| 295 | 
         
            +
                        logger.warning(f"  - {detail}")
         
     | 
| 296 | 
         | 
| 297 | 
         
             
            # --- Model Loading ---
         
     | 
| 298 | 
         
             
            logger.info(f"Loading VAE for {MODEL_ID}...")
         
     | 
| 
         | 
|
| 419 | 
         
             
                        width_input,
         
     | 
| 420 | 
         
             
                        num_frames_input,
         
     | 
| 421 | 
         
             
                        guidance_scale_input,
         
     | 
| 422 | 
         
            +
                        steps,
         
     | 
| 423 | 
         
             
                        fps_input
         
     | 
| 424 | 
         
             
                    ],
         
     | 
| 425 | 
         
             
                    outputs=video_output
         
     |