Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
from einops import rearrange | |
from torch.utils.checkpoint import checkpoint | |
from transformer_engine.pytorch.attention import apply_rotary_pos_emb | |
from cosmos_predict1.diffusion.module.attention import Attention | |
from cosmos_predict1.diffusion.training.utils.peft.lora_net import LoRALinearLayer, TELoRALinearLayer | |
from cosmos_predict1.diffusion.utils.customization.customization_manager import CustomizationType | |
try: | |
from megatron.core import parallel_state | |
USE_MEGATRON = True | |
except ImportError: | |
USE_MEGATRON = False | |
def enable_attn_lora(attn: Attention, peft_control: dict) -> None: | |
""" | |
Enable LoRA for the attention block based on the peft_control dictionary. | |
Args: | |
attn (Attention): The attention block to configure. | |
peft_control (dict): Dictionary containing PEFT configuration. | |
""" | |
attn.peft_lora_enabled = False | |
if peft_control: | |
try: | |
if peft_control["customization_type"] == CustomizationType.LORA: | |
attn.peft_lora_enabled = True | |
else: | |
raise Exception(f"Unsupported Customization type {peft_control['customization_type']}") | |
except KeyError as e: | |
raise KeyError(f"peft_control dictionary expected to have attribute {e.args[0]}.") | |
def configure_attn_lora(attn: Attention, peft_control: dict) -> None: | |
""" | |
Configure LoRA for the attention block based on the peft_control dictionary. | |
Args: | |
attn (Attention): The attention block to configure. | |
peft_control (dict): Dictionary containing PEFT configuration. | |
""" | |
try: | |
attn.q_lora_enabled = peft_control.get("to_q", {}).get("activate", False) | |
attn.k_lora_enabled = peft_control.get("to_k", {}).get("activate", False) | |
attn.v_lora_enabled = peft_control.get("to_v", {}).get("activate", False) | |
attn.out_lora_enabled = peft_control.get("to_out", {}).get("activate", False) | |
if attn.q_lora_enabled: | |
attn.q_lora_rank = peft_control["to_q"]["lora_rank"] | |
attn.q_lora_scale = float(peft_control["to_q"]["lora_scale"]) | |
if attn.k_lora_enabled: | |
attn.k_lora_rank = peft_control["to_k"]["lora_rank"] | |
attn.k_lora_scale = float(peft_control["to_k"]["lora_scale"]) | |
if attn.v_lora_enabled: | |
attn.v_lora_rank = peft_control["to_v"]["lora_rank"] | |
attn.v_lora_scale = float(peft_control["to_v"]["lora_scale"]) | |
if attn.out_lora_enabled: | |
attn.out_lora_rank = peft_control["to_out"]["lora_rank"] | |
attn.out_lora_scale = float(peft_control["to_out"]["lora_scale"]) | |
except KeyError as e: | |
raise KeyError(f"All layers (to_q, etc) specified must have attribute {e.args[0]}.") | |
except ValueError as e: | |
raise ValueError(f"Could not convert string to float: {e}") | |
def cal_qkv_lora( | |
self, | |
x: torch.Tensor, | |
context: torch.Tensor = None, | |
mask: torch.Tensor = None, | |
rope_emb: torch.Tensor = None, | |
**kwargs, | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
del kwargs | |
""" | |
Calculate the Q, K, V matrices with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_qkv. | |
Args: | |
x (torch.Tensor): Input tensor. | |
context (torch.Tensor, optional): Context tensor | |
mask (torch.Tensor, optional): Mask tensor | |
rope_emb (torch.Tensor, optional): Rotary positional embedding | |
Returns: | |
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The Q, K, V matrices. | |
""" | |
q = self.to_q[0](x) | |
context = x if context is None else context | |
k = self.to_k[0](context) | |
v = self.to_v[0](context) | |
if self.peft_lora_enabled: | |
try: | |
if self.q_lora_enabled: | |
q_lora = self.to_q_lora(x) | |
q = q + self.q_lora_scale * q_lora | |
if self.k_lora_enabled: | |
k_lora = self.to_k_lora(context) | |
k = k + self.k_lora_scale * k_lora | |
if self.v_lora_enabled: | |
v_lora = self.to_v_lora(context) | |
v = v + self.v_lora_scale * v_lora | |
except AttributeError as e: | |
raise AttributeError(f"lora enabled, but missing class attribute {e.args[0]} of Attention block") | |
q, k, v = map( | |
lambda t: rearrange(t, "b ... (n c) -> b ... n c", n=self.heads // self.tp_size, c=self.dim_head), | |
(q, k, v), | |
) | |
def apply_norm_and_rotary_pos_emb(q, k, v, rope_emb): | |
q = self.to_q[1](q) | |
k = self.to_k[1](k) | |
v = self.to_v[1](v) | |
if self.is_selfattn and rope_emb is not None: # only apply to self-attention! | |
q = apply_rotary_pos_emb(q, rope_emb, tensor_format=self.qkv_format, fused=True) | |
k = apply_rotary_pos_emb(k, rope_emb, tensor_format=self.qkv_format, fused=True) | |
return q, k, v | |
q, k, v = checkpoint(apply_norm_and_rotary_pos_emb, q, k, v, rope_emb, use_reentrant=False) | |
return q, k, v | |
def cal_attn_lora(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: | |
""" | |
Calculate the attention output with LoRA adjustments. Derived from cosmos_predict1/diffusion/module/attention.py cal_attn. | |
Args: | |
q (torch.Tensor): Query tensor. | |
k (torch.Tensor): Key tensor. | |
v (torch.Tensor): Value tensor. | |
mask (torch.Tensor, optional): Mask tensor. | |
Returns: | |
torch.Tensor: The attention output. | |
""" | |
if self.backend == "transformer_engine": | |
seq_dim = self.qkv_format.index("s") | |
assert ( | |
q.shape[seq_dim] > 1 and k.shape[seq_dim] > 1 | |
), "Seqlen must be larger than 1 for TE Attention starting with 1.8 TE version." | |
attn_out = self.attn_op(q, k, v, core_attention_bias_type="no_bias", core_attention_bias=None) # [B, Mq, H, V] | |
out = self.to_out(attn_out) | |
if self.peft_lora_enabled and self.out_lora_enabled: | |
try: | |
out_lora = self.to_out_lora(attn_out) | |
out = out + self.out_lora_scale * out_lora | |
except AttributeError as e: | |
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") | |
return out | |
elif self.backend == "torch": | |
attn_out = self.attn_op(q, k, v, mask=mask) # [B, Mq, H, V] | |
attn_out = rearrange(attn_out, " b ... n c -> b ... (n c)") | |
out = self.to_out(attn_out) | |
if self.peft_lora_enabled and self.out_lora_enabled: | |
try: | |
out_lora = self.to_out_lora(attn_out) | |
out = out + self.out_lora_scale * out_lora | |
except AttributeError as e: | |
raise AttributeError(f"l1 lora enabled, but missing class attribute {e.args[0]} of FeedForward block") | |
return out | |
else: | |
raise ValueError(f"Backend {self.backend} not found") | |
def build_attn_lora(attn: Attention, peft_control: dict) -> None: | |
""" | |
Configure, build and add LoRA layers to the attention block. | |
Args: | |
attn (Attention): The attention block to add LoRA layers to. | |
peft_control (dict): Dictionary containing PEFT configuration. | |
""" | |
enable_attn_lora(attn, peft_control) | |
configure_attn_lora(attn, peft_control) | |
if attn.peft_lora_enabled: | |
query_dim = attn.query_dim | |
inner_dim = attn.inner_dim | |
context_dim = attn.context_dim | |
tp_group = parallel_state.get_tensor_model_parallel_group(check_initialized=False) if USE_MEGATRON else None | |
if attn.tp_size == 1: | |
if attn.q_lora_enabled: | |
attn.to_q_lora = LoRALinearLayer(query_dim, inner_dim, rank=attn.q_lora_rank, linear=True) | |
if attn.k_lora_enabled: | |
attn.to_k_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.k_lora_rank, linear=True) | |
if attn.v_lora_enabled: | |
attn.to_v_lora = LoRALinearLayer(context_dim, inner_dim, rank=attn.v_lora_rank, linear=True) | |
if attn.out_lora_enabled: | |
attn.to_out_lora = LoRALinearLayer(inner_dim, query_dim, rank=attn.out_lora_rank, linear=True) | |
else: | |
sequence_parallel = getattr(parallel_state, "sequence_parallel", False) | |
if attn.q_lora_enabled: | |
attn.to_q_lora = TELoRALinearLayer( | |
query_dim, | |
inner_dim, | |
rank=attn.q_lora_rank, | |
linear=True, | |
tp_size=attn.tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode="column", | |
) | |
if attn.k_lora_enabled: | |
attn.to_k_lora = TELoRALinearLayer( | |
context_dim, | |
inner_dim, | |
rank=attn.k_lora_rank, | |
linear=True, | |
tp_size=attn.tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode="column", | |
) | |
if attn.v_lora_enabled: | |
attn.to_v_lora = TELoRALinearLayer( | |
context_dim, | |
inner_dim, | |
rank=attn.v_lora_rank, | |
linear=True, | |
tp_size=attn.tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode="column", | |
) | |
if attn.out_lora_enabled: | |
attn.to_out_lora = TELoRALinearLayer( | |
inner_dim, | |
query_dim, | |
rank=attn.out_lora_rank, | |
linear=True, | |
tp_size=attn.tp_size, | |
tp_group=tp_group, | |
sequence_parallel=sequence_parallel, | |
parallel_mode="row", | |
) | |
attn.cal_qkv = cal_qkv_lora.__get__(attn, attn.__class__) | |
attn.cal_attn = cal_attn_lora.__get__(attn, attn.__class__) | |