Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import transformer_engine as te | |
from megatron.core import parallel_state | |
from torch import nn | |
from cosmos_predict1.utils import log | |
class LoRALinearLayer(nn.Module): | |
""" | |
ported from | |
https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. | |
""" | |
def __init__(self, in_features, out_features, rank=4, linear=False): | |
super().__init__() | |
if rank > min(in_features, out_features): | |
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") | |
if linear: | |
down = nn.Linear(in_features, rank, bias=False) | |
up = nn.Linear(rank, out_features, bias=False) | |
else: | |
down = nn.Conv1d(in_features, rank, 1, bias=False) | |
up = nn.Conv1d(rank, out_features, 1, bias=False) | |
nn.init.normal_(down.weight, std=1 / rank) | |
nn.init.zeros_(up.weight) | |
self.net = nn.Sequential(down, up) | |
def forward(self, hidden_states): | |
orig_dtype = hidden_states.dtype | |
dtype = self.net[0].weight.dtype | |
up_hidden_states = self.net(hidden_states.to(dtype)) | |
return up_hidden_states.to(orig_dtype) | |
class TELoRALinearLayer(nn.Module): | |
""" | |
ported from | |
https://github.com/huggingface/diffusers/blob/7a32b6beeb0cfdefed645253dce23d9b0a78597f/src/diffusers/models/attention_processor.py#L470. | |
""" | |
def __init__(self, in_features, out_features, rank, linear, tp_size, tp_group, sequence_parallel, parallel_mode): | |
super().__init__() | |
if rank > min(in_features, out_features): | |
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") | |
if linear: | |
down = te.pytorch.Linear( | |
in_features, | |
rank, | |
bias=False, | |
tp_size=1, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode=None, | |
) | |
up = te.pytorch.Linear( | |
rank, | |
out_features, | |
bias=False, | |
tp_size=tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode=parallel_mode, | |
) | |
else: | |
down = te.pytorch.Conv1d( | |
in_features, | |
rank, | |
1, | |
bias=False, | |
tp_size=1, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode=None, | |
) | |
up = te.pytorch.Conv1d( | |
rank, | |
out_features, | |
1, | |
bias=False, | |
tp_size=tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode=parallel_mode, | |
) | |
tp_rank = parallel_state.get_tensor_model_parallel_rank() | |
# Create generator | |
gen = torch.Generator(device=down.weight.device) | |
# Save the current random state | |
gen_state = gen.get_state() | |
# Set constant seed for non-tp layers | |
log.info(f"rank {tp_rank}: setting seed to 0") | |
gen.manual_seed(0) | |
nn.init.normal_(down.weight, std=1 / rank, generator=gen) | |
# Set a new random seed based on the tensor parallel rank | |
gen.manual_seed(tp_rank) | |
log.info(f"rank {tp_rank}: setting seed to {tp_rank}") | |
nn.init.zeros_(up.weight) | |
# Restore the original random state | |
gen.set_state(gen_state) | |
self.net = nn.Sequential(down, up) | |
def forward(self, hidden_states): | |
orig_dtype = hidden_states.dtype | |
dtype = self.net[0].weight.dtype | |
up_hidden_states = self.net(hidden_states.to(dtype)) | |
return up_hidden_states.to(orig_dtype) | |