roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType
from cosmos_predict1.utils import log
from cosmos_predict1.utils.misc import count_params
def get_all_lora_params(model):
"""
Get all LoRA weight parameters in the model
"""
lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name]
lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()]
log.info(f"Found {len(lora_params)} LoRA weight matrices")
return lora_params
def setup_lora_requires_grad(model):
"""
Freeze all model parameters except LoRA parameters.
"""
num_param = count_params(model, verbose=True)
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing")
lora_params = get_all_lora_params(model)
num_lora_param = sum([p.numel() for _, p in lora_params])
log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M")
if num_lora_param > 0:
log.info("Freezing all parameters")
model.requires_grad_(False)
log.info("Unfreezing LoRA parameters")
for name, param in lora_params:
# log.info(f"Unfreezing loRA : {name}")
param.requires_grad_(True)
num_param = count_params(model, verbose=True)
log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing")
return num_lora_param
def add_lora_layers(model, peft_control_config):
for i, block_name in enumerate(model.net.blocks):
block = model.net.blocks[block_name]
peft_control = peft_control_config.get(i, {})
for j, subblock in enumerate(block.blocks):
block_type = subblock.block_type
peft_control_subblock = peft_control.get(block_type.upper(), {})
customization_type = peft_control_subblock.get("customization_type", None)
if customization_type == CustomizationType.LORA:
if block_type.upper() in ["CA", "FA"]:
build_attn_lora(subblock.block.attn, peft_control_subblock)