roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from torch.utils.checkpoint import checkpoint
from transformer_engine.pytorch.attention import apply_rotary_pos_emb
from cosmos_predict1.diffusion.module.attention import Attention
from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType
try:
from megatron.core import parallel_state
USE_MEGATRON = True
except ImportError:
USE_MEGATRON = False
def enable_attn_lora(attn: Attention, peft_control: dict) -> None:
"""
Enable LoRA for the attention block based on the peft_control dictionary.
Args:
attn (Attention): The attention block to configure.
peft_control (dict): Dictionary containing PEFT configuration.
"""
attn.peft_lora_enabled = False
if peft_control:
try:
if peft_control["customization_type"] == CustomizationType.LORA:
attn.peft_lora_enabled = True
else:
raise Exception(f"Unsupported Customization type {peft_control['customization_type']}")
except KeyError as e:
raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.")
def configure_attn_lora(attn: Attention, peft_control: dict) -> None:
"""
Configure LoRA for the attention block based on the peft_control dictionary.
Args:
attn (Attention): The attention block to configure.
peft_control (dict): Dictionary containing PEFT configuration.
"""
try:
attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False)
attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False)
attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False)
attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False)
if attn.q_lora_enabled:
attn.q_lora_rank = peft_control["to_q"]["lora_rank"]
attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"])
if attn.k_lora_enabled:
attn.k_lora_rank = peft_control["to_k"]["lora_rank"]
attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"])
if attn.v_lora_enabled:
attn.v_lora_rank = peft_control["to_v"]["lora_rank"]
attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"])
if attn.out_lora_enabled:
attn.out_lora_rank = peft_control["to_out"]["lora_rank"]
attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"])
except KeyError as e:
raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.")
except ValueError as e:
raise ValueError(f"Could not convert string to float: {e}")
def cal_qkv_lora(
self,
x: torch.Tensor,
context: torch.Tensor = None,
mask: torch.Tensor = None,
rope_emb: torch.Tensor = None,
**kwargs,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
del kwargs
"""
Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv.
Args:
x (torch.Tensor): Input tensor.
context (torch.Tensor, optional): Context tensor
mask (torch.Tensor, optional): Mask tensor
rope_emb (torch.Tensor, optional): Rotary positional embedding
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices.
"""
q = self.to_q[0](x)
context = x if context is None else context
k = self.to_k[0](context)
v = self.to_v[0](context)
if self.peft_lora_enabled:
try:
if self.q_lora_enabled:
q_lora = self.to_q_lora(x)
q = q + self.q_lora_scale * q_lora
if self.k_lora_enabled:
k_lora = self.to_k_lora(context)
k = k + self.k_lora_scale * k_lora
if self.v_lora_enabled:
v_lora = self.to_v_lora(context)
v = v + self.v_lora_scale * v_lora
except AttributeError as e:
raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block")
q, k, v = map(
lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head),
(q, k, v),
)
def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb):
q = self.to_q[1](q)
k = self.to_k[1](k)
v = self.to_v[1](v)
if self.is_selfattn and rope_emb is not None: # only apply to self-attention!
q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True)
k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True)
return q, k, v
q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False)
return q, k, v
def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
"""
Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn.
Args:
q (torch.Tensor): Query tensor.
k (torch.Tensor): Key tensor.
v (torch.Tensor): Value tensor.
mask (torch.Tensor, optional): Mask tensor.
Returns:
torch.Tensor: The attention output.
"""
if self.backend == "transformer_engine":
seq_dim = self.qkv_format.index("s")
assert (
q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1
), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version."
attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V]
out = self.to_out(attn_out)
if self.peft_lora_enabled and self.out_lora_enabled:
try:
out_lora = self.to_out_lora(attn_out)
out = out + self.out_lora_scale * out_lora
except AttributeError as e:
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block")
return out
elif self.backend == "torch":
attn_out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V]
attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)")
out = self.to_out(attn_out)
if self.peft_lora_enabled and self.out_lora_enabled:
try:
out_lora = self.to_out_lora(attn_out)
out = out + self.out_lora_scale * out_lora
except AttributeError as e:
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block")
return out
else:
raise ValueError(f"Backend {self.backend} not found")
def build_attn_lora(attn: Attention, peft_control: dict) -> None:
"""
Configure, build and add LoRA layers to the attention block.
Args:
attn (Attention): The attention block to add LoRA layers to.
peft_control (dict): Dictionary containing PEFT configuration.
"""
enable_attn_lora(attn, peft_control)
configure_attn_lora(attn, peft_control)
if attn.peft_lora_enabled:
query_dim = attn.query_dim
inner_dim = attn.inner_dim
context_dim = attn.context_dim
tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None
if attn.tp_size == 1:
if attn.q_lora_enabled:
attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True)
if attn.k_lora_enabled:
attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True)
if attn.v_lora_enabled:
attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True)
if attn.out_lora_enabled:
attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True)
else:
sequence_parallel = getattr(parallel_state, "sequence_parallel", False)
if attn.q_lora_enabled:
attn.to_q_lora = TELoRALinearLayer(
query_dim,
inner_dim,
rank=attn.q_lora_rank,
linear=True,
tp_size=attn.tp_size,
tp_group=tp_group,
sequence_parallel=sequence_parallel,
parallel_mode="column",
)
if attn.k_lora_enabled:
attn.to_k_lora = TELoRALinearLayer(
context_dim,
inner_dim,
rank=attn.k_lora_rank,
linear=True,
tp_size=attn.tp_size,
tp_group=tp_group,
sequence_parallel=sequence_parallel,
parallel_mode="column",
)
if attn.v_lora_enabled:
attn.to_v_lora = TELoRALinearLayer(
context_dim,
inner_dim,
rank=attn.v_lora_rank,
linear=True,
tp_size=attn.tp_size,
tp_group=tp_group,
sequence_parallel=sequence_parallel,
parallel_mode="column",
)
if attn.out_lora_enabled:
attn.to_out_lora = TELoRALinearLayer(
inner_dim,
query_dim,
rank=attn.out_lora_rank,
linear=True,
tp_size=attn.tp_size,
tp_group=tp_group,
sequence_parallel=sequence_parallel,
parallel_mode="row",
)
attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__)
attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__)