fredzzp commited on
Commit
bc4e288
·
verified ·
1 Parent(s): 5e27940

Initial model upload with custom modeling and generation code

Browse files
Files changed (2) hide show
  1. config.json +2 -1
  2. modeling_qwen2.py +81 -203
config.json CHANGED
@@ -28,5 +28,6 @@
28
  "vocab_size": 151936,
29
  "auto_map": {
30
  "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
31
- }
 
32
  }
 
28
  "vocab_size": 151936,
29
  "auto_map": {
30
  "AutoModelForCausalLM": "modeling_qwen2.Qwen2ForCausalLM"
31
+ },
32
+ "trust_remote_code": true
33
  }
modeling_qwen2.py CHANGED
@@ -12,17 +12,16 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
- # This is a cleaned version of the original script, with proprietary dependencies
16
- # and training-specific code removed for public release.
17
 
18
  import logging
19
- from typing import Callable, List, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
25
- from transformers.generation import GenerationMixin
26
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
27
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
28
  from transformers.modeling_outputs import (
@@ -39,6 +38,9 @@ from transformers.utils import (
39
  replace_return_docstrings,
40
  )
41
 
 
 
 
42
  logger = logging.getLogger(__name__)
43
 
44
  _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
@@ -88,35 +90,8 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
88
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
89
 
90
 
91
- def eager_attention_forward(
92
- module: nn.Module,
93
- query: torch.Tensor,
94
- key: torch.Tensor,
95
- value: torch.Tensor,
96
- attention_mask: Optional[torch.Tensor],
97
- scaling: float,
98
- dropout: float = 0.0,
99
- **kwargs,
100
- ):
101
- key_states = repeat_kv(key, module.num_key_value_groups)
102
- value_states = repeat_kv(value, module.num_key_value_groups)
103
-
104
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
105
- if attention_mask is not None:
106
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
107
- attn_weights = attn_weights + causal_mask
108
-
109
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
110
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
111
- attn_output = torch.matmul(attn_weights, value_states)
112
- attn_output = attn_output.transpose(1, 2).contiguous()
113
-
114
- return attn_output, attn_weights
115
-
116
-
117
  class Qwen2Attention(nn.Module):
118
- """Multi-headed attention from 'Attention Is All You Need' paper"""
119
-
120
  def __init__(self, config: Qwen2Config, layer_idx: int):
121
  super().__init__()
122
  self.config = config
@@ -136,6 +111,7 @@ class Qwen2Attention(nn.Module):
136
  position_embeddings: Tuple[torch.Tensor, torch.Tensor],
137
  attention_mask: Optional[torch.Tensor],
138
  past_key_value: Optional[Cache] = None,
 
139
  cache_position: Optional[torch.LongTensor] = None,
140
  is_causal: bool = True,
141
  **kwargs: Unpack[FlashAttentionKwargs],
@@ -155,42 +131,34 @@ class Qwen2Attention(nn.Module):
155
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
156
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
157
 
158
- sliding_window = None
159
- if (
160
- self.config.use_sliding_window
161
- and getattr(self.config, "sliding_window", None) is not None
162
- and self.layer_idx >= self.config.max_window_layers
163
- ):
164
- sliding_window = self.config.sliding_window
165
-
166
- attention_interface: Callable = eager_attention_forward
167
- if self.config._attn_implementation != "eager":
168
- if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
169
- logger.warning_once(
170
- "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
171
- 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
172
- )
173
- else:
174
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
175
- self.is_causal = is_causal
176
  attn_output, attn_weights = attention_interface(
177
- self,
178
  query_states,
179
  key_states,
180
  value_states,
181
- attention_mask,
182
- dropout=0.0 if not self.training else self.attention_dropout,
183
- scaling=self.scaling,
184
- sliding_window=sliding_window,
185
  **kwargs,
186
  )
187
-
188
- attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
189
- attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size).contiguous()
190
  attn_output = self.o_proj(attn_output)
191
- return attn_output, attn_weights
 
 
192
 
 
193
 
 
194
  class Qwen2RMSNorm(nn.Module):
195
  def __init__(self, hidden_size, eps=1e-6):
196
  super().__init__()
@@ -204,10 +172,6 @@ class Qwen2RMSNorm(nn.Module):
204
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
205
  return self.weight * hidden_states.to(input_dtype)
206
 
207
- def extra_repr(self):
208
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
209
-
210
-
211
  class Qwen2DecoderLayer(nn.Module):
212
  def __init__(self, config: Qwen2Config, layer_idx: int):
213
  super().__init__()
@@ -216,11 +180,6 @@ class Qwen2DecoderLayer(nn.Module):
216
  self.mlp = Qwen2MLP(config)
217
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
218
  self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
219
- if config.sliding_window and config._attn_implementation != "flash_attention_2":
220
- logger.warning_once(
221
- f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
222
- "unexpected results may be encountered."
223
- )
224
 
225
  def forward(
226
  self,
@@ -238,10 +197,11 @@ class Qwen2DecoderLayer(nn.Module):
238
  residual = hidden_states
239
  hidden_states = self.input_layernorm(hidden_states)
240
 
241
- hidden_states, self_attn_weights = self.self_attn(
242
  hidden_states=hidden_states,
243
  attention_mask=attention_mask,
244
  past_key_value=past_key_value,
 
245
  cache_position=cache_position,
246
  position_embeddings=position_embeddings,
247
  is_causal=is_causal,
@@ -257,10 +217,11 @@ class Qwen2DecoderLayer(nn.Module):
257
  outputs = (hidden_states,)
258
  if output_attentions:
259
  outputs += (self_attn_weights,)
 
 
260
 
261
  return outputs
262
 
263
-
264
  class Qwen2RotaryEmbedding(nn.Module):
265
  def __init__(self, config: Qwen2Config, device=None):
266
  super().__init__()
@@ -304,24 +265,6 @@ class Qwen2RotaryEmbedding(nn.Module):
304
  sin = sin * self.attention_scaling
305
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
306
 
307
-
308
- QWEN2_START_DOCSTRING = r"""
309
- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
310
- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
311
- etc.)
312
-
313
- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
314
- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
315
- and behavior.
316
-
317
- Parameters:
318
- config ([`Qwen2Config`]):
319
- Model configuration class with all the parameters of the model. Initializing with a config file does not
320
- load the weights associated with the model, only the configuration. Check out the
321
- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
322
- """
323
-
324
-
325
  @add_start_docstrings(
326
  "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
327
  QWEN2_START_DOCSTRING,
@@ -334,11 +277,7 @@ class Qwen2PreTrainedModel(PreTrainedModel):
334
  _skip_keys_device_placement = ["past_key_values"]
335
  _supports_flash_attn_2 = True
336
  _supports_sdpa = True
337
- _supports_flex_attn = True
338
  _supports_cache_class = True
339
- _supports_quantized_cache = True
340
- _supports_static_cache = True
341
- _supports_attention_backend = True
342
 
343
  def _init_weights(self, module):
344
  std = self.config.initializer_range
@@ -351,36 +290,6 @@ class Qwen2PreTrainedModel(PreTrainedModel):
351
  if module.padding_idx is not None:
352
  module.weight.data[module.padding_idx].zero_()
353
 
354
-
355
- QWEN2_INPUTS_DOCSTRING = r"""
356
- Args:
357
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
358
- Indices of input sequence tokens in the vocabulary.
359
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
360
- Mask to avoid performing attention on padding token indices.
361
- position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
362
- Indices of positions of each input sequence tokens in the position embeddings.
363
- past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
364
- Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding.
365
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
366
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
367
- use_cache (`bool`, *optional*):
368
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding.
369
- output_attentions (`bool`, *optional*):
370
- Whether or not to return the attentions tensors of all attention layers.
371
- output_hidden_states (`bool`, *optional*):
372
- Whether or not to return the hidden states of all layers.
373
- return_dict (`bool`, *optional*):
374
- Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
375
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
376
- Indices depicting the position of the input sequence tokens in the sequence.
377
- """
378
-
379
-
380
- @add_start_docstrings(
381
- "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
382
- QWEN2_START_DOCSTRING,
383
- )
384
  class Qwen2Model(Qwen2PreTrainedModel):
385
  def __init__(self, config: Qwen2Config):
386
  super().__init__(config)
@@ -401,7 +310,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
401
  def set_input_embeddings(self, value):
402
  self.embed_tokens = value
403
 
404
- @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
405
  def forward(
406
  self,
407
  input_ids: torch.LongTensor = None,
@@ -431,114 +339,87 @@ class Qwen2Model(Qwen2PreTrainedModel):
431
  use_cache = False
432
  if inputs_embeds is None:
433
  inputs_embeds = self.embed_tokens(input_ids)
434
- if use_cache and past_key_values is None:
435
- past_key_values = DynamicCache()
 
 
 
 
 
436
  if cache_position is None:
437
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
438
  cache_position = torch.arange(
439
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
440
  )
441
  if position_ids is None:
442
  position_ids = cache_position.unsqueeze(0)
443
- causal_mask = self._update_causal_mask(
444
- attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
445
- )
446
  hidden_states = inputs_embeds
447
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
448
  all_hidden_states = () if output_hidden_states else None
449
  all_self_attns = () if output_attentions else None
 
450
 
451
  for decoder_layer in self.layers:
452
  if output_hidden_states:
453
  all_hidden_states += (hidden_states,)
454
- if self.gradient_checkpointing and self.training:
455
- layer_outputs = self._gradient_checkpointing_func(
456
- decoder_layer.__call__,
457
- hidden_states,
458
- causal_mask,
459
- position_ids,
460
- past_key_values,
461
- output_attentions,
462
- use_cache,
463
- cache_position,
464
- position_embeddings,
465
- is_causal,
466
- )
467
- else:
468
- layer_outputs = decoder_layer(
469
- hidden_states,
470
- attention_mask=causal_mask,
471
- position_ids=position_ids,
472
- past_key_value=past_key_values,
473
- output_attentions=output_attentions,
474
- use_cache=use_cache,
475
- cache_position=cache_position,
476
- position_embeddings=position_embeddings,
477
- is_causal=is_causal,
478
- **flash_attn_kwargs,
479
- )
480
  hidden_states = layer_outputs[0]
 
 
481
  if output_attentions:
482
  all_self_attns += (layer_outputs[1],)
483
 
484
  hidden_states = self.norm(hidden_states)
485
  if output_hidden_states:
486
  all_hidden_states += (hidden_states,)
487
- output = BaseModelOutputWithPast(
 
 
 
 
 
488
  last_hidden_state=hidden_states,
489
- past_key_values=past_key_values if use_cache else None,
490
  hidden_states=all_hidden_states,
491
  attentions=all_self_attns,
492
  )
493
- return output if return_dict else output.to_tuple()
494
 
495
- def _update_causal_mask(
496
- self,
497
- attention_mask: torch.Tensor,
498
- input_tensor: torch.Tensor,
499
- cache_position: torch.Tensor,
500
- past_key_values: Cache,
501
- output_attentions: bool,
502
- ):
503
- # Standard causal mask creation logic from transformers, no changes needed here.
504
  if self.config._attn_implementation == "flash_attention_2":
505
  if attention_mask is not None and 0.0 in attention_mask:
506
  return attention_mask
507
  return None
508
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
509
- using_static_cache = isinstance(past_key_values, StaticCache)
510
- if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
511
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
512
- attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training
513
- ):
514
- return None
515
- dtype, device = input_tensor.dtype, input_tensor.device
516
- min_dtype = torch.finfo(dtype).min
517
- sequence_length = input_tensor.shape[1]
518
- if isinstance(past_key_values, StaticCache):
519
- target_length = past_key_values.get_max_cache_shape()
520
- else:
521
- target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length
522
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
523
- if sequence_length != 1:
524
- causal_mask = torch.triu(causal_mask, diagonal=1)
525
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
526
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
 
527
  if attention_mask is not None:
528
- causal_mask = causal_mask.clone()
529
- mask_length = attention_mask.shape[-1]
530
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
531
- padding_mask = padding_mask == 0
532
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
533
- if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions:
534
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
535
  return causal_mask
536
 
537
-
538
- class KwargsForCausalLM(FlashAttentionKwargs, ): ...
539
-
540
-
541
- class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
542
  _tied_weights_keys = ["lm_head.weight"]
543
 
544
  def __init__(self, config):
@@ -573,7 +454,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
573
  input_ids: torch.LongTensor = None,
574
  attention_mask: Optional[torch.Tensor] = None,
575
  position_ids: Optional[torch.LongTensor] = None,
576
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
577
  inputs_embeds: Optional[torch.FloatTensor] = None,
578
  labels: Optional[torch.LongTensor] = None,
579
  use_cache: Optional[bool] = None,
@@ -582,7 +463,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
582
  return_dict: Optional[bool] = None,
583
  cache_position: Optional[torch.LongTensor] = None,
584
  is_causal: bool = True,
585
- **kwargs: Unpack[KwargsForCausalLM],
586
  ) -> Union[Tuple, CausalLMOutputWithPast]:
587
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
588
  output_hidden_states = (
@@ -611,14 +492,11 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
611
  loss = None
612
 
613
  if labels is not None:
614
- # Shift so that tokens < n predict n
615
  shift_logits = logits[..., :-1, :].contiguous()
616
  shift_labels = labels[..., 1:].contiguous()
617
- # Flatten the tokens
618
  loss_fct = torch.nn.CrossEntropyLoss()
619
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
620
  shift_labels = shift_labels.view(-1)
621
- # Ensure labels are on the same device as logits
622
  shift_labels = shift_labels.to(shift_logits.device)
623
  loss = loss_fct(shift_logits, shift_labels)
624
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ # This is a cleaned version of the model script for public release.
16
+ # It imports the MDMGenerationMixin from the accompanying generation_utils.py file.
17
 
18
  import logging
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
  from torch import nn
23
  from transformers.activations import ACT2FN
24
  from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
 
25
  from transformers.modeling_attn_mask_utils import AttentionMaskConverter
26
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
27
  from transformers.modeling_outputs import (
 
38
  replace_return_docstrings,
39
  )
40
 
41
+ # Import the custom generation mixin from the local file in the repo
42
+ from .generation_utils import MDMGenerationMixin
43
+
44
  logger = logging.getLogger(__name__)
45
 
46
  _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
 
90
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  class Qwen2Attention(nn.Module):
94
+ # ... (rest of the class is unchanged)
 
95
  def __init__(self, config: Qwen2Config, layer_idx: int):
96
  super().__init__()
97
  self.config = config
 
111
  position_embeddings: Tuple[torch.Tensor, torch.Tensor],
112
  attention_mask: Optional[torch.Tensor],
113
  past_key_value: Optional[Cache] = None,
114
+ output_attentions: Optional[bool] = False,
115
  cache_position: Optional[torch.LongTensor] = None,
116
  is_causal: bool = True,
117
  **kwargs: Unpack[FlashAttentionKwargs],
 
131
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
132
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
133
 
134
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get(self.config._attn_implementation, None)
135
+ if attention_interface is None:
136
+ raise ValueError(f"Attention implementation {self.config._attn_implementation} not found.")
137
+
138
+ if self.config._attn_implementation == "sdpa" and output_attentions:
139
+ logger.warning_once("Using SDPA with `output_attentions=True` requires eager attention.")
140
+ attention_interface = ALL_ATTENTION_FUNCTIONS["eager"]
141
+
142
+
 
 
 
 
 
 
 
 
 
143
  attn_output, attn_weights = attention_interface(
 
144
  query_states,
145
  key_states,
146
  value_states,
147
+ attention_mask=attention_mask,
148
+ dropout=self.attention_dropout if self.training else 0.0,
149
+ is_causal=is_causal,
 
150
  **kwargs,
151
  )
152
+ attn_output = attn_output.transpose(1, 2).contiguous()
153
+ attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size)
 
154
  attn_output = self.o_proj(attn_output)
155
+
156
+ if not output_attentions:
157
+ attn_weights = None
158
 
159
+ return attn_output, attn_weights, past_key_value
160
 
161
+ # ... (Qwen2RMSNorm, Qwen2DecoderLayer, Qwen2RotaryEmbedding, Qwen2PreTrainedModel, Qwen2Model are unchanged)
162
  class Qwen2RMSNorm(nn.Module):
163
  def __init__(self, hidden_size, eps=1e-6):
164
  super().__init__()
 
172
  hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
173
  return self.weight * hidden_states.to(input_dtype)
174
 
 
 
 
 
175
  class Qwen2DecoderLayer(nn.Module):
176
  def __init__(self, config: Qwen2Config, layer_idx: int):
177
  super().__init__()
 
180
  self.mlp = Qwen2MLP(config)
181
  self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
182
  self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
 
 
 
 
 
183
 
184
  def forward(
185
  self,
 
197
  residual = hidden_states
198
  hidden_states = self.input_layernorm(hidden_states)
199
 
200
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
201
  hidden_states=hidden_states,
202
  attention_mask=attention_mask,
203
  past_key_value=past_key_value,
204
+ output_attentions=output_attentions,
205
  cache_position=cache_position,
206
  position_embeddings=position_embeddings,
207
  is_causal=is_causal,
 
217
  outputs = (hidden_states,)
218
  if output_attentions:
219
  outputs += (self_attn_weights,)
220
+ if use_cache:
221
+ outputs += (present_key_value,)
222
 
223
  return outputs
224
 
 
225
  class Qwen2RotaryEmbedding(nn.Module):
226
  def __init__(self, config: Qwen2Config, device=None):
227
  super().__init__()
 
265
  sin = sin * self.attention_scaling
266
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  @add_start_docstrings(
269
  "The bare Qwen2 Model outputting raw hidden-states without any specific head on top.",
270
  QWEN2_START_DOCSTRING,
 
277
  _skip_keys_device_placement = ["past_key_values"]
278
  _supports_flash_attn_2 = True
279
  _supports_sdpa = True
 
280
  _supports_cache_class = True
 
 
 
281
 
282
  def _init_weights(self, module):
283
  std = self.config.initializer_range
 
290
  if module.padding_idx is not None:
291
  module.weight.data[module.padding_idx].zero_()
292
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  class Qwen2Model(Qwen2PreTrainedModel):
294
  def __init__(self, config: Qwen2Config):
295
  super().__init__(config)
 
310
  def set_input_embeddings(self, value):
311
  self.embed_tokens = value
312
 
 
313
  def forward(
314
  self,
315
  input_ids: torch.LongTensor = None,
 
339
  use_cache = False
340
  if inputs_embeds is None:
341
  inputs_embeds = self.embed_tokens(input_ids)
342
+
343
+ past_key_values_length = 0
344
+ if use_cache:
345
+ if past_key_values is None:
346
+ past_key_values = DynamicCache()
347
+ past_key_values_length = past_key_values.get_seq_length()
348
+
349
  if cache_position is None:
 
350
  cache_position = torch.arange(
351
+ past_key_values_length, past_key_values_length + inputs_embeds.shape[1], device=inputs_embeds.device
352
  )
353
  if position_ids is None:
354
  position_ids = cache_position.unsqueeze(0)
355
+
356
+ causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, is_causal)
 
357
  hidden_states = inputs_embeds
358
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
359
  all_hidden_states = () if output_hidden_states else None
360
  all_self_attns = () if output_attentions else None
361
+ next_decoder_cache = () if use_cache else None
362
 
363
  for decoder_layer in self.layers:
364
  if output_hidden_states:
365
  all_hidden_states += (hidden_states,)
366
+
367
+ layer_outputs = decoder_layer(
368
+ hidden_states,
369
+ attention_mask=causal_mask,
370
+ position_ids=position_ids,
371
+ past_key_value=past_key_values,
372
+ output_attentions=output_attentions,
373
+ use_cache=use_cache,
374
+ cache_position=cache_position,
375
+ position_embeddings=position_embeddings,
376
+ is_causal=is_causal,
377
+ **flash_attn_kwargs,
378
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
379
  hidden_states = layer_outputs[0]
380
+ if use_cache:
381
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
382
  if output_attentions:
383
  all_self_attns += (layer_outputs[1],)
384
 
385
  hidden_states = self.norm(hidden_states)
386
  if output_hidden_states:
387
  all_hidden_states += (hidden_states,)
388
+
389
+ next_cache = next_decoder_cache if use_cache else None
390
+
391
+ if not return_dict:
392
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
393
+ return BaseModelOutputWithPast(
394
  last_hidden_state=hidden_states,
395
+ past_key_values=next_cache,
396
  hidden_states=all_hidden_states,
397
  attentions=all_self_attns,
398
  )
 
399
 
400
+ def _update_causal_mask(self, attention_mask, input_tensor, cache_position, is_causal):
401
+ if not is_causal:
402
+ return attention_mask
403
+
404
+ seq_len = input_tensor.shape[1]
 
 
 
 
405
  if self.config._attn_implementation == "flash_attention_2":
406
  if attention_mask is not None and 0.0 in attention_mask:
407
  return attention_mask
408
  return None
409
+
410
+ dtype = input_tensor.dtype
411
+ device = input_tensor.device
412
+
413
+ causal_mask = torch.triu(torch.full((seq_len, seq_len), torch.finfo(dtype).min, device=device), 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
414
  causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
415
+
416
  if attention_mask is not None:
417
+ causal_mask = causal_mask.clone()
418
+ causal_mask = causal_mask + attention_mask[:, None, None, :]
419
+
 
 
 
 
420
  return causal_mask
421
 
422
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, MDMGenerationMixin):
 
 
 
 
423
  _tied_weights_keys = ["lm_head.weight"]
424
 
425
  def __init__(self, config):
 
454
  input_ids: torch.LongTensor = None,
455
  attention_mask: Optional[torch.Tensor] = None,
456
  position_ids: Optional[torch.LongTensor] = None,
457
+ past_key_values: Optional[Cache] = None,
458
  inputs_embeds: Optional[torch.FloatTensor] = None,
459
  labels: Optional[torch.LongTensor] = None,
460
  use_cache: Optional[bool] = None,
 
463
  return_dict: Optional[bool] = None,
464
  cache_position: Optional[torch.LongTensor] = None,
465
  is_causal: bool = True,
466
+ **kwargs: Unpack[FlashAttentionKwargs],
467
  ) -> Union[Tuple, CausalLMOutputWithPast]:
468
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
469
  output_hidden_states = (
 
492
  loss = None
493
 
494
  if labels is not None:
 
495
  shift_logits = logits[..., :-1, :].contiguous()
496
  shift_labels = labels[..., 1:].contiguous()
 
497
  loss_fct = torch.nn.CrossEntropyLoss()
498
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
499
  shift_labels = shift_labels.view(-1)
 
500
  shift_labels = shift_labels.to(shift_logits.device)
501
  loss = loss_fct(shift_logits, shift_labels)
502