Update modeling_qwen.py (#5)
Browse files- Update modeling_qwen.py (79a97c4ae947769db3fd26494a4c5cf5dc744ee8)
Co-authored-by: Blair Chintella <[email protected]>
- modeling_qwen.py +55 -15
modeling_qwen.py
CHANGED
@@ -17,6 +17,9 @@
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
|
|
|
|
|
|
20 |
""" PyTorch Qwen2 model."""
|
21 |
from transformers import Qwen2Config
|
22 |
import inspect
|
@@ -274,7 +277,9 @@ class Qwen2Attention(nn.Module):
|
|
274 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
275 |
"with a layer index."
|
276 |
)
|
277 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
278 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
279 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
280 |
|
@@ -378,7 +383,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
378 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
379 |
"with a layer index."
|
380 |
)
|
381 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
382 |
|
383 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
384 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
@@ -676,7 +683,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
676 |
|
677 |
kv_seq_len = key_states.shape[-2]
|
678 |
if past_key_value is not None:
|
679 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
680 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
681 |
|
682 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
@@ -972,7 +981,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
972 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
973 |
)
|
974 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
975 |
-
|
976 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
977 |
|
978 |
# retrieve input_ids and inputs_embeds
|
@@ -993,12 +1001,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
993 |
use_cache = False
|
994 |
|
995 |
past_key_values_length = 0
|
|
|
996 |
|
997 |
if use_cache:
|
998 |
-
|
999 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1000 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
1001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1002 |
|
1003 |
if position_ids is None:
|
1004 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
@@ -1104,7 +1128,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
1104 |
|
1105 |
next_cache = None
|
1106 |
if use_cache:
|
1107 |
-
|
|
|
|
|
|
|
1108 |
|
1109 |
if not return_dict:
|
1110 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
@@ -1116,6 +1143,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
1116 |
)
|
1117 |
|
1118 |
|
|
|
1119 |
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
1120 |
_tied_weights_keys = ["lm_head.weight"]
|
1121 |
|
@@ -1243,21 +1271,32 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
1243 |
# Omit tokens covered by past_key_values
|
1244 |
if past_key_values is not None:
|
1245 |
if isinstance(past_key_values, Cache):
|
|
|
1246 |
cache_length = past_key_values.get_seq_length()
|
1247 |
-
past_length =
|
1248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1249 |
else:
|
|
|
|
|
1250 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1251 |
max_cache_length = None
|
1252 |
|
1253 |
# Keep only the unprocessed tokens:
|
1254 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1255 |
-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
1256 |
-
# input)
|
1257 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1258 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1259 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
|
1260 |
-
# input_ids based on the past_length.
|
1261 |
elif past_length < input_ids.shape[1]:
|
1262 |
input_ids = input_ids[:, past_length:]
|
1263 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
@@ -1287,13 +1326,14 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
1287 |
model_inputs.update(
|
1288 |
{
|
1289 |
"position_ids": position_ids,
|
1290 |
-
"past_key_values": past_key_values,
|
1291 |
"use_cache": kwargs.get("use_cache"),
|
1292 |
"attention_mask": attention_mask,
|
1293 |
}
|
1294 |
)
|
1295 |
return model_inputs
|
1296 |
|
|
|
1297 |
@staticmethod
|
1298 |
def _reorder_cache(past_key_values, beam_idx):
|
1299 |
reordered_past = ()
|
|
|
17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18 |
# See the License for the specific language governing permissions and
|
19 |
# limitations under the License.
|
20 |
+
|
21 |
+
# includes edits by https://github.com/BBC-Esq to fix cache errors following transformers version post 4.53.3 major cache refactor
|
22 |
+
|
23 |
""" PyTorch Qwen2 model."""
|
24 |
from transformers import Qwen2Config
|
25 |
import inspect
|
|
|
277 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
278 |
"with a layer index."
|
279 |
)
|
280 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
281 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
282 |
+
kv_seq_len += past_len
|
283 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
284 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
285 |
|
|
|
383 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
384 |
"with a layer index."
|
385 |
)
|
386 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
387 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
388 |
+
kv_seq_len += past_len
|
389 |
|
390 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
391 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
|
683 |
|
684 |
kv_seq_len = key_states.shape[-2]
|
685 |
if past_key_value is not None:
|
686 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
687 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
688 |
+
kv_seq_len += past_len
|
689 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
690 |
|
691 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
981 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
982 |
)
|
983 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
984 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
985 |
|
986 |
# retrieve input_ids and inputs_embeds
|
|
|
1001 |
use_cache = False
|
1002 |
|
1003 |
past_key_values_length = 0
|
1004 |
+
use_legacy_cache = False
|
1005 |
|
1006 |
if use_cache:
|
1007 |
+
# OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
|
1008 |
+
# directly used legacy methods on it (would crash if None or new API).
|
1009 |
+
# use_legacy_cache = not isinstance(past_key_values, Cache)
|
1010 |
+
# if use_legacy_cache:
|
1011 |
+
# # past_key_values_length = past_key_values.get_seq_length()
|
1012 |
+
# past_key_values_length = past_key_values.get_usable_length(seq_length)
|
1013 |
+
|
1014 |
+
# NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
|
1015 |
+
# compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
|
1016 |
+
if past_key_values is not None and not isinstance(past_key_values, Cache):
|
1017 |
+
use_legacy_cache = True # remember input format for return
|
1018 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
1019 |
+
|
1020 |
+
if isinstance(past_key_values, Cache):
|
1021 |
+
# Layer-agnostic total length; cache_position is handled deeper if needed
|
1022 |
+
past_key_values_length = past_key_values.get_seq_length()
|
1023 |
+
else:
|
1024 |
+
# No cache given on first forward, keep length at 0
|
1025 |
+
past_key_values_length = 0
|
1026 |
|
1027 |
if position_ids is None:
|
1028 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
1128 |
|
1129 |
next_cache = None
|
1130 |
if use_cache:
|
1131 |
+
# If the caller passed legacy, return legacy. Otherwise return the Cache object.
|
1132 |
+
next_cache = (
|
1133 |
+
next_decoder_cache.to_legacy_cache() if (use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
|
1134 |
+
)
|
1135 |
|
1136 |
if not return_dict:
|
1137 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
1143 |
)
|
1144 |
|
1145 |
|
1146 |
+
|
1147 |
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
1148 |
_tied_weights_keys = ["lm_head.weight"]
|
1149 |
|
|
|
1271 |
# Omit tokens covered by past_key_values
|
1272 |
if past_key_values is not None:
|
1273 |
if isinstance(past_key_values, Cache):
|
1274 |
+
# NEW API (HF >= 4.55): use Cache methods
|
1275 |
cache_length = past_key_values.get_seq_length()
|
1276 |
+
past_length = cache_length # `seen_tokens` removed; use total seq length instead
|
1277 |
+
try:
|
1278 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
1279 |
+
except Exception:
|
1280 |
+
max_cache_length = None
|
1281 |
+
|
1282 |
+
# OLD API (deprecated/removed):
|
1283 |
+
# cache_length = past_key_values.get_seq_length()
|
1284 |
+
# past_length = past_key_values.seen_tokens
|
1285 |
+
# max_cache_length = past_key_values.get_max_length()
|
1286 |
+
|
1287 |
else:
|
1288 |
+
# Legacy tuple format: keep computing lengths directly from tensors
|
1289 |
+
# (We keep it compatible without forcing a conversion here)
|
1290 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
1291 |
max_cache_length = None
|
1292 |
|
1293 |
# Keep only the unprocessed tokens:
|
1294 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
1295 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
|
|
1296 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
1297 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
1298 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
|
1299 |
+
# We can discard input_ids based on the past_length.
|
1300 |
elif past_length < input_ids.shape[1]:
|
1301 |
input_ids = input_ids[:, past_length:]
|
1302 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
1326 |
model_inputs.update(
|
1327 |
{
|
1328 |
"position_ids": position_ids,
|
1329 |
+
"past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
|
1330 |
"use_cache": kwargs.get("use_cache"),
|
1331 |
"attention_mask": attention_mask,
|
1332 |
}
|
1333 |
)
|
1334 |
return model_inputs
|
1335 |
|
1336 |
+
|
1337 |
@staticmethod
|
1338 |
def _reorder_cache(past_key_values, beam_idx):
|
1339 |
reordered_past = ()
|