duzx16
commited on
Commit
·
c949d03
1
Parent(s):
0cfae21
Use dynamic dtype for prompts
Browse files- modeling_chatglm.py +7 -5
modeling_chatglm.py
CHANGED
|
@@ -804,9 +804,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 804 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
| 805 |
self.word_embeddings = new_embeddings
|
| 806 |
|
| 807 |
-
def get_prompt(self, batch_size, device):
|
| 808 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
| 809 |
-
past_key_values = self.prefix_encoder(prefix_tokens).
|
| 810 |
past_key_values = past_key_values.view(
|
| 811 |
batch_size,
|
| 812 |
self.pre_seq_len,
|
|
@@ -896,9 +896,13 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 896 |
else:
|
| 897 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 898 |
|
|
|
|
|
|
|
|
|
|
| 899 |
if past_key_values is None:
|
| 900 |
if self.pre_seq_len is not None:
|
| 901 |
-
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device
|
|
|
|
| 902 |
else:
|
| 903 |
past_key_values = tuple([None] * len(self.layers))
|
| 904 |
|
|
@@ -927,8 +931,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 927 |
gmask=use_gmask
|
| 928 |
)
|
| 929 |
|
| 930 |
-
if inputs_embeds is None:
|
| 931 |
-
inputs_embeds = self.word_embeddings(input_ids)
|
| 932 |
|
| 933 |
# [seq_len, batch, hidden_size]
|
| 934 |
hidden_states = inputs_embeds.transpose(0, 1)
|
|
|
|
| 804 |
def set_input_embeddings(self, new_embeddings: torch.Tensor):
|
| 805 |
self.word_embeddings = new_embeddings
|
| 806 |
|
| 807 |
+
def get_prompt(self, batch_size, device, dtype=torch.half):
|
| 808 |
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
|
| 809 |
+
past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
|
| 810 |
past_key_values = past_key_values.view(
|
| 811 |
batch_size,
|
| 812 |
self.pre_seq_len,
|
|
|
|
| 896 |
else:
|
| 897 |
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 898 |
|
| 899 |
+
if inputs_embeds is None:
|
| 900 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
| 901 |
+
|
| 902 |
if past_key_values is None:
|
| 903 |
if self.pre_seq_len is not None:
|
| 904 |
+
past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
|
| 905 |
+
dtype=inputs_embeds.dtype)
|
| 906 |
else:
|
| 907 |
past_key_values = tuple([None] * len(self.layers))
|
| 908 |
|
|
|
|
| 931 |
gmask=use_gmask
|
| 932 |
)
|
| 933 |
|
|
|
|
|
|
|
| 934 |
|
| 935 |
# [seq_len, batch, hidden_size]
|
| 936 |
hidden_states = inputs_embeds.transpose(0, 1)
|