orionweller commited on
Commit
a08f1c2
·
1 Parent(s): 00b89f1
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -133,6 +133,7 @@ def encode_queries(dataset_name, postfix):
133
 
134
  encoded_embeds = []
135
  batch_size = 32
 
136
 
137
  for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"):
138
  batch_input_texts = input_texts[start_idx: start_idx + batch_size]
@@ -147,6 +148,7 @@ def encode_queries(dataset_name, postfix):
147
  embeds = F.normalize(embeds, p=2, dim=-1)
148
  encoded_embeds.append(embeds.float().cpu().numpy())
149
 
 
150
  return np.concatenate(encoded_embeds, axis=0)
151
 
152
 
 
133
 
134
  encoded_embeds = []
135
  batch_size = 32
136
+ model = model.cuda()
137
 
138
  for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"):
139
  batch_input_texts = input_texts[start_idx: start_idx + batch_size]
 
148
  embeds = F.normalize(embeds, p=2, dim=-1)
149
  encoded_embeds.append(embeds.float().cpu().numpy())
150
 
151
+ model = model.cpu()
152
  return np.concatenate(encoded_embeds, axis=0)
153
 
154