File size: 2,888 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from cosmos_predict1.diffusion.training.utils.peft.lora_attn import build_attn_lora
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType
from cosmos_predict1.utils import log
from cosmos_predict1.utils.misc import count_params


def get_all_lora_params(model):
    """
    Get all LoRA weight parameters in the model
    """
    lora_modules = [mod for name, mod in model.named_modules() if "lora.net.0" in name or "lora.net.1" in name]
    lora_params = [(name, param) for mod in lora_modules for name, param in mod.named_parameters()]
    log.info(f"Found {len(lora_params)} LoRA weight matrices")
    return lora_params


def setup_lora_requires_grad(model):
    """
    Freeze all model parameters except LoRA parameters.
    """
    num_param = count_params(model, verbose=True)
    log.critical(f"Model has {num_param * 1e-6:.2f}M parameters before freezing")
    lora_params = get_all_lora_params(model)
    num_lora_param = sum([p.numel() for _, p in lora_params])
    log.info(f"Total number of LoRA parameters: {num_lora_param * 1e-6:.2f}M")
    if num_lora_param > 0:
        log.info("Freezing all parameters")
        model.requires_grad_(False)
        log.info("Unfreezing LoRA parameters")
        for name, param in lora_params:
            # log.info(f"Unfreezing loRA : {name}")
            param.requires_grad_(True)
        num_param = count_params(model, verbose=True)
        log.critical(f"Model has {num_param * 1e-6:.2f}M parameters after freezing")
    return num_lora_param


def add_lora_layers(model, peft_control_config):
    for i, block_name in enumerate(model.net.blocks):
        block = model.net.blocks[block_name]
        peft_control = peft_control_config.get(i, {})
        for j, subblock in enumerate(block.blocks):
            block_type = subblock.block_type
            peft_control_subblock = peft_control.get(block_type.upper(), {})
            customization_type = peft_control_subblock.get("customization_type", None)
            if customization_type == CustomizationType.LORA:
                if block_type.upper() in ["CA", "FA"]:
                    build_attn_lora(subblock.block.attn, peft_control_subblock)