chengs18 commited on
Commit
56bca3c
·
verified ·
1 Parent(s): d8a00c3

Update modeling_sdar.py

Browse files
Files changed (1) hide show
  1. modeling_sdar.py +86 -61
modeling_sdar.py CHANGED
@@ -21,10 +21,11 @@
21
  # See the License for the specific language governing permissions and
22
  # limitations under the License.
23
 
24
- from typing import Callable, Optional, Tuple, Union
25
 
26
  import torch
27
  from torch import nn
 
28
 
29
  from transformers.activations import ACT2FN
30
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
@@ -43,8 +44,9 @@ from transformers.modeling_outputs import (
43
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
  from transformers.processing_utils import Unpack
46
- from transformers.utils import LossKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
47
  from .configuration_sdar import SDARConfig
 
48
 
49
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
50
 
@@ -69,6 +71,10 @@ if is_torch_flex_attn_available():
69
 
70
  logger = logging.get_logger(__name__)
71
 
 
 
 
 
72
 
73
  @use_kernel_forward_from_hub("RMSNorm")
74
  class SDARRMSNorm(nn.Module):
@@ -272,34 +278,21 @@ class SDARAttention(nn.Module):
272
  value_states = torch.cat(
273
  [past_value_states, value_states], dim=-2)
274
 
275
- attention_mask = attention_mask.bool() if attention_mask is not None else None
276
- if torch.all(attention_mask): # decoding
277
- query_states = query_states.transpose(1, 2)
278
- key_states = key_states.transpose(1, 2)
279
- value_states = value_states.transpose(1, 2)
280
- attn_output = flash_attn_func(
281
- query_states,
282
- key_states,
283
- value_states,
284
- causal=False,
285
- softmax_scale=self.scaling
286
- )
287
-
288
- else: # prefilling
289
- attn_output = F.scaled_dot_product_attention(
290
- query=query_states,
291
- key=key_states,
292
- value=value_states,
293
- attn_mask=attention_mask,
294
- is_causal=False,
295
- scale=self.scaling,
296
- enable_gqa=True
297
- )
298
- attn_output = attn_output.transpose(1, 2).contiguous()
299
-
300
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
301
  attn_output = self.o_proj(attn_output)
302
- return attn_output, None # , attn_weights
303
 
304
 
305
  class SDARDecoderLayer(GradientCheckpointingLayer):
@@ -733,10 +726,6 @@ class SDARModel(SDARPreTrainedModel):
733
  return causal_mask
734
 
735
 
736
- class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
737
- ...
738
-
739
-
740
  @auto_docstring
741
  class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
742
  _tied_weights_keys = ["lm_head.weight"]
@@ -771,6 +760,49 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
771
  def get_decoder(self):
772
  return self.model
773
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
  @can_return_tuple
775
  @auto_docstring
776
  def forward(
@@ -785,8 +817,8 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
785
  output_attentions: Optional[bool] = None,
786
  output_hidden_states: Optional[bool] = None,
787
  cache_position: Optional[torch.LongTensor] = None,
788
- logits_to_keep: Union[int, torch.Tensor] = 0,
789
- **kwargs: Unpack[KwargsForCausalLM],
790
  ) -> CausalLMOutputWithPast:
791
  r"""
792
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -814,40 +846,33 @@ class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
814
  output_hidden_states = (
815
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
816
  )
817
-
818
- # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
819
- outputs: BaseModelOutputWithPast = self.model(
820
  input_ids=input_ids,
821
  attention_mask=attention_mask,
822
  position_ids=position_ids,
823
- past_key_values=past_key_values,
824
- inputs_embeds=inputs_embeds,
825
- use_cache=use_cache,
826
  output_attentions=output_attentions,
827
  output_hidden_states=output_hidden_states,
 
828
  cache_position=cache_position,
829
- **kwargs,
830
- )
831
-
832
  hidden_states = outputs.last_hidden_state
833
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
834
- slice_indices = slice(-logits_to_keep,
835
- None) if isinstance(logits_to_keep, int) else logits_to_keep
836
- hidden_states = hidden_states[:, slice_indices, :].contiguous()
837
- fuse_linear_and_cross_entropy = self.config.fuse_cross_entropy and self.training
838
- if fuse_linear_and_cross_entropy:
839
- # When using fused_linear_ce_loss, we do not compute the whole logits on HBM
840
- logits = None
841
- else:
842
- logits = self.lm_head(hidden_states)
843
 
 
 
 
 
 
 
 
844
  loss = None
845
- if labels is not None:
846
- # FusedLinearCrossEntropyLoss will be implemented by monkey patch when training
847
- # We don't use it when inferencing
848
- loss_fct = nn.CrossEntropyLoss() # nn.CE
849
- loss = loss_fct(
850
- logits.view(-1, self.config.vocab_size), labels.view(-1))
851
 
852
  return CausalLMOutputWithPast(
853
  loss=loss,
@@ -862,4 +887,4 @@ __all__ = [
862
  "SDARForCausalLM",
863
  "SDARModel",
864
  "SDARPreTrainedModel",
865
- ]
 
21
  # See the License for the specific language governing permissions and
22
  # limitations under the License.
23
 
24
+ from typing import Callable, Optional, Tuple, Union, List
25
 
26
  import torch
27
  from torch import nn
28
+ from einops import rearrange
29
 
30
  from transformers.activations import ACT2FN
31
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
44
  from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
  from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
  from transformers.processing_utils import Unpack
47
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
48
  from .configuration_sdar import SDARConfig
49
+ from .fused_linear_diffusion_cross_entropy import FusedLinearDiffusionCrossEntropyLoss
50
 
51
  from flash_attn.ops.triton.layer_norm import rms_norm_fn as flash_rms_norm
52
 
 
71
 
72
  logger = logging.get_logger(__name__)
73
 
74
+ @torch.compile(fullgraph=True, mode="max-autotune-no-cudagraphs")
75
+ def fused_flex_attention(query, key, value, attention_mask, **kwargs):
76
+ return flex_attention(query, key, value, block_mask=attention_mask, **kwargs)
77
+
78
 
79
  @use_kernel_forward_from_hub("RMSNorm")
80
  class SDARRMSNorm(nn.Module):
 
278
  value_states = torch.cat(
279
  [past_value_states, value_states], dim=-2)
280
 
281
+ attn_output, attn_weights = fused_flex_attention(
282
+ query=query_states,
283
+ key=key_states,
284
+ value=value_states,
285
+ attention_mask=attention_mask,
286
+ enable_gqa=True,
287
+ scale=self.scaling,
288
+ return_lse=True
289
+ )
290
+ attn_weights = attn_weights.to(
291
+ value_states.dtype) if attn_weights is not None else None
292
+ attn_output = rearrange(attn_output, 'b h l d -> b l (h d)')
293
+
 
 
 
 
 
 
 
 
 
 
 
 
 
294
  attn_output = self.o_proj(attn_output)
295
+ return attn_output, attn_weights # , attn_weights
296
 
297
 
298
  class SDARDecoderLayer(GradientCheckpointingLayer):
 
726
  return causal_mask
727
 
728
 
 
 
 
 
729
  @auto_docstring
730
  class SDARForCausalLM(SDARPreTrainedModel, GenerationMixin):
731
  _tied_weights_keys = ["lm_head.weight"]
 
760
  def get_decoder(self):
761
  return self.model
762
 
763
+ def prepare_for_bd_training(self, inputs_ids, position_ids, prompt_mask):
764
+ bsz, seq_len = inputs_ids.shape
765
+ num_tokens = calculate_token_nums(position_ids) # List[torch.Tensor]
766
+ noisy_inputs_ids, logits_to_keep_half, p_mask = forward_add_noise_packed(
767
+ inputs_ids=inputs_ids,
768
+ num_tokens_list=num_tokens,
769
+ prompt_mask=prompt_mask,
770
+ mask_id=self.config.mask_token_id,
771
+ )
772
+ router_noisy_part_list = []
773
+ for i in range(bsz):
774
+ cur_router_noisy_part = (torch.arange(num_tokens[i].shape[0] *2) % 2 == 0).to(inputs_ids.device)
775
+ cur_router_noisy_part = cur_router_noisy_part.repeat_interleave(num_tokens[i].repeat_interleave(2))
776
+ router_noisy_part_list.append(cur_router_noisy_part)
777
+ router_noisy_part = torch.stack(router_noisy_part_list, dim=0)
778
+
779
+ # concated inputs_ids: (bzs, seq_len x 2)
780
+ concat_inputs_ids = inputs_ids.repeat(1, 2)
781
+ # concated logits_to_keep: (bsz, seq_len x 2)
782
+ logits_to_keep = torch.zeros(
783
+ bsz, 2 * seq_len, dtype=torch.bool, device=inputs_ids.device)
784
+ # concated position_ids: (bsz, seq_len x 2)
785
+ concat_position_ids = torch.zeros(
786
+ bsz, 2 * seq_len, dtype=position_ids.dtype, device=position_ids.device)
787
+ for i in range(bsz):
788
+ concat_inputs_ids[i][router_noisy_part[i]] = noisy_inputs_ids[i]
789
+ concat_inputs_ids[i][~router_noisy_part[i]] = inputs_ids[i]
790
+
791
+ logits_to_keep[i][router_noisy_part[i]] = logits_to_keep_half[i]
792
+
793
+ concat_position_ids[i][router_noisy_part[i]] = position_ids[i]
794
+ concat_position_ids[i][~router_noisy_part[i]] = position_ids[i]
795
+
796
+ # create flex_attention mask
797
+ attention_mask = block_attn_mask(num_tokens, self.config.block_size, inputs_ids.device)
798
+ flex_attention_mask_3d = create_block_mask(
799
+ lambda b, h, q_idx, kv_idx: attention_mask[b, q_idx, kv_idx],
800
+ B=attention_mask.size(0), H=None,
801
+ Q_LEN=attention_mask.size(1), KV_LEN=attention_mask.size(2),
802
+ )
803
+
804
+ return concat_inputs_ids, concat_position_ids, flex_attention_mask_3d, logits_to_keep_half, logits_to_keep, p_mask
805
+
806
  @can_return_tuple
807
  @auto_docstring
808
  def forward(
 
817
  output_attentions: Optional[bool] = None,
818
  output_hidden_states: Optional[bool] = None,
819
  cache_position: Optional[torch.LongTensor] = None,
820
+ logits_to_keep: Optional[torch.Tensor] = None,
821
+ **kwargs: Unpack[TransformersKwargs],
822
  ) -> CausalLMOutputWithPast:
823
  r"""
824
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
846
  output_hidden_states = (
847
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
  )
849
+
850
+ outputs = self.model(
 
851
  input_ids=input_ids,
852
  attention_mask=attention_mask,
853
  position_ids=position_ids,
 
 
 
854
  output_attentions=output_attentions,
855
  output_hidden_states=output_hidden_states,
856
+ return_dict=True,
857
  cache_position=cache_position,
858
+ **kwargs
859
+ )
 
860
  hidden_states = outputs.last_hidden_state
 
 
 
 
 
 
 
 
 
 
861
 
862
+ if logits_to_keep is not None:
863
+ B, _, H = hidden_states.shape
864
+ num_keep = logits_to_keep.sum(dim=1)
865
+ assert torch.all(num_keep == num_keep[0])
866
+ N = int(num_keep[0].item())
867
+ hidden_states = hidden_states[logits_to_keep].view(B, N, H).contiguous() # [B, N, H]
868
+ logits = self.lm_head(hidden_states)
869
  loss = None
870
+ # if labels is not None:
871
+ # loss_fct = nn.CrossEntropyLoss(reduction="none", ignore_index=-100)
872
+ # loss = loss_fct(
873
+ # logits.view(-1, self.config.vocab_size),
874
+ # labels.view(-1)
875
+ # ).view(labels.size())
876
 
877
  return CausalLMOutputWithPast(
878
  loss=loss,
 
887
  "SDARForCausalLM",
888
  "SDARModel",
889
  "SDARPreTrainedModel",
890
+ ]