Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -155,45 +155,7 @@ def search(query):
|
|
155 |
)
|
156 |
|
157 |
|
158 |
-
|
159 |
-
# Encode the query using the bi-encoder and find potentially relevant passages
|
160 |
-
question_embedding = bi_encoder.encode(query, convert_to_tensor=True)
|
161 |
-
# question_embedding = question_embedding.cuda()
|
162 |
-
hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k)
|
163 |
-
hits = hits[0] # Get the hits for the first query
|
164 |
-
|
165 |
-
##### Re-Ranking #####
|
166 |
-
# Now, score all retrieved passages with the cross_encoder
|
167 |
-
cross_inp = [[query, passages[hit['corpus_id']]] for hit in hits]
|
168 |
-
cross_scores = cross_encoder.predict(cross_inp)
|
169 |
-
|
170 |
-
# Sort results by the cross-encoder scores
|
171 |
-
for idx in range(len(cross_scores)):
|
172 |
-
hits[idx]['cross-score'] = cross_scores[idx]
|
173 |
-
|
174 |
-
# Output of top-5 hits from bi-encoder
|
175 |
-
print("\n-------------------------\n")
|
176 |
-
print("Top-5 Bi-Encoder Retrieval hits")
|
177 |
-
hits = sorted(hits, key=lambda x: x['score'], reverse=True)
|
178 |
-
for hit in hits[0:5]:
|
179 |
-
# print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']].replace("\n", " ")))
|
180 |
-
row_dict = df.loc[df['all_review']== corpus[hit['corpus_id']]]
|
181 |
-
print("\t{:.3f}\t".format(hit['score']),row_dict['Hotel'].values[0])
|
182 |
-
de = df_basic.loc[df_basic.Hotel == row_dict['Hotel'].values[0]]
|
183 |
-
print(f'\tPrice Per night: {de.price_per_night.values[0]}')
|
184 |
-
print(de.description.values[0])
|
185 |
-
|
186 |
-
# Output of top-5 hits from re-ranker
|
187 |
-
print("\n-------------------------\n")
|
188 |
-
print("Top-5 Cross-Encoder Re-ranker hits")
|
189 |
-
hits = sorted(hits, key=lambda x: x['cross-score'], reverse=True)
|
190 |
-
for hit in hits[0:5]:
|
191 |
-
# print("\t{:.3f}\t{}".format(hit['cross-score'], passages[hit['corpus_id']].replace("\n", " ")))
|
192 |
-
row_dict = df.loc[df['all_review']== corpus[hit['corpus_id']]]
|
193 |
-
print("\t{:.3f}\t".format(hit['cross-score']),row_dict['Hotel'].values[0])
|
194 |
-
de = df_basic.loc[df_basic.Hotel == row_dict['Hotel'].values[0]]
|
195 |
-
print(f'\tPrice Per night: {de.price_per_night.values[0]}')
|
196 |
-
print(de.description.values[0])
|
197 |
|
198 |
|
199 |
return bm25list
|
|
|
155 |
)
|
156 |
|
157 |
|
158 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
|
161 |
return bm25list
|