orionweller commited on
Commit
a667370
·
1 Parent(s): 798b478
Files changed (1) hide show
  1. app.py +70 -58
app.py CHANGED
@@ -38,13 +38,23 @@ datasets = ["scifact"]
38
  current_dataset = "scifact"
39
 
40
 
41
- def pool(last_hidden_states, attention_mask):
42
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
43
- sequence_lengths = attention_mask.sum(dim=1) - 1
44
- batch_size = last_hidden.shape[0]
45
- return last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
46
 
47
- def create_batch_dict(tokenizer, input_texts, max_length=512):
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  batch_dict = tokenizer(
49
  input_texts,
50
  max_length=max_length - 1,
@@ -53,7 +63,10 @@ def create_batch_dict(tokenizer, input_texts, max_length=512):
53
  padding=False,
54
  truncation=True
55
  )
56
- batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
 
 
 
57
  return tokenizer.pad(
58
  batch_dict,
59
  padding=True,
@@ -62,18 +75,44 @@ def create_batch_dict(tokenizer, input_texts, max_length=512):
62
  return_tensors="pt",
63
  )
64
 
65
- def load_model():
66
- global tokenizer, model, CUR_MODEL, BASE_MODEL
67
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
68
- tokenizer.pad_token_id = tokenizer.eos_token_id
69
- tokenizer.pad_token = tokenizer.eos_token
70
- tokenizer.padding_side = "right"
71
-
72
- # model = AutoModel.from_pretrained(CUR_MODEL, max_memory={"cpu": "12GiB"}, torch_dtype=torch.bfloat16, offload_state_dict=True)
73
- base_model_instance = AutoModel.from_pretrained(BASE_MODEL).cpu()
74
- model = PeftModel.from_pretrained(base_model_instance, CUR_MODEL)
75
- model = model.merge_and_unload()
76
- model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
 
79
  def load_faiss_index(dataset_name):
@@ -128,31 +167,6 @@ def load_queries(dataset_name):
128
  qrels[dataset_name][qrel.query_id] = {}
129
  qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
130
 
131
- @spaces.GPU
132
- def encode_queries(dataset_name, postfix):
133
- global queries, tokenizer, model
134
- input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[dataset_name]]
135
-
136
- encoded_embeds = []
137
- batch_size = 32
138
- model = model.cuda()
139
-
140
- for start_idx in tqdm.tqdm(range(0, len(input_texts), batch_size), desc="Encoding queries"):
141
- batch_input_texts = input_texts[start_idx: start_idx + batch_size]
142
-
143
- batch_dict = create_batch_dict(tokenizer, batch_input_texts)
144
- batch_dict = {k: v.to(model.device) for k, v in batch_dict.items()}
145
-
146
- with torch.cuda.amp.autocast():
147
- with torch.no_grad():
148
- outputs = model(**batch_dict)
149
- embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'])
150
- embeds = F.normalize(embeds, p=2, dim=-1)
151
- encoded_embeds.append(embeds.float().cpu().numpy())
152
-
153
- model = model.cpu()
154
- return np.concatenate(encoded_embeds, axis=0)
155
-
156
 
157
  def evaluate(qrels, results, k_values):
158
  evaluator = pytrec_eval.RelevanceEvaluator(
@@ -168,15 +182,11 @@ def evaluate(qrels, results, k_values):
168
  return metrics
169
 
170
  def run_evaluation(dataset, postfix):
171
- global current_dataset
172
-
173
- if dataset not in corpus_lookups or dataset not in queries:
174
- load_corpus_lookups(dataset)
175
- load_queries(dataset)
176
-
177
  current_dataset = dataset
178
-
179
- q_reps = encode_queries(dataset, postfix)
 
180
  all_scores, psg_indices = search_queries(dataset, q_reps)
181
 
182
  results = {qid: dict(zip(doc_ids, map(float, scores)))
@@ -189,16 +199,18 @@ def run_evaluation(dataset, postfix):
189
  "Recall@100": metrics["Recall@100"]
190
  }
191
 
192
- def gradio_interface(dataset, postfix):
193
- if 'model' not in globals() or model is None:
194
- load_model()
195
- for dataset in datasets:
196
- print(f"Loading dataset: {dataset}")
197
- load_corpus_lookups(dataset)
198
- load_queries(dataset)
199
 
 
 
200
  return run_evaluation(dataset, postfix)
201
 
 
 
 
 
 
 
 
202
  # Create Gradio interface
203
  iface = gr.Interface(
204
  fn=gradio_interface,
 
38
  current_dataset = "scifact"
39
 
40
 
41
+ def pool(last_hidden_states, attention_mask, pool_type="last"):
42
  last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
 
 
 
43
 
44
+ if pool_type == "last":
45
+ left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
46
+ if left_padding:
47
+ emb = last_hidden[:, -1]
48
+ else:
49
+ sequence_lengths = attention_mask.sum(dim=1) - 1
50
+ batch_size = last_hidden.shape[0]
51
+ emb = last_hidden[torch.arange(batch_size, device=last_hidden.device), sequence_lengths]
52
+ else:
53
+ raise ValueError(f"pool_type {pool_type} not supported")
54
+
55
+ return emb
56
+
57
+ def create_batch_dict(tokenizer, input_texts, always_add_eos="last", max_length=512):
58
  batch_dict = tokenizer(
59
  input_texts,
60
  max_length=max_length - 1,
 
63
  padding=False,
64
  truncation=True
65
  )
66
+
67
+ if always_add_eos == "last":
68
+ batch_dict['input_ids'] = [input_ids + [tokenizer.eos_token_id] for input_ids in batch_dict['input_ids']]
69
+
70
  return tokenizer.pad(
71
  batch_dict,
72
  padding=True,
 
75
  return_tensors="pt",
76
  )
77
 
78
+ class RepLlamaModel:
79
+ def __init__(self, model_name_or_path):
80
+ self.base_model = "meta-llama/Llama-2-7b-hf"
81
+ self.tokenizer = AutoTokenizer.from_pretrained(self.base_model)
82
+ self.tokenizer.model_max_length = 2048
83
+ self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
84
+ self.tokenizer.pad_token = self.tokenizer.eos_token
85
+ self.tokenizer.padding_side = "right"
86
+
87
+ self.model = self.get_model(model_name_or_path)
88
+ self.model.config.max_length = 2048
89
+
90
+ def get_model(self, peft_model_name):
91
+ base_model = AutoModel.from_pretrained(self.base_model)
92
+ model = PeftModel.from_pretrained(base_model, peft_model_name)
93
+ model = model.merge_and_unload()
94
+ model.eval()
95
+ return model
96
+
97
+ @spaces.GPU
98
+ def encode(self, texts, batch_size=32, **kwargs):
99
+ self.model = self.model.cuda()
100
+ all_embeddings = []
101
+ for i in range(0, len(texts), batch_size):
102
+ batch_texts = texts[i:i+batch_size]
103
+
104
+ batch_dict = create_batch_dict(self.tokenizer, batch_texts, always_add_eos="last")
105
+ batch_dict = {key: value.cuda() for key, value in batch_dict.items()}
106
+
107
+ with torch.cuda.amp.autocast():
108
+ with torch.no_grad():
109
+ outputs = self.model(**batch_dict)
110
+ embeddings = pool(outputs.last_hidden_state, batch_dict['attention_mask'], 'last')
111
+ embeddings = F.normalize(embeddings, p=2, dim=-1)
112
+ all_embeddings.append(embeddings.cpu().numpy())
113
+
114
+ self.model = self.model.cpu()
115
+ return np.concatenate(all_embeddings, axis=0)
116
 
117
 
118
  def load_faiss_index(dataset_name):
 
167
  qrels[dataset_name][qrel.query_id] = {}
168
  qrels[dataset_name][qrel.query_id][qrel.doc_id] = qrel.relevance
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  def evaluate(qrels, results, k_values):
172
  evaluator = pytrec_eval.RelevanceEvaluator(
 
182
  return metrics
183
 
184
  def run_evaluation(dataset, postfix):
185
+ global current_dataset, queries, model
 
 
 
 
 
186
  current_dataset = dataset
187
+
188
+ input_texts = [f"query: {query.strip()} {postfix}".strip() for query in queries[current_dataset]]
189
+ q_reps = model.encode(input_texts)
190
  all_scores, psg_indices = search_queries(dataset, q_reps)
191
 
192
  results = {qid: dict(zip(doc_ids, map(float, scores)))
 
199
  "Recall@100": metrics["Recall@100"]
200
  }
201
 
 
 
 
 
 
 
 
202
 
203
+ @spaces.GPU
204
+ def gradio_interface(dataset, postfix):
205
  return run_evaluation(dataset, postfix)
206
 
207
+
208
+ if model is None:
209
+ model = RepLlamaModel(model_name_or_path=CUR_MODEL)
210
+ load_corpus_lookups(current_dataset)
211
+ load_queries(current_dataset)
212
+
213
+
214
  # Create Gradio interface
215
  iface = gr.Interface(
216
  fn=gradio_interface,