Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora | |
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType | |
from cosmos_predict1.utils import log | |
from cosmos_predict1.utils.misc import count_params | |
def get_all_lora_params(model): | |
""" | |
Get all LoRA weight parameters in the model | |
""" | |
lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name] | |
lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()] | |
log.info(f"Found {len(lora_params)} LoRA weight matrices") | |
return lora_params | |
def setup_lora_requires_grad(model): | |
""" | |
Freeze all model parameters except LoRA parameters. | |
""" | |
num_param = count_params(model, verbose=True) | |
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing") | |
lora_params = get_all_lora_params(model) | |
num_lora_param = sum([p.numel() for _, p in lora_params]) | |
log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M") | |
if num_lora_param > 0: | |
log.info("Freezing all parameters") | |
model.requires_grad_(False) | |
log.info("Unfreezing LoRA parameters") | |
for name, param in lora_params: | |
# log.info(f"Unfreezing loRA : {name}") | |
param.requires_grad_(True) | |
num_param = count_params(model, verbose=True) | |
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing") | |
return num_lora_param | |
def add_lora_layers(model, peft_control_config): | |
for i, block_name in enumerate(model.net.blocks): | |
block = model.net.blocks[block_name] | |
peft_control = peft_control_config.get(i, {}) | |
for j, subblock in enumerate(block.blocks): | |
block_type = subblock.block_type | |
peft_control_subblock = peft_control.get(block_type.upper(), {}) | |
customization_type = peft_control_subblock.get("customization_type", None) | |
if customization_type == CustomizationType.LORA: | |
if block_type.upper() in ["CA", "FA"]: | |
build_attn_lora(subblock.block.attn, peft_control_subblock) | |