Update modeling_sdar.py
Browse files- modeling_sdar.py +86 -61
modeling_sdar.py
CHANGED
|
@@ -21,10 +21,11 @@
|
|
| 21 |
# See the License for the specific language governing permissions and
|
| 22 |
# limitations under the License.
|
| 23 |
|
| 24 |
-
from typing import Callable, Optional, Tuple, Union
|
| 25 |
|
| 26 |
import torch
|
| 27 |
from torch import nn
|
|
|
|
| 28 |
|
| 29 |
from transformers.activations import ACT2FN
|
| 30 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|
@@ -43,8 +44,9 @@ from transformers.modeling_outputs import (
|
|
| 43 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 44 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 45 |
from transformers.processing_utils import Unpack
|
| 46 |
-
from transformers.utils import
|
| 47 |
from .configuration_sdar import SDARConfig
|
|
|
|
| 48 |
|
| 49 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
| 50 |
|
|
@@ -69,6 +71,10 @@ if is_torch_flex_attn_available():
|
|
| 69 |
|
| 70 |
logger = logging.get_logger(__name__)
|
| 71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 74 |
class SDARRMSNorm(nn.Module):
|
|
@@ -272,34 +278,21 @@ class SDARAttention(nn.Module):
|
|
| 272 |
value_states = torch.cat(
|
| 273 |
[past_value_states, value_states], dim=-2)
|
| 274 |
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
else: # prefilling
|
| 289 |
-
attn_output = F.scaled_dot_product_attention(
|
| 290 |
-
query=query_states,
|
| 291 |
-
key=key_states,
|
| 292 |
-
value=value_states,
|
| 293 |
-
attn_mask=attention_mask,
|
| 294 |
-
is_causal=False,
|
| 295 |
-
scale=self.scaling,
|
| 296 |
-
enable_gqa=True
|
| 297 |
-
)
|
| 298 |
-
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 299 |
-
|
| 300 |
-
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
|
| 301 |
attn_output = self.o_proj(attn_output)
|
| 302 |
-
return attn_output,
|
| 303 |
|
| 304 |
|
| 305 |
class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
@@ -733,10 +726,6 @@ class SDARModel(SDARPreTrainedModel):
|
|
| 733 |
return causal_mask
|
| 734 |
|
| 735 |
|
| 736 |
-
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
|
| 737 |
-
...
|
| 738 |
-
|
| 739 |
-
|
| 740 |
@auto_docstring
|
| 741 |
class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
| 742 |
_tied_weights_keys = ["lm_head.weight"]
|
|
@@ -771,6 +760,49 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 771 |
def get_decoder(self):
|
| 772 |
return self.model
|
| 773 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
@can_return_tuple
|
| 775 |
@auto_docstring
|
| 776 |
def forward(
|
|
@@ -785,8 +817,8 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 785 |
output_attentions: Optional[bool] = None,
|
| 786 |
output_hidden_states: Optional[bool] = None,
|
| 787 |
cache_position: Optional[torch.LongTensor] = None,
|
| 788 |
-
logits_to_keep:
|
| 789 |
-
**kwargs: Unpack[
|
| 790 |
) -> CausalLMOutputWithPast:
|
| 791 |
r"""
|
| 792 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -814,40 +846,33 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
|
| 814 |
output_hidden_states = (
|
| 815 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 816 |
)
|
| 817 |
-
|
| 818 |
-
|
| 819 |
-
outputs: BaseModelOutputWithPast = self.model(
|
| 820 |
input_ids=input_ids,
|
| 821 |
attention_mask=attention_mask,
|
| 822 |
position_ids=position_ids,
|
| 823 |
-
past_key_values=past_key_values,
|
| 824 |
-
inputs_embeds=inputs_embeds,
|
| 825 |
-
use_cache=use_cache,
|
| 826 |
output_attentions=output_attentions,
|
| 827 |
output_hidden_states=output_hidden_states,
|
|
|
|
| 828 |
cache_position=cache_position,
|
| 829 |
-
**kwargs
|
| 830 |
-
|
| 831 |
-
|
| 832 |
hidden_states = outputs.last_hidden_state
|
| 833 |
-
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
|
| 834 |
-
slice_indices = slice(-logits_to_keep,
|
| 835 |
-
None) if isinstance(logits_to_keep, int) else logits_to_keep
|
| 836 |
-
hidden_states = hidden_states[:, slice_indices, :].contiguous()
|
| 837 |
-
fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
|
| 838 |
-
if fuse_linear_and_cross_entropy:
|
| 839 |
-
# When using fused_linear_ce_loss, we do not compute the whole logits on HBM
|
| 840 |
-
logits = None
|
| 841 |
-
else:
|
| 842 |
-
logits = self.lm_head(hidden_states)
|
| 843 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 844 |
loss = None
|
| 845 |
-
if labels is not None:
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
|
| 852 |
return CausalLMOutputWithPast(
|
| 853 |
loss=loss,
|
|
@@ -862,4 +887,4 @@ __all__ = [
|
|
| 862 |
"SDARForCausalLM",
|
| 863 |
"SDARModel",
|
| 864 |
"SDARPreTrainedModel",
|
| 865 |
-
]
|
|
|
|
| 21 |
# See the License for the specific language governing permissions and
|
| 22 |
# limitations under the License.
|
| 23 |
|
| 24 |
+
from typing import Callable, Optional, Tuple, Union, List
|
| 25 |
|
| 26 |
import torch
|
| 27 |
from torch import nn
|
| 28 |
+
from einops import rearrange
|
| 29 |
|
| 30 |
from transformers.activations import ACT2FN
|
| 31 |
from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
|
|
|
|
| 44 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
| 45 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 46 |
from transformers.processing_utils import Unpack
|
| 47 |
+
from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
|
| 48 |
from .configuration_sdar import SDARConfig
|
| 49 |
+
from .fused_linear_diffusion_cross_entropy import FusedLinearDiffusionCrossEntropyLoss
|
| 50 |
|
| 51 |
from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
|
| 52 |
|
|
|
|
| 71 |
|
| 72 |
logger = logging.get_logger(__name__)
|
| 73 |
|
| 74 |
+
@torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
|
| 75 |
+
def fused_flex_attention(query, key, value, attention_mask, **kwargs):
|
| 76 |
+
return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
|
| 77 |
+
|
| 78 |
|
| 79 |
@use_kernel_forward_from_hub("RMSNorm")
|
| 80 |
class SDARRMSNorm(nn.Module):
|
|
|
|
| 278 |
value_states = torch.cat(
|
| 279 |
[past_value_states, value_states], dim=-2)
|
| 280 |
|
| 281 |
+
attn_output, attn_weights = fused_flex_attention(
|
| 282 |
+
query=query_states,
|
| 283 |
+
key=key_states,
|
| 284 |
+
value=value_states,
|
| 285 |
+
attention_mask=attention_mask,
|
| 286 |
+
enable_gqa=True,
|
| 287 |
+
scale=self.scaling,
|
| 288 |
+
return_lse=True
|
| 289 |
+
)
|
| 290 |
+
attn_weights = attn_weights.to(
|
| 291 |
+
value_states.dtype) if attn_weights is not None else None
|
| 292 |
+
attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
|
| 293 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 294 |
attn_output = self.o_proj(attn_output)
|
| 295 |
+
return attn_output, attn_weights # , attn_weights
|
| 296 |
|
| 297 |
|
| 298 |
class SDARDecoderLayer(GradientCheckpointingLayer):
|
|
|
|
| 726 |
return causal_mask
|
| 727 |
|
| 728 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
@auto_docstring
|
| 730 |
class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
|
| 731 |
_tied_weights_keys = ["lm_head.weight"]
|
|
|
|
| 760 |
def get_decoder(self):
|
| 761 |
return self.model
|
| 762 |
|
| 763 |
+
def prepare_for_bd_training(self, inputs_ids, position_ids, prompt_mask):
|
| 764 |
+
bsz, seq_len = inputs_ids.shape
|
| 765 |
+
num_tokens = calculate_token_nums(position_ids) # List[torch.Tensor]
|
| 766 |
+
noisy_inputs_ids, logits_to_keep_half, p_mask = forward_add_noise_packed(
|
| 767 |
+
inputs_ids=inputs_ids,
|
| 768 |
+
num_tokens_list=num_tokens,
|
| 769 |
+
prompt_mask=prompt_mask,
|
| 770 |
+
mask_id=self.config.mask_token_id,
|
| 771 |
+
)
|
| 772 |
+
router_noisy_part_list = []
|
| 773 |
+
for i in range(bsz):
|
| 774 |
+
cur_router_noisy_part = (torch.arange(num_tokens[i].shape[0] *2) % 2 == 0).to(inputs_ids.device)
|
| 775 |
+
cur_router_noisy_part = cur_router_noisy_part.repeat_interleave(num_tokens[i].repeat_interleave(2))
|
| 776 |
+
router_noisy_part_list.append(cur_router_noisy_part)
|
| 777 |
+
router_noisy_part = torch.stack(router_noisy_part_list, dim=0)
|
| 778 |
+
|
| 779 |
+
# concated inputs_ids: (bzs, seq_len x 2)
|
| 780 |
+
concat_inputs_ids = inputs_ids.repeat(1, 2)
|
| 781 |
+
# concated logits_to_keep: (bsz, seq_len x 2)
|
| 782 |
+
logits_to_keep = torch.zeros(
|
| 783 |
+
bsz, 2 * seq_len, dtype=torch.bool, device=inputs_ids.device)
|
| 784 |
+
# concated position_ids: (bsz, seq_len x 2)
|
| 785 |
+
concat_position_ids = torch.zeros(
|
| 786 |
+
bsz, 2 * seq_len, dtype=position_ids.dtype, device=position_ids.device)
|
| 787 |
+
for i in range(bsz):
|
| 788 |
+
concat_inputs_ids[i][router_noisy_part[i]] = noisy_inputs_ids[i]
|
| 789 |
+
concat_inputs_ids[i][~router_noisy_part[i]] = inputs_ids[i]
|
| 790 |
+
|
| 791 |
+
logits_to_keep[i][router_noisy_part[i]] = logits_to_keep_half[i]
|
| 792 |
+
|
| 793 |
+
concat_position_ids[i][router_noisy_part[i]] = position_ids[i]
|
| 794 |
+
concat_position_ids[i][~router_noisy_part[i]] = position_ids[i]
|
| 795 |
+
|
| 796 |
+
# create flex_attention mask
|
| 797 |
+
attention_mask = block_attn_mask(num_tokens, self.config.block_size, inputs_ids.device)
|
| 798 |
+
flex_attention_mask_3d = create_block_mask(
|
| 799 |
+
lambda b, h, q_idx, kv_idx: attention_mask[b, q_idx, kv_idx],
|
| 800 |
+
B=attention_mask.size(0), H=None,
|
| 801 |
+
Q_LEN=attention_mask.size(1), KV_LEN=attention_mask.size(2),
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
return concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask
|
| 805 |
+
|
| 806 |
@can_return_tuple
|
| 807 |
@auto_docstring
|
| 808 |
def forward(
|
|
|
|
| 817 |
output_attentions: Optional[bool] = None,
|
| 818 |
output_hidden_states: Optional[bool] = None,
|
| 819 |
cache_position: Optional[torch.LongTensor] = None,
|
| 820 |
+
logits_to_keep: Optional[torch.Tensor] = None,
|
| 821 |
+
**kwargs: Unpack[TransformersKwargs],
|
| 822 |
) -> CausalLMOutputWithPast:
|
| 823 |
r"""
|
| 824 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 846 |
output_hidden_states = (
|
| 847 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 848 |
)
|
| 849 |
+
|
| 850 |
+
outputs = self.model(
|
|
|
|
| 851 |
input_ids=input_ids,
|
| 852 |
attention_mask=attention_mask,
|
| 853 |
position_ids=position_ids,
|
|
|
|
|
|
|
|
|
|
| 854 |
output_attentions=output_attentions,
|
| 855 |
output_hidden_states=output_hidden_states,
|
| 856 |
+
return_dict=True,
|
| 857 |
cache_position=cache_position,
|
| 858 |
+
**kwargs
|
| 859 |
+
)
|
|
|
|
| 860 |
hidden_states = outputs.last_hidden_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
|
| 862 |
+
if logits_to_keep is not None:
|
| 863 |
+
B, _, H = hidden_states.shape
|
| 864 |
+
num_keep = logits_to_keep.sum(dim=1)
|
| 865 |
+
assert torch.all(num_keep == num_keep[0])
|
| 866 |
+
N = int(num_keep[0].item())
|
| 867 |
+
hidden_states = hidden_states[logits_to_keep].view(B, N, H).contiguous() # [B, N, H]
|
| 868 |
+
logits = self.lm_head(hidden_states)
|
| 869 |
loss = None
|
| 870 |
+
# if labels is not None:
|
| 871 |
+
# loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
|
| 872 |
+
# loss = loss_fct(
|
| 873 |
+
# logits.view(-1, self.config.vocab_size),
|
| 874 |
+
# labels.view(-1)
|
| 875 |
+
# ).view(labels.size())
|
| 876 |
|
| 877 |
return CausalLMOutputWithPast(
|
| 878 |
loss=loss,
|
|
|
|
| 887 |
"SDARForCausalLM",
|
| 888 |
"SDARModel",
|
| 889 |
"SDARPreTrainedModel",
|
| 890 |
+
]
|