fredzzp commited on
Commit
5e27940
·
verified ·
1 Parent(s): ed3972d

Initial model upload with custom code

Browse files
Files changed (1) hide show
  1. modeling_qwen2.py +115 -345
modeling_qwen2.py CHANGED
@@ -1,10 +1,4 @@
1
- import logging
2
- from transformers import GenerationMixin
3
- import torch
4
- from typing import Optional, Union, List
5
- from transformers.modeling_outputs import CausalLMOutputWithPast
6
  # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
7
- # Copyright 2025 Bytedance Ltd. and/or its affiliates
8
  #
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
  # you may not use this file except in compliance with the License.
@@ -18,7 +12,10 @@ from transformers.modeling_outputs import CausalLMOutputWithPast
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
 
21
- # adapted from https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2/modeling_qwen2.py
 
 
 
22
  from typing import Callable, List, Optional, Tuple, Union
23
 
24
  import torch
@@ -37,26 +34,12 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
  from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
38
  from transformers.processing_utils import Unpack
39
  from transformers.utils import (
40
- # LossKwargs,
41
  add_start_docstrings,
42
  add_start_docstrings_to_model_forward,
43
  replace_return_docstrings,
44
  )
45
 
46
-
47
- gather_heads_scatter_seq,
48
- gather_seq_scatter_heads,
49
- reduce_sequence_parallel_loss,
50
- )
51
-
52
-
53
- if False:
54
- from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss # type: ignore
55
- from liger_kernel.transformers.rms_norm import LigerRMSNorm
56
- from liger_kernel.transformers.rope import liger_rotary_pos_emb
57
- from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
58
-
59
- logger = logging.get_logger(__name__)
60
 
61
  _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
62
  _CONFIG_FOR_DOC = "Qwen2Config"
@@ -86,25 +69,6 @@ def rotate_half(x):
86
 
87
 
88
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
89
- """Applies Rotary Position Embedding to the query and key tensors.
90
-
91
- Args:
92
- q (`torch.Tensor`): The query tensor.
93
- k (`torch.Tensor`): The key tensor.
94
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
95
- sin (`torch.Tensor`): The sine part of the rotary embedding.
96
- position_ids (`torch.Tensor`, *optional*):
97
- Deprecated and unused.
98
- unsqueeze_dim (`int`, *optional*, defaults to 1):
99
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
100
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
101
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
102
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
103
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
104
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
105
- Returns:
106
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
107
- """
108
  cos = cos.unsqueeze(unsqueeze_dim)
109
  sin = sin.unsqueeze(unsqueeze_dim)
110
  q_embed = (q * cos) + (rotate_half(q) * sin)
@@ -176,24 +140,18 @@ class Qwen2Attention(nn.Module):
176
  is_causal: bool = True,
177
  **kwargs: Unpack[FlashAttentionKwargs],
178
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
179
- bsz, q_len, _ = hidden_states.size() # q_len = seq_length / sp_size
180
  hidden_shape = (bsz, q_len, -1, self.head_dim)
181
 
182
  query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
183
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
184
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
185
- if False:
186
- query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
187
- key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
188
- value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
189
- # (batch_size, num_head / sp_size, seq_length, head_size)
190
 
191
- full_q_len = query_states.size(2) # full_q_len = seq_length
192
  cos, sin = position_embeddings
193
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
194
 
195
  if past_key_value is not None:
196
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
197
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
198
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
199
 
@@ -223,14 +181,11 @@ class Qwen2Attention(nn.Module):
223
  attention_mask,
224
  dropout=0.0 if not self.training else self.attention_dropout,
225
  scaling=self.scaling,
226
- sliding_window=sliding_window, # main diff with Llama
227
  **kwargs,
228
  )
229
 
230
  attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
231
- if False:
232
- attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
233
-
234
  attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size).contiguous()
235
  attn_output = self.o_proj(attn_output)
236
  return attn_output, attn_weights
@@ -238,9 +193,6 @@ class Qwen2Attention(nn.Module):
238
 
239
  class Qwen2RMSNorm(nn.Module):
240
  def __init__(self, hidden_size, eps=1e-6):
241
- """
242
- Qwen2RMSNorm is equivalent to T5LayerNorm
243
- """
244
  super().__init__()
245
  self.weight = nn.Parameter(torch.ones(hidden_size))
246
  self.variance_epsilon = eps
@@ -284,17 +236,12 @@ class Qwen2DecoderLayer(nn.Module):
284
  **kwargs: Unpack[FlashAttentionKwargs],
285
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
286
  residual = hidden_states
287
-
288
  hidden_states = self.input_layernorm(hidden_states)
289
 
290
- # Self Attention
291
  hidden_states, self_attn_weights = self.self_attn(
292
  hidden_states=hidden_states,
293
  attention_mask=attention_mask,
294
- position_ids=position_ids,
295
  past_key_value=past_key_value,
296
- output_attentions=output_attentions,
297
- use_cache=use_cache,
298
  cache_position=cache_position,
299
  position_embeddings=position_embeddings,
300
  is_causal=is_causal,
@@ -302,7 +249,6 @@ class Qwen2DecoderLayer(nn.Module):
302
  )
303
  hidden_states = residual + hidden_states
304
 
305
- # Fully Connected
306
  residual = hidden_states
307
  hidden_states = self.post_attention_layernorm(hidden_states)
308
  hidden_states = self.mlp(hidden_states)
@@ -318,36 +264,25 @@ class Qwen2DecoderLayer(nn.Module):
318
  class Qwen2RotaryEmbedding(nn.Module):
319
  def __init__(self, config: Qwen2Config, device=None):
320
  super().__init__()
321
- # BC: "rope_type" was originally "type"
322
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
323
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
324
  else:
325
  self.rope_type = "default"
326
  self.max_seq_len_cached = config.max_position_embeddings
327
  self.original_max_seq_len = config.max_position_embeddings
328
-
329
  self.config = config
330
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
331
-
332
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
333
  self.register_buffer("inv_freq", inv_freq, persistent=False)
334
  self.original_inv_freq = self.inv_freq
335
 
336
  def _dynamic_frequency_update(self, position_ids, device):
337
- """
338
- dynamic RoPE layers should recompute `inv_freq` in the following situations:
339
- 1 - growing beyond the cached sequence length (allow scaling)
340
- 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
341
- """
342
  seq_len = torch.max(position_ids) + 1
343
- if seq_len > self.max_seq_len_cached: # growth
344
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
345
- self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
346
  self.max_seq_len_cached = seq_len
347
-
348
- if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
349
- # This .to() is needed if the model has been moved to a device after being initialized (because
350
- # the buffer is automatically moved, but not the original copy)
351
  self.original_inv_freq = self.original_inv_freq.to(device)
352
  self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
353
  self.max_seq_len_cached = self.original_max_seq_len
@@ -356,11 +291,8 @@ class Qwen2RotaryEmbedding(nn.Module):
356
  def forward(self, x, position_ids):
357
  if "dynamic" in self.rope_type:
358
  self._dynamic_frequency_update(position_ids, device=x.device)
359
-
360
- # Core RoPE block
361
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
362
  position_ids_expanded = position_ids[:, None, :].float()
363
- # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
364
  device_type = x.device.type
365
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
366
  with torch.autocast(device_type=device_type, enabled=False):
@@ -368,11 +300,8 @@ class Qwen2RotaryEmbedding(nn.Module):
368
  emb = torch.cat((freqs, freqs), dim=-1)
369
  cos = emb.cos()
370
  sin = emb.sin()
371
-
372
- # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
373
  cos = cos * self.attention_scaling
374
  sin = sin * self.attention_scaling
375
-
376
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
377
 
378
 
@@ -426,75 +355,25 @@ class Qwen2PreTrainedModel(PreTrainedModel):
426
  QWEN2_INPUTS_DOCSTRING = r"""
427
  Args:
428
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
429
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
430
- it.
431
-
432
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
433
- [`PreTrainedTokenizer.__call__`] for details.
434
-
435
- [What are input IDs?](../glossary#input-ids)
436
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
437
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
438
-
439
- - 1 for tokens that are **not masked**,
440
- - 0 for tokens that are **masked**.
441
-
442
- [What are attention masks?](../glossary#attention-mask)
443
-
444
- Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
445
- [`PreTrainedTokenizer.__call__`] for details.
446
-
447
- If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
448
- `past_key_values`).
449
-
450
- If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
451
- and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
452
- information on the default strategy.
453
-
454
- - 1 indicates the head is **not masked**,
455
- - 0 indicates the head is **masked**.
456
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
457
- Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
458
- config.n_positions - 1]`.
459
-
460
- [What are position IDs?](../glossary#position-ids)
461
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
462
- Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
463
- blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
464
- returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
465
-
466
- Two formats are allowed:
467
- - a [`~cache_utils.Cache`] instance, see our
468
- [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
469
- - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
470
- shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
471
- cache format.
472
-
473
- The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
474
- legacy cache format will be returned.
475
-
476
- If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
477
- have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
478
- of shape `(batch_size, sequence_length)`.
479
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
480
- Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
481
- is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
482
- model's internal embedding lookup matrix.
483
  use_cache (`bool`, *optional*):
484
- If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
485
- `past_key_values`).
486
  output_attentions (`bool`, *optional*):
487
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
488
- tensors for more detail.
489
  output_hidden_states (`bool`, *optional*):
490
- Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
491
- more detail.
492
  return_dict (`bool`, *optional*):
493
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
494
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
495
- Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
496
- this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
497
- the complete sequence length.
498
  """
499
 
500
 
@@ -503,18 +382,10 @@ QWEN2_INPUTS_DOCSTRING = r"""
503
  QWEN2_START_DOCSTRING,
504
  )
505
  class Qwen2Model(Qwen2PreTrainedModel):
506
- """
507
- Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`]
508
-
509
- Args:
510
- config: Qwen2Config
511
- """
512
-
513
  def __init__(self, config: Qwen2Config):
514
  super().__init__(config)
515
  self.padding_idx = config.pad_token_id
516
  self.vocab_size = config.vocab_size
517
-
518
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
519
  self.layers = nn.ModuleList(
520
  [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
@@ -522,8 +393,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
522
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
523
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
524
  self.gradient_checkpointing = False
525
-
526
- # Initialize weights and apply final processing
527
  self.post_init()
528
 
529
  def get_input_embeddings(self):
@@ -532,79 +401,96 @@ class Qwen2Model(Qwen2PreTrainedModel):
532
  def set_input_embeddings(self, value):
533
  self.embed_tokens = value
534
 
535
-
536
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
537
- @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
538
  def forward(
539
  self,
540
  input_ids: torch.LongTensor = None,
541
  attention_mask: Optional[torch.Tensor] = None,
542
  position_ids: Optional[torch.LongTensor] = None,
543
- past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
544
  inputs_embeds: Optional[torch.FloatTensor] = None,
545
- labels: Optional[torch.LongTensor] = None,
546
  use_cache: Optional[bool] = None,
547
  output_attentions: Optional[bool] = None,
548
  output_hidden_states: Optional[bool] = None,
549
  return_dict: Optional[bool] = None,
550
  cache_position: Optional[torch.LongTensor] = None,
551
  is_causal: bool = True,
552
- **kwargs,
553
- ) -> Union[Tuple, CausalLMOutputWithPast]:
554
- r\"\"\"
555
- Args:
556
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
557
- Labels for computing the masked language modeling loss. Indices should be in `[0, ...,
558
- config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
559
- computed for the tokens with labels in `[0, ..., config.vocab_size - 1]`.
560
- \"\"\"
561
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
562
  output_hidden_states = (
563
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
564
  )
 
565
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
566
 
567
- outputs = self.model(
568
- input_ids=input_ids,
569
- attention_mask=attention_mask,
570
- position_ids=position_ids,
571
- past_key_values=past_key_values,
572
- inputs_embeds=inputs_embeds,
573
- use_cache=use_cache,
574
- output_attentions=output_attentions,
575
- output_hidden_states=output_hidden_states,
576
- return_dict=return_dict,
577
- cache_position=cache_position,
578
- is_causal=is_causal,
579
- **kwargs,
 
 
 
 
 
580
  )
581
-
582
- hidden_states = outputs[0]
583
- logits = self.lm_head(hidden_states)
584
- logits = logits.float()
585
- loss = None
586
-
587
- if labels is not None:
588
- # Maintained for compatibility with Trainer API, but not essential for pure inference
589
- shift_logits = logits[..., :-1, :].contiguous()
590
- shift_labels = labels[..., 1:].contiguous()
591
- loss_fct = torch.nn.CrossEntropyLoss()
592
- shift_logits = shift_logits.view(-1, self.config.vocab_size)
593
- shift_labels = shift_labels.view(-1)
594
- shift_labels = shift_labels.to(shift_logits.device)
595
- loss = loss_fct(shift_logits, shift_labels)
596
-
597
- if not return_dict:
598
- output = (logits,) + outputs[1:]
599
- return (loss,) + output if loss is not None else output
600
-
601
- return CausalLMOutputWithPast(
602
- loss=loss,
603
- logits=logits,
604
- past_key_values=outputs.past_key_values,
605
- hidden_states=outputs.hidden_states,
606
- attentions=outputs.attentions,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  )
 
608
 
609
  def _update_causal_mask(
610
  self,
@@ -614,166 +500,52 @@ class Qwen2Model(Qwen2PreTrainedModel):
614
  past_key_values: Cache,
615
  output_attentions: bool,
616
  ):
 
617
  if self.config._attn_implementation == "flash_attention_2":
618
- if attention_mask is not None and past_key_values is not None:
619
- is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
620
- if is_padding_right:
621
- raise ValueError(
622
- "You are attempting to perform batched generation with padding_side='right'"
623
- " this may lead to unexpected behaviour for Flash Attention version of Qwen2. Make sure to "
624
- " call `tokenizer.padding_side = 'left'` before tokenizing the input. "
625
- )
626
  if attention_mask is not None and 0.0 in attention_mask:
627
  return attention_mask
628
  return None
629
-
630
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
631
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
632
- # to infer the attention mask.
633
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
634
  using_static_cache = isinstance(past_key_values, StaticCache)
635
- using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
636
-
637
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
638
- if (
639
- self.config._attn_implementation == "sdpa"
640
- and not (using_static_cache or using_sliding_window_cache)
641
- and not output_attentions
642
- ):
643
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
644
- attention_mask,
645
- inputs_embeds=input_tensor,
646
- past_key_values_length=past_seen_tokens,
647
- sliding_window=self.config.sliding_window,
648
- is_training=self.training,
649
  ):
650
  return None
651
-
652
  dtype, device = input_tensor.dtype, input_tensor.device
653
  min_dtype = torch.finfo(dtype).min
654
  sequence_length = input_tensor.shape[1]
655
- # SlidingWindowCache or StaticCache
656
- if using_sliding_window_cache or using_static_cache:
657
  target_length = past_key_values.get_max_cache_shape()
658
- # DynamicCache or no cache
659
  else:
660
- target_length = (
661
- attention_mask.shape[-1]
662
- if isinstance(attention_mask, torch.Tensor)
663
- else past_seen_tokens + sequence_length + 1
664
- )
665
-
666
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
667
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
668
- attention_mask,
669
- sequence_length=sequence_length,
670
- target_length=target_length,
671
- dtype=dtype,
672
- device=device,
673
- cache_position=cache_position,
674
- batch_size=input_tensor.shape[0],
675
- config=self.config,
676
- past_key_values=past_key_values,
677
- )
678
-
679
- if (
680
- self.config._attn_implementation == "sdpa"
681
- and attention_mask is not None
682
- and attention_mask.device.type in ["cuda", "xpu"]
683
- and not output_attentions
684
- ):
685
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
686
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
687
- # Details: https://github.com/pytorch/pytorch/issues/110213
688
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
689
-
690
- return causal_mask
691
-
692
- @staticmethod
693
- def _prepare_4d_causal_attention_mask_with_cache_position(
694
- attention_mask: torch.Tensor,
695
- sequence_length: int,
696
- target_length: int,
697
- dtype: torch.dtype,
698
- device: torch.device,
699
- cache_position: torch.Tensor,
700
- batch_size: int,
701
- config: Qwen2Config,
702
- past_key_values: Cache,
703
- ):
704
- """
705
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
706
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
707
-
708
- Args:
709
- attention_mask (`torch.Tensor`):
710
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`.
711
- sequence_length (`int`):
712
- The sequence length being processed.
713
- target_length (`int`):
714
- The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet.
715
- dtype (`torch.dtype`):
716
- The dtype to use for the 4D attention mask.
717
- device (`torch.device`):
718
- The device to plcae the 4D attention mask on.
719
- cache_position (`torch.Tensor`):
720
- Indices depicting the position of the input sequence tokens in the sequence.
721
- batch_size (`torch.Tensor`):
722
- Batch size.
723
- config (`Qwen2Config`):
724
- The model's configuration class
725
- past_key_values (`Cache`):
726
- The cache class that is being used currently to generate
727
- """
728
- if attention_mask is not None and attention_mask.dim() == 4:
729
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
730
- causal_mask = attention_mask
731
- else:
732
- min_dtype = torch.finfo(dtype).min
733
- causal_mask = torch.full(
734
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
735
- )
736
- diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
737
- if config.sliding_window is not None:
738
- # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also
739
- # the check is needed to verify is current checkpoint was trained with sliding window or not
740
- if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length:
741
- sliding_attend_mask = torch.arange(target_length, device=device) <= (
742
- cache_position.reshape(-1, 1) - config.sliding_window
743
- )
744
- diagonal_attend_mask.bitwise_or_(sliding_attend_mask)
745
- causal_mask *= diagonal_attend_mask
746
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
747
- if attention_mask is not None:
748
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
749
- if attention_mask.shape[-1] > target_length:
750
- attention_mask = attention_mask[:, :target_length]
751
- mask_length = attention_mask.shape[-1]
752
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
753
- causal_mask.device
754
- )
755
- padding_mask = padding_mask == 0
756
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
757
- padding_mask, min_dtype
758
- )
759
  return causal_mask
760
 
761
 
762
  class KwargsForCausalLM(FlashAttentionKwargs, ): ...
763
 
764
 
765
- class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
766
  _tied_weights_keys = ["lm_head.weight"]
767
- _tp_plan = {"lm_head": "colwise_rep"}
768
- _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
769
 
770
  def __init__(self, config):
771
  super().__init__(config)
772
  self.model = Qwen2Model(config)
773
  self.vocab_size = config.vocab_size
774
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
775
-
776
- # Initialize weights and apply final processing
777
  self.post_init()
778
 
779
  def get_input_embeddings(self):
@@ -794,7 +566,6 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
794
  def get_decoder(self):
795
  return self.model
796
 
797
-
798
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
799
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
800
  def forward(
@@ -811,15 +582,8 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
811
  return_dict: Optional[bool] = None,
812
  cache_position: Optional[torch.LongTensor] = None,
813
  is_causal: bool = True,
814
- **kwargs,
815
  ) -> Union[Tuple, CausalLMOutputWithPast]:
816
- r\"\"\"
817
- Args:
818
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
819
- Labels for computing the masked language modeling loss. Indices should be in `[0, ...,
820
- config.vocab_size - 1]`. Tokens with indices set to `-100` are ignored (masked), the loss is only
821
- computed for the tokens with labels in `[0, ..., config.vocab_size - 1]`.
822
- \"\"\"
823
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
824
  output_hidden_states = (
825
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -847,12 +611,14 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
847
  loss = None
848
 
849
  if labels is not None:
850
- # Maintained for compatibility with Trainer API, but not essential for pure inference
851
  shift_logits = logits[..., :-1, :].contiguous()
852
  shift_labels = labels[..., 1:].contiguous()
 
853
  loss_fct = torch.nn.CrossEntropyLoss()
854
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
855
  shift_labels = shift_labels.view(-1)
 
856
  shift_labels = shift_labels.to(shift_logits.device)
857
  loss = loss_fct(shift_logits, shift_labels)
858
 
@@ -867,3 +633,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
867
  hidden_states=outputs.hidden_states,
868
  attentions=outputs.attentions,
869
  )
 
 
 
 
 
 
 
 
 
 
1
  # Copyright 2024 The Qwen team, Alibaba Group and The HuggingFace Inc. team. All rights reserved.
 
2
  #
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
  # you may not use this file except in compliance with the License.
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
+ # This is a cleaned version of the original script, with proprietary dependencies
16
+ # and training-specific code removed for public release.
17
+
18
+ import logging
19
  from typing import Callable, List, Optional, Tuple, Union
20
 
21
  import torch
 
34
  from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
35
  from transformers.processing_utils import Unpack
36
  from transformers.utils import (
 
37
  add_start_docstrings,
38
  add_start_docstrings_to_model_forward,
39
  replace_return_docstrings,
40
  )
41
 
42
+ logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  _CHECKPOINT_FOR_DOC = "meta-qwen2/Qwen2-2-7b-hf"
45
  _CONFIG_FOR_DOC = "Qwen2Config"
 
69
 
70
 
71
  def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  cos = cos.unsqueeze(unsqueeze_dim)
73
  sin = sin.unsqueeze(unsqueeze_dim)
74
  q_embed = (q * cos) + (rotate_half(q) * sin)
 
140
  is_causal: bool = True,
141
  **kwargs: Unpack[FlashAttentionKwargs],
142
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
143
+ bsz, q_len, _ = hidden_states.size()
144
  hidden_shape = (bsz, q_len, -1, self.head_dim)
145
 
146
  query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
147
  key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
148
  value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
 
 
 
 
 
149
 
150
+ full_q_len = query_states.size(2)
151
  cos, sin = position_embeddings
152
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
153
 
154
  if past_key_value is not None:
 
155
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
156
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
157
 
 
181
  attention_mask,
182
  dropout=0.0 if not self.training else self.attention_dropout,
183
  scaling=self.scaling,
184
+ sliding_window=sliding_window,
185
  **kwargs,
186
  )
187
 
188
  attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
 
 
 
189
  attn_output = attn_output.reshape(bsz, q_len, self.config.hidden_size).contiguous()
190
  attn_output = self.o_proj(attn_output)
191
  return attn_output, attn_weights
 
193
 
194
  class Qwen2RMSNorm(nn.Module):
195
  def __init__(self, hidden_size, eps=1e-6):
 
 
 
196
  super().__init__()
197
  self.weight = nn.Parameter(torch.ones(hidden_size))
198
  self.variance_epsilon = eps
 
236
  **kwargs: Unpack[FlashAttentionKwargs],
237
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
238
  residual = hidden_states
 
239
  hidden_states = self.input_layernorm(hidden_states)
240
 
 
241
  hidden_states, self_attn_weights = self.self_attn(
242
  hidden_states=hidden_states,
243
  attention_mask=attention_mask,
 
244
  past_key_value=past_key_value,
 
 
245
  cache_position=cache_position,
246
  position_embeddings=position_embeddings,
247
  is_causal=is_causal,
 
249
  )
250
  hidden_states = residual + hidden_states
251
 
 
252
  residual = hidden_states
253
  hidden_states = self.post_attention_layernorm(hidden_states)
254
  hidden_states = self.mlp(hidden_states)
 
264
  class Qwen2RotaryEmbedding(nn.Module):
265
  def __init__(self, config: Qwen2Config, device=None):
266
  super().__init__()
 
267
  if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
268
  self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
269
  else:
270
  self.rope_type = "default"
271
  self.max_seq_len_cached = config.max_position_embeddings
272
  self.original_max_seq_len = config.max_position_embeddings
 
273
  self.config = config
274
  self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
 
275
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
276
  self.register_buffer("inv_freq", inv_freq, persistent=False)
277
  self.original_inv_freq = self.inv_freq
278
 
279
  def _dynamic_frequency_update(self, position_ids, device):
 
 
 
 
 
280
  seq_len = torch.max(position_ids) + 1
281
+ if seq_len > self.max_seq_len_cached:
282
  inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
283
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
284
  self.max_seq_len_cached = seq_len
285
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len:
 
 
 
286
  self.original_inv_freq = self.original_inv_freq.to(device)
287
  self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
288
  self.max_seq_len_cached = self.original_max_seq_len
 
291
  def forward(self, x, position_ids):
292
  if "dynamic" in self.rope_type:
293
  self._dynamic_frequency_update(position_ids, device=x.device)
 
 
294
  inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
295
  position_ids_expanded = position_ids[:, None, :].float()
 
296
  device_type = x.device.type
297
  device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
298
  with torch.autocast(device_type=device_type, enabled=False):
 
300
  emb = torch.cat((freqs, freqs), dim=-1)
301
  cos = emb.cos()
302
  sin = emb.sin()
 
 
303
  cos = cos * self.attention_scaling
304
  sin = sin * self.attention_scaling
 
305
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
306
 
307
 
 
355
  QWEN2_INPUTS_DOCSTRING = r"""
356
  Args:
357
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
358
+ Indices of input sequence tokens in the vocabulary.
 
 
 
 
 
 
359
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
360
+ Mask to avoid performing attention on padding token indices.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
362
+ Indices of positions of each input sequence tokens in the position embeddings.
 
 
 
363
  past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
364
+ Pre-computed hidden-states (key and values in the self-attention blocks) that can be used to speed up sequential decoding.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
  inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
366
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
 
 
367
  use_cache (`bool`, *optional*):
368
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding.
 
369
  output_attentions (`bool`, *optional*):
370
+ Whether or not to return the attentions tensors of all attention layers.
 
371
  output_hidden_states (`bool`, *optional*):
372
+ Whether or not to return the hidden states of all layers.
 
373
  return_dict (`bool`, *optional*):
374
  Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
375
  cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
376
+ Indices depicting the position of the input sequence tokens in the sequence.
 
 
377
  """
378
 
379
 
 
382
  QWEN2_START_DOCSTRING,
383
  )
384
  class Qwen2Model(Qwen2PreTrainedModel):
 
 
 
 
 
 
 
385
  def __init__(self, config: Qwen2Config):
386
  super().__init__(config)
387
  self.padding_idx = config.pad_token_id
388
  self.vocab_size = config.vocab_size
 
389
  self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
390
  self.layers = nn.ModuleList(
391
  [Qwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
 
393
  self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
394
  self.rotary_emb = Qwen2RotaryEmbedding(config=config)
395
  self.gradient_checkpointing = False
 
 
396
  self.post_init()
397
 
398
  def get_input_embeddings(self):
 
401
  def set_input_embeddings(self, value):
402
  self.embed_tokens = value
403
 
 
404
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
 
405
  def forward(
406
  self,
407
  input_ids: torch.LongTensor = None,
408
  attention_mask: Optional[torch.Tensor] = None,
409
  position_ids: Optional[torch.LongTensor] = None,
410
+ past_key_values: Optional[Cache] = None,
411
  inputs_embeds: Optional[torch.FloatTensor] = None,
 
412
  use_cache: Optional[bool] = None,
413
  output_attentions: Optional[bool] = None,
414
  output_hidden_states: Optional[bool] = None,
415
  return_dict: Optional[bool] = None,
416
  cache_position: Optional[torch.LongTensor] = None,
417
  is_causal: bool = True,
418
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
419
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
 
 
 
 
 
 
 
420
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
421
  output_hidden_states = (
422
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
423
  )
424
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
425
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
426
 
427
+ if (input_ids is None) ^ (inputs_embeds is not None):
428
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
429
+ if self.gradient_checkpointing and self.training and use_cache:
430
+ logger.warning_once("`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.")
431
+ use_cache = False
432
+ if inputs_embeds is None:
433
+ inputs_embeds = self.embed_tokens(input_ids)
434
+ if use_cache and past_key_values is None:
435
+ past_key_values = DynamicCache()
436
+ if cache_position is None:
437
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
438
+ cache_position = torch.arange(
439
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
440
+ )
441
+ if position_ids is None:
442
+ position_ids = cache_position.unsqueeze(0)
443
+ causal_mask = self._update_causal_mask(
444
+ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
445
  )
446
+ hidden_states = inputs_embeds
447
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
448
+ all_hidden_states = () if output_hidden_states else None
449
+ all_self_attns = () if output_attentions else None
450
+
451
+ for decoder_layer in self.layers:
452
+ if output_hidden_states:
453
+ all_hidden_states += (hidden_states,)
454
+ if self.gradient_checkpointing and self.training:
455
+ layer_outputs = self._gradient_checkpointing_func(
456
+ decoder_layer.__call__,
457
+ hidden_states,
458
+ causal_mask,
459
+ position_ids,
460
+ past_key_values,
461
+ output_attentions,
462
+ use_cache,
463
+ cache_position,
464
+ position_embeddings,
465
+ is_causal,
466
+ )
467
+ else:
468
+ layer_outputs = decoder_layer(
469
+ hidden_states,
470
+ attention_mask=causal_mask,
471
+ position_ids=position_ids,
472
+ past_key_value=past_key_values,
473
+ output_attentions=output_attentions,
474
+ use_cache=use_cache,
475
+ cache_position=cache_position,
476
+ position_embeddings=position_embeddings,
477
+ is_causal=is_causal,
478
+ **flash_attn_kwargs,
479
+ )
480
+ hidden_states = layer_outputs[0]
481
+ if output_attentions:
482
+ all_self_attns += (layer_outputs[1],)
483
+
484
+ hidden_states = self.norm(hidden_states)
485
+ if output_hidden_states:
486
+ all_hidden_states += (hidden_states,)
487
+ output = BaseModelOutputWithPast(
488
+ last_hidden_state=hidden_states,
489
+ past_key_values=past_key_values if use_cache else None,
490
+ hidden_states=all_hidden_states,
491
+ attentions=all_self_attns,
492
  )
493
+ return output if return_dict else output.to_tuple()
494
 
495
  def _update_causal_mask(
496
  self,
 
500
  past_key_values: Cache,
501
  output_attentions: bool,
502
  ):
503
+ # Standard causal mask creation logic from transformers, no changes needed here.
504
  if self.config._attn_implementation == "flash_attention_2":
 
 
 
 
 
 
 
 
505
  if attention_mask is not None and 0.0 in attention_mask:
506
  return attention_mask
507
  return None
 
 
 
 
508
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
509
  using_static_cache = isinstance(past_key_values, StaticCache)
510
+ if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
 
 
 
 
 
 
 
511
  if AttentionMaskConverter._ignore_causal_mask_sdpa(
512
+ attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens, is_training=self.training
 
 
 
 
513
  ):
514
  return None
 
515
  dtype, device = input_tensor.dtype, input_tensor.device
516
  min_dtype = torch.finfo(dtype).min
517
  sequence_length = input_tensor.shape[1]
518
+ if isinstance(past_key_values, StaticCache):
 
519
  target_length = past_key_values.get_max_cache_shape()
 
520
  else:
521
+ target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + sequence_length
522
+ causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
523
+ if sequence_length != 1:
524
+ causal_mask = torch.triu(causal_mask, diagonal=1)
525
+ causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
526
+ causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
527
+ if attention_mask is not None:
528
+ causal_mask = causal_mask.clone()
529
+ mask_length = attention_mask.shape[-1]
530
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
531
+ padding_mask = padding_mask == 0
532
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
533
+ if self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" and not output_attentions:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  return causal_mask
536
 
537
 
538
  class KwargsForCausalLM(FlashAttentionKwargs, ): ...
539
 
540
 
541
+ class Qwen2ForCausalLM(Qwen2PreTrainedModel, GenerationMixin):
542
  _tied_weights_keys = ["lm_head.weight"]
 
 
543
 
544
  def __init__(self, config):
545
  super().__init__(config)
546
  self.model = Qwen2Model(config)
547
  self.vocab_size = config.vocab_size
548
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
549
  self.post_init()
550
 
551
  def get_input_embeddings(self):
 
566
  def get_decoder(self):
567
  return self.model
568
 
 
569
  @add_start_docstrings_to_model_forward(QWEN2_INPUTS_DOCSTRING)
570
  @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
571
  def forward(
 
582
  return_dict: Optional[bool] = None,
583
  cache_position: Optional[torch.LongTensor] = None,
584
  is_causal: bool = True,
585
+ **kwargs: Unpack[KwargsForCausalLM],
586
  ) -> Union[Tuple, CausalLMOutputWithPast]:
 
 
 
 
 
 
 
587
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
588
  output_hidden_states = (
589
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
611
  loss = None
612
 
613
  if labels is not None:
614
+ # Shift so that tokens < n predict n
615
  shift_logits = logits[..., :-1, :].contiguous()
616
  shift_labels = labels[..., 1:].contiguous()
617
+ # Flatten the tokens
618
  loss_fct = torch.nn.CrossEntropyLoss()
619
  shift_logits = shift_logits.view(-1, self.config.vocab_size)
620
  shift_labels = shift_labels.view(-1)
621
+ # Ensure labels are on the same device as logits
622
  shift_labels = shift_labels.to(shift_logits.device)
623
  loss = loss_fct(shift_logits, shift_labels)
624
 
 
633
  hidden_states=outputs.hidden_states,
634
  attentions=outputs.attentions,
635
  )
636
+
637
+ ModelClass = Qwen2ForCausalLM
638
+
639
+ __all__ = ["Qwen2ForCausalLM", "Qwen2Model", "Qwen2PreTrainedModel"]