orionweller commited on
Commit
3118b6d
·
verified ·
1 Parent(s): c27007f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -141,10 +141,11 @@ def encode_queries(dataset_name, postfix):
141
  batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()}
142
 
143
  with torch.cuda.amp.autocast():
144
- outputs = model(**batch_dict)
145
- embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'])
146
- embeds = F.normalize(embeds, p=2, dim=-1)
147
- encoded_embeds.append(embeds.cpu().numpy())
 
148
 
149
  return np.concatenate(encoded_embeds, axis=0)
150
 
 
141
  batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()}
142
 
143
  with torch.cuda.amp.autocast():
144
+ with torch.no_grad():
145
+ outputs = model(**batch_dict)
146
+ embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'])
147
+ embeds = F.normalize(embeds, p=2, dim=-1)
148
+ encoded_embeds.append(embeds.cpu().numpy())
149
 
150
  return np.concatenate(encoded_embeds, axis=0)
151