Commit
·
18f12de
1
Parent(s):
898922b
Minor fix
Browse files
app.py
CHANGED
@@ -94,7 +94,7 @@ def infer(message, history):
|
|
94 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
|
95 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
96 |
predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
97 |
-
next_token_embeds = model.get_input_embeddings()(
|
98 |
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
99 |
|
100 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|
|
|
94 |
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
|
95 |
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
96 |
predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
97 |
+
next_token_embeds = model.get_input_embeddings()(predicted_word_token)
|
98 |
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
99 |
|
100 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|