AkashDataScience commited on
Commit
18f12de
·
1 Parent(s): 898922b
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -94,7 +94,7 @@ def infer(message, history):
94
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
95
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
96
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
97
- next_token_embeds = model.get_input_embeddings()(prompt_tokens)
98
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
 
100
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
94
  predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
95
  predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
96
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
97
+ next_token_embeds = model.get_input_embeddings()(predicted_word_token)
98
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
 
100
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]