SamuelYang ctranslate2-4you commited on
Commit
e64fa74
·
verified ·
1 Parent(s): bec332c

Update modeling_qwen.py (#5)

Browse files

- Update modeling_qwen.py (79a97c4ae947769db3fd26494a4c5cf5dc744ee8)


Co-authored-by: Blair Chintella <[email protected]>

Files changed (1) hide show
  1. modeling_qwen.py +55 -15
modeling_qwen.py CHANGED
@@ -17,6 +17,9 @@
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
 
 
 
20
  """ PyTorch Qwen2 model."""
21
  from transformers import Qwen2Config
22
  import inspect
@@ -274,7 +277,9 @@ class Qwen2Attention(nn.Module):
274
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
275
  "with a layer index."
276
  )
277
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
278
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
279
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
280
 
@@ -378,7 +383,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
378
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
379
  "with a layer index."
380
  )
381
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
382
 
383
  # Because the input can be padded, the absolute sequence length depends on the max position id.
384
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
@@ -676,7 +683,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
676
 
677
  kv_seq_len = key_states.shape[-2]
678
  if past_key_value is not None:
679
- kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
 
 
680
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
681
 
682
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
@@ -972,7 +981,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
972
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
973
  )
974
  use_cache = use_cache if use_cache is not None else self.config.use_cache
975
-
976
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
977
 
978
  # retrieve input_ids and inputs_embeds
@@ -993,12 +1001,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
993
  use_cache = False
994
 
995
  past_key_values_length = 0
 
996
 
997
  if use_cache:
998
- use_legacy_cache = not isinstance(past_key_values, Cache)
999
- if use_legacy_cache:
 
 
 
 
 
 
 
 
 
1000
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1001
- past_key_values_length = past_key_values.get_usable_length(seq_length)
 
 
 
 
 
 
1002
 
1003
  if position_ids is None:
1004
  device = input_ids.device if input_ids is not None else inputs_embeds.device
@@ -1104,7 +1128,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
1104
 
1105
  next_cache = None
1106
  if use_cache:
1107
- next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
 
 
 
1108
 
1109
  if not return_dict:
1110
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
@@ -1116,6 +1143,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
1116
  )
1117
 
1118
 
 
1119
  class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1120
  _tied_weights_keys = ["lm_head.weight"]
1121
 
@@ -1243,21 +1271,32 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1243
  # Omit tokens covered by past_key_values
1244
  if past_key_values is not None:
1245
  if isinstance(past_key_values, Cache):
 
1246
  cache_length = past_key_values.get_seq_length()
1247
- past_length = past_key_values.seen_tokens
1248
- max_cache_length = past_key_values.get_max_length()
 
 
 
 
 
 
 
 
 
1249
  else:
 
 
1250
  cache_length = past_length = past_key_values[0][0].shape[2]
1251
  max_cache_length = None
1252
 
1253
  # Keep only the unprocessed tokens:
1254
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1255
- # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
1256
- # input)
1257
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1258
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1259
- # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
1260
- # input_ids based on the past_length.
1261
  elif past_length < input_ids.shape[1]:
1262
  input_ids = input_ids[:, past_length:]
1263
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
@@ -1287,13 +1326,14 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1287
  model_inputs.update(
1288
  {
1289
  "position_ids": position_ids,
1290
- "past_key_values": past_key_values,
1291
  "use_cache": kwargs.get("use_cache"),
1292
  "attention_mask": attention_mask,
1293
  }
1294
  )
1295
  return model_inputs
1296
 
 
1297
  @staticmethod
1298
  def _reorder_cache(past_key_values, beam_idx):
1299
  reordered_past = ()
 
17
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
  # See the License for the specific language governing permissions and
19
  # limitations under the License.
20
+
21
+ # includes edits by https://github.com/BBC-Esq to fix cache errors following transformers version post 4.53.3 major cache refactor
22
+
23
  """ PyTorch Qwen2 model."""
24
  from transformers import Qwen2Config
25
  import inspect
 
277
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
278
  "with a layer index."
279
  )
280
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
281
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
282
+ kv_seq_len += past_len
283
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
284
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
285
 
 
383
  "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
384
  "with a layer index."
385
  )
386
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
387
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
388
+ kv_seq_len += past_len
389
 
390
  # Because the input can be padded, the absolute sequence length depends on the max position id.
391
  rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
 
683
 
684
  kv_seq_len = key_states.shape[-2]
685
  if past_key_value is not None:
686
+ # kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
687
+ past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
688
+ kv_seq_len += past_len
689
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
690
 
691
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
981
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
982
  )
983
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
984
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
985
 
986
  # retrieve input_ids and inputs_embeds
 
1001
  use_cache = False
1002
 
1003
  past_key_values_length = 0
1004
+ use_legacy_cache = False
1005
 
1006
  if use_cache:
1007
+ # OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
1008
+ # directly used legacy methods on it (would crash if None or new API).
1009
+ # use_legacy_cache = not isinstance(past_key_values, Cache)
1010
+ # if use_legacy_cache:
1011
+ # # past_key_values_length = past_key_values.get_seq_length()
1012
+ # past_key_values_length = past_key_values.get_usable_length(seq_length)
1013
+
1014
+ # NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
1015
+ # compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
1016
+ if past_key_values is not None and not isinstance(past_key_values, Cache):
1017
+ use_legacy_cache = True # remember input format for return
1018
  past_key_values = DynamicCache.from_legacy_cache(past_key_values)
1019
+
1020
+ if isinstance(past_key_values, Cache):
1021
+ # Layer-agnostic total length; cache_position is handled deeper if needed
1022
+ past_key_values_length = past_key_values.get_seq_length()
1023
+ else:
1024
+ # No cache given on first forward, keep length at 0
1025
+ past_key_values_length = 0
1026
 
1027
  if position_ids is None:
1028
  device = input_ids.device if input_ids is not None else inputs_embeds.device
 
1128
 
1129
  next_cache = None
1130
  if use_cache:
1131
+ # If the caller passed legacy, return legacy. Otherwise return the Cache object.
1132
+ next_cache = (
1133
+ next_decoder_cache.to_legacy_cache() if (use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
1134
+ )
1135
 
1136
  if not return_dict:
1137
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
1143
  )
1144
 
1145
 
1146
+
1147
  class Qwen2ForCausalLM(Qwen2PreTrainedModel):
1148
  _tied_weights_keys = ["lm_head.weight"]
1149
 
 
1271
  # Omit tokens covered by past_key_values
1272
  if past_key_values is not None:
1273
  if isinstance(past_key_values, Cache):
1274
+ # NEW API (HF >= 4.55): use Cache methods
1275
  cache_length = past_key_values.get_seq_length()
1276
+ past_length = cache_length # `seen_tokens` removed; use total seq length instead
1277
+ try:
1278
+ max_cache_length = past_key_values.get_max_cache_shape()
1279
+ except Exception:
1280
+ max_cache_length = None
1281
+
1282
+ # OLD API (deprecated/removed):
1283
+ # cache_length = past_key_values.get_seq_length()
1284
+ # past_length = past_key_values.seen_tokens
1285
+ # max_cache_length = past_key_values.get_max_length()
1286
+
1287
  else:
1288
+ # Legacy tuple format: keep computing lengths directly from tensors
1289
+ # (We keep it compatible without forcing a conversion here)
1290
  cache_length = past_length = past_key_values[0][0].shape[2]
1291
  max_cache_length = None
1292
 
1293
  # Keep only the unprocessed tokens:
1294
  # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
1295
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
 
1296
  if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1297
  input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1298
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
1299
+ # We can discard input_ids based on the past_length.
1300
  elif past_length < input_ids.shape[1]:
1301
  input_ids = input_ids[:, past_length:]
1302
  # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
 
1326
  model_inputs.update(
1327
  {
1328
  "position_ids": position_ids,
1329
+ "past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
1330
  "use_cache": kwargs.get("use_cache"),
1331
  "attention_mask": attention_mask,
1332
  }
1333
  )
1334
  return model_inputs
1335
 
1336
+
1337
  @staticmethod
1338
  def _reorder_cache(past_key_values, beam_idx):
1339
  reordered_past = ()