AkashDataScience commited on
Commit
898922b
·
1 Parent(s): 9214f47

Checking generation

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -86,15 +86,15 @@ def infer(message, history):
86
 
87
  combined_embeds = torch.cat(combined_embeds,dim=1)
88
 
89
- #val_combined_embeds = torch.cat([val_image_embeds, img_token_embeds, val_q_embeds], dim=1) # 4, 69, 2560
90
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)
91
 
92
  for g in range(max_generate_length):
93
- phi_output_logits = model(inputs_embeds=combined_embeds)['logits'] # 4, 69, 51200
94
- predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1) # 4,1,51200
95
- predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) # 4,1
 
96
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
97
- next_token_embeds = model.get_input_embeddings()(prompt_tokens) # 4,1,2560
98
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
 
100
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]
 
86
 
87
  combined_embeds = torch.cat(combined_embeds,dim=1)
88
 
 
89
  predicted_caption = torch.full((1,max_generate_length),50256).to(device)
90
 
91
  for g in range(max_generate_length):
92
+ print(g)
93
+ phi_output_logits = model(inputs_embeds=combined_embeds)['logits']
94
+ predicted_word_token_logits = phi_output_logits[:, -1, :].unsqueeze(1)
95
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
96
  predicted_caption[:,g] = predicted_word_token.view(1,-1)
97
+ next_token_embeds = model.get_input_embeddings()(prompt_tokens)
98
  combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
 
100
  predicted_captions_decoded = tokenizer.batch_decode(predicted_caption,ignore_index = 50256)[0]