Spaces:
Sleeping
Sleeping
Commit
·
c837e28
1
Parent(s):
05ff7af
load faiss
Browse files
app.py
CHANGED
@@ -7,7 +7,6 @@ import torch
|
|
7 |
import torch.nn.functional as F
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed
|
9 |
from peft import PeftModel
|
10 |
-
from tevatron.retriever.searcher import FaissFlatSearcher
|
11 |
import logging
|
12 |
import os
|
13 |
import json
|
@@ -47,7 +46,6 @@ current_dataset = "scifact"
|
|
47 |
def log_system_info():
|
48 |
logger.info("System Information:")
|
49 |
logger.info(f"Python version: {sys.version}")
|
50 |
-
# logger.info(f"Platform: {platform.platform()}")
|
51 |
|
52 |
logger.info("\nPackage Versions:")
|
53 |
logger.info(f"torch: {torch.__version__}")
|
@@ -55,7 +53,6 @@ def log_system_info():
|
|
55 |
logger.info(f"peft: {peft.__version__}")
|
56 |
logger.info(f"faiss: {faiss.__version__}")
|
57 |
logger.info(f"gradio: {gr.__version__}")
|
58 |
-
# logger.info(f"pytrec_eval: {pytrec_eval.__version__}")
|
59 |
logger.info(f"ir_datasets: {ir_datasets.__version__}")
|
60 |
|
61 |
if torch.cuda.is_available():
|
@@ -70,11 +67,8 @@ def log_system_info():
|
|
70 |
logger.info("\nCUDA Information:")
|
71 |
logger.info("CUDA available: No")
|
72 |
|
73 |
-
|
74 |
log_system_info()
|
75 |
|
76 |
-
|
77 |
-
|
78 |
def pool(last_hidden_states, attention_mask, pool_type="last"):
|
79 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
80 |
|
@@ -151,18 +145,45 @@ class RepLlamaModel:
|
|
151 |
self.model = self.model.cpu()
|
152 |
return np.concatenate(all_embeddings, axis=0)
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
def
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
return None
|
161 |
|
162 |
def search_queries(dataset_name, q_reps, depth=1000):
|
163 |
-
faiss_index =
|
164 |
-
if faiss_index is None:
|
165 |
-
raise ValueError(f"No FAISS index found for dataset {dataset_name}")
|
166 |
|
167 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
168 |
|
@@ -171,28 +192,11 @@ def search_queries(dataset_name, q_reps, depth=1000):
|
|
171 |
|
172 |
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
173 |
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
174 |
-
|
175 |
|
176 |
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
177 |
|
178 |
return all_scores, np.array(psg_indices)
|
179 |
|
180 |
-
def load_corpus_lookups(dataset_name):
|
181 |
-
global corpus_lookups
|
182 |
-
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
|
183 |
-
index_files = glob.glob(corpus_path)
|
184 |
-
# sort them
|
185 |
-
index_files.sort(key=lambda x: int(x.split('.')[-2]))
|
186 |
-
|
187 |
-
corpus_lookups[dataset_name] = []
|
188 |
-
for file in index_files:
|
189 |
-
with open(file, 'rb') as f:
|
190 |
-
_, p_lookup = pickle.load(f)
|
191 |
-
corpus_lookups[dataset_name] += p_lookup
|
192 |
-
|
193 |
-
logger.info(f"Loaded corpus lookups for {dataset_name}. Total entries: {len(corpus_lookups[dataset_name])}")
|
194 |
-
logger.info(f"Sample corpus lookup entry: {corpus_lookups[dataset_name][:10]}")
|
195 |
-
|
196 |
def load_queries(dataset_name):
|
197 |
global queries, q_lookups, qrels, query2qid
|
198 |
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
|
@@ -214,7 +218,6 @@ def load_queries(dataset_name):
|
|
214 |
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
215 |
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
216 |
|
217 |
-
|
218 |
def evaluate(qrels, results, k_values):
|
219 |
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
|
220 |
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
|
@@ -273,7 +276,6 @@ def run_evaluation(dataset, postfix):
|
|
273 |
logger.info(f"Number of results: {len(results)}")
|
274 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
275 |
|
276 |
-
# Add these lines
|
277 |
logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
|
278 |
logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
|
279 |
logger.info(f"Number of queries in results: {len(results)}")
|
@@ -293,13 +295,10 @@ def run_evaluation(dataset, postfix):
|
|
293 |
def gradio_interface(dataset, postfix):
|
294 |
return run_evaluation(dataset, postfix)
|
295 |
|
296 |
-
|
297 |
if model is None:
|
298 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
299 |
-
load_corpus_lookups(current_dataset)
|
300 |
load_queries(current_dataset)
|
301 |
|
302 |
-
|
303 |
# Create Gradio interface
|
304 |
iface = gr.Interface(
|
305 |
fn=gradio_interface,
|
@@ -318,4 +317,4 @@ iface = gr.Interface(
|
|
318 |
)
|
319 |
|
320 |
# Launch the interface
|
321 |
-
iface.launch()
|
|
|
7 |
import torch.nn.functional as F
|
8 |
from transformers import AutoTokenizer, AutoModel, set_seed
|
9 |
from peft import PeftModel
|
|
|
10 |
import logging
|
11 |
import os
|
12 |
import json
|
|
|
46 |
def log_system_info():
|
47 |
logger.info("System Information:")
|
48 |
logger.info(f"Python version: {sys.version}")
|
|
|
49 |
|
50 |
logger.info("\nPackage Versions:")
|
51 |
logger.info(f"torch: {torch.__version__}")
|
|
|
53 |
logger.info(f"peft: {peft.__version__}")
|
54 |
logger.info(f"faiss: {faiss.__version__}")
|
55 |
logger.info(f"gradio: {gr.__version__}")
|
|
|
56 |
logger.info(f"ir_datasets: {ir_datasets.__version__}")
|
57 |
|
58 |
if torch.cuda.is_available():
|
|
|
67 |
logger.info("\nCUDA Information:")
|
68 |
logger.info("CUDA available: No")
|
69 |
|
|
|
70 |
log_system_info()
|
71 |
|
|
|
|
|
72 |
def pool(last_hidden_states, attention_mask, pool_type="last"):
|
73 |
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
74 |
|
|
|
145 |
self.model = self.model.cpu()
|
146 |
return np.concatenate(all_embeddings, axis=0)
|
147 |
|
148 |
+
def load_corpus_embeddings(dataset_name):
|
149 |
+
corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
|
150 |
+
index_files = glob.glob(corpus_path)
|
151 |
+
index_files.sort(key=lambda x: int(x.split('.')[-2]))
|
152 |
+
|
153 |
+
all_embeddings = []
|
154 |
+
corpus_lookups = []
|
155 |
+
|
156 |
+
for file in index_files:
|
157 |
+
with open(file, 'rb') as f:
|
158 |
+
embeddings, p_lookup = pickle.load(f)
|
159 |
+
all_embeddings.append(embeddings)
|
160 |
+
corpus_lookups.extend(p_lookup)
|
161 |
+
|
162 |
+
all_embeddings = np.concatenate(all_embeddings, axis=0)
|
163 |
+
logger.info(f"Loaded corpus embeddings for {dataset_name}. Shape: {all_embeddings.shape}")
|
164 |
+
|
165 |
+
return all_embeddings, corpus_lookups
|
166 |
+
|
167 |
+
def create_faiss_index(embeddings):
|
168 |
+
dimension = embeddings.shape[1]
|
169 |
+
index = faiss.IndexFlatIP(dimension)
|
170 |
+
index.add(embeddings)
|
171 |
+
logger.info(f"Created FAISS index with {index.ntotal} vectors of dimension {dimension}")
|
172 |
+
return index
|
173 |
+
|
174 |
+
def load_or_create_faiss_index(dataset_name):
|
175 |
+
embeddings, corpus_lookups = load_corpus_embeddings(dataset_name)
|
176 |
+
index = create_faiss_index(embeddings)
|
177 |
+
return index, corpus_lookups
|
178 |
|
179 |
+
def initialize_faiss_and_corpus(dataset_name):
|
180 |
+
global corpus_lookups
|
181 |
+
index, corpus_lookups[dataset_name] = load_or_create_faiss_index(dataset_name)
|
182 |
+
logger.info(f"Initialized FAISS index and corpus lookups for {dataset_name}")
|
183 |
+
return index
|
|
|
184 |
|
185 |
def search_queries(dataset_name, q_reps, depth=1000):
|
186 |
+
faiss_index = initialize_faiss_and_corpus(dataset_name)
|
|
|
|
|
187 |
|
188 |
logger.info(f"Searching queries. Shape of q_reps: {q_reps.shape}")
|
189 |
|
|
|
192 |
|
193 |
logger.info(f"Search completed. Shape of all_scores: {all_scores.shape}, all_indices: {all_indices.shape}")
|
194 |
logger.info(f"Sample scores: {all_scores[0][:5]}, Sample indices: {all_indices[0][:5]}")
|
|
|
195 |
|
196 |
psg_indices = [[str(corpus_lookups[dataset_name][x]) for x in q_dd] for q_dd in all_indices]
|
197 |
|
198 |
return all_scores, np.array(psg_indices)
|
199 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
def load_queries(dataset_name):
|
201 |
global queries, q_lookups, qrels, query2qid
|
202 |
dataset = ir_datasets.load(f"beir/{dataset_name.lower()}" + ("/test" if dataset_name == "scifact" else ""))
|
|
|
218 |
logger.info(f"Loaded queries for {dataset_name}. Total queries: {len(queries[dataset_name])}")
|
219 |
logger.info(f"Loaded qrels for {dataset_name}. Total query IDs: {len(qrels[dataset_name])}")
|
220 |
|
|
|
221 |
def evaluate(qrels, results, k_values):
|
222 |
qrels = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in qrels.items()}
|
223 |
results = {str(k): {str(k2): v2 for k2, v2 in v.items()} for k, v in results.items()}
|
|
|
276 |
logger.info(f"Number of results: {len(results)}")
|
277 |
logger.info(f"Sample result: {list(results.items())[0]}")
|
278 |
|
|
|
279 |
logger.info(f"Number of queries in qrels: {len(qrels[dataset])}")
|
280 |
logger.info(f"Sample qrel: {list(qrels[dataset].items())[0]}")
|
281 |
logger.info(f"Number of queries in results: {len(results)}")
|
|
|
295 |
def gradio_interface(dataset, postfix):
|
296 |
return run_evaluation(dataset, postfix)
|
297 |
|
|
|
298 |
if model is None:
|
299 |
model = RepLlamaModel(model_name_or_path=CUR_MODEL)
|
|
|
300 |
load_queries(current_dataset)
|
301 |
|
|
|
302 |
# Create Gradio interface
|
303 |
iface = gr.Interface(
|
304 |
fn=gradio_interface,
|
|
|
317 |
)
|
318 |
|
319 |
# Launch the interface
|
320 |
+
iface.launch(share=False)
|