fix
Browse files- modeling_jais.py +0 -3
modeling_jais.py
CHANGED
@@ -813,17 +813,14 @@ class JAISModel(JAISPreTrainedModel):
|
|
813 |
if position_ids is not None:
|
814 |
position_ids = position_ids.view(-1, input_shape[-1])
|
815 |
|
816 |
-
import pdb;pdb.set_trace()
|
817 |
if past_key_values is None:
|
818 |
past_length = 0
|
819 |
past_key_values = tuple([None] * len(self.h))
|
820 |
else:
|
821 |
if isinstance(past_key_values, tuple):
|
822 |
-
import pdb;pdb.set_trace()
|
823 |
past_length = past_key_values[0][0].size(-2)
|
824 |
else:
|
825 |
past_length = past_key_values.get_seq_length()
|
826 |
-
#past_length = past_key_values[0][0].size(-2)
|
827 |
if position_ids is None:
|
828 |
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
829 |
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
|
|
813 |
if position_ids is not None:
|
814 |
position_ids = position_ids.view(-1, input_shape[-1])
|
815 |
|
|
|
816 |
if past_key_values is None:
|
817 |
past_length = 0
|
818 |
past_key_values = tuple([None] * len(self.h))
|
819 |
else:
|
820 |
if isinstance(past_key_values, tuple):
|
|
|
821 |
past_length = past_key_values[0][0].size(-2)
|
822 |
else:
|
823 |
past_length = past_key_values.get_seq_length()
|
|
|
824 |
if position_ids is None:
|
825 |
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
826 |
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|