File size: 4,649 Bytes
b6af722
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import transformer_engine as te
from megatron.core import parallel_state
from torch import nn

from cosmos_predict1.utils import log


class LoRALinearLayer(nn.Module):
    """
    ported from
    https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470.
    """

    def __init__(self, in_features, out_features, rank=4, linear=False):
        super().__init__()

        if rank > min(in_features, out_features):
            raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

        if linear:
            down = nn.Linear(in_features, rank, bias=False)
            up = nn.Linear(rank, out_features, bias=False)
        else:
            down = nn.Conv1d(in_features, rank, 1, bias=False)
            up = nn.Conv1d(rank, out_features, 1, bias=False)

        nn.init.normal_(down.weight, std=1 / rank)
        nn.init.zeros_(up.weight)
        self.net = nn.Sequential(down, up)

    def forward(self, hidden_states):
        orig_dtype = hidden_states.dtype
        dtype = self.net[0].weight.dtype

        up_hidden_states = self.net(hidden_states.to(dtype))

        return up_hidden_states.to(orig_dtype)


class TELoRALinearLayer(nn.Module):
    """
    ported from
    https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470.
    """

    def __init__(self, in_features, out_features, rank, linear, tp_size, tp_group, sequence_parallel, parallel_mode):
        super().__init__()

        if rank > min(in_features, out_features):
            raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

        if linear:
            down = te.pytorch.Linear(
                in_features,
                rank,
                bias=False,
                tp_size=1,
                tp_group=tp_group,
                sequence_parallel=sequence_parallel,
                parallel_mode=None,
            )
            up = te.pytorch.Linear(
                rank,
                out_features,
                bias=False,
                tp_size=tp_size,
                tp_group=tp_group,
                sequence_parallel=sequence_parallel,
                parallel_mode=parallel_mode,
            )
        else:
            down = te.pytorch.Conv1d(
                in_features,
                rank,
                1,
                bias=False,
                tp_size=1,
                tp_group=tp_group,
                sequence_parallel=sequence_parallel,
                parallel_mode=None,
            )
            up = te.pytorch.Conv1d(
                rank,
                out_features,
                1,
                bias=False,
                tp_size=tp_size,
                tp_group=tp_group,
                sequence_parallel=sequence_parallel,
                parallel_mode=parallel_mode,
            )
        tp_rank = parallel_state.get_tensor_model_parallel_rank()
        # Create generator
        gen = torch.Generator(device=down.weight.device)
        # Save the current random state
        gen_state = gen.get_state()

        # Set constant seed for non-tp layers
        log.info(f"rank {tp_rank}: setting seed to 0")
        gen.manual_seed(0)
        nn.init.normal_(down.weight, std=1 / rank, generator=gen)
        # Set a new random seed based on the tensor parallel rank
        gen.manual_seed(tp_rank)
        log.info(f"rank {tp_rank}: setting seed to {tp_rank}")
        nn.init.zeros_(up.weight)
        # Restore the original random state
        gen.set_state(gen_state)

        self.net = nn.Sequential(down, up)

    def forward(self, hidden_states):
        orig_dtype = hidden_states.dtype
        dtype = self.net[0].weight.dtype
        up_hidden_states = self.net(hidden_states.to(dtype))

        return up_hidden_states.to(orig_dtype)