echarlaix HF Staff commited on
Commit
59d431b
·
1 Parent(s): c9c8ab1
Files changed (1) hide show
  1. modeling_jais.py +0 -3
modeling_jais.py CHANGED
@@ -813,17 +813,14 @@ class JAISModel(JAISPreTrainedModel):
813
  if position_ids is not None:
814
  position_ids = position_ids.view(-1, input_shape[-1])
815
 
816
- import pdb;pdb.set_trace()
817
  if past_key_values is None:
818
  past_length = 0
819
  past_key_values = tuple([None] * len(self.h))
820
  else:
821
  if isinstance(past_key_values, tuple):
822
- import pdb;pdb.set_trace()
823
  past_length = past_key_values[0][0].size(-2)
824
  else:
825
  past_length = past_key_values.get_seq_length()
826
- #past_length = past_key_values[0][0].size(-2)
827
  if position_ids is None:
828
  position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
829
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
 
813
  if position_ids is not None:
814
  position_ids = position_ids.view(-1, input_shape[-1])
815
 
 
816
  if past_key_values is None:
817
  past_length = 0
818
  past_key_values = tuple([None] * len(self.h))
819
  else:
820
  if isinstance(past_key_values, tuple):
 
821
  past_length = past_key_values[0][0].size(-2)
822
  else:
823
  past_length = past_key_values.get_seq_length()
 
824
  if position_ids is None:
825
  position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
826
  position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])