Commit
·
898922b
1
Parent(s):
9214f47
Checking generation
Browse files
app.py
CHANGED
@@ -86,15 +86,15 @@ def infer(message, history):
|
|
86 |
|
87 |
combined_embeds = torch.cat(combined_embeds,dim=1)
|
88 |
|
89 |
-
#val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
|
90 |
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
|
91 |
|
92 |
for g in range(max_generate_length):
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
96 |
predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
97 |
-
next_token_embeds = model.get_input_embeddings()(prompt_tokens)
|
98 |
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
99 |
|
100 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|
|
|
86 |
|
87 |
combined_embeds = torch.cat(combined_embeds,dim=1)
|
88 |
|
|
|
89 |
predicted_caption = torch.full((1,max_generate_length),50256).to(device)
|
90 |
|
91 |
for g in range(max_generate_length):
|
92 |
+
print(g)
|
93 |
+
phi_output_logits = model(inputs_embeds=combined_embeds)['logits']
|
94 |
+
predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
|
95 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
96 |
predicted_caption[:,g] = predicted_word_token.view(1,-1)
|
97 |
+
next_token_embeds = model.get_input_embeddings()(prompt_tokens)
|
98 |
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
99 |
|
100 |
predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
|