Jonas Leeb
commited on
Commit
·
dfc89a9
1
Parent(s):
b4a0b98
query is also shown now
Browse files
app.py
CHANGED
|
@@ -23,6 +23,7 @@ class ArxivSearch:
|
|
| 23 |
self.raw_texts = []
|
| 24 |
self.arxiv_ids = []
|
| 25 |
self.last_results = []
|
|
|
|
| 26 |
|
| 27 |
self.embedding_dropdown = gr.Dropdown(
|
| 28 |
choices=["tfidf", "word2vec", "bert"],
|
|
@@ -113,20 +114,7 @@ class ArxivSearch:
|
|
| 113 |
self.documents.append(text.strip())
|
| 114 |
self.arxiv_ids.append(arxiv_id)
|
| 115 |
|
| 116 |
-
|
| 117 |
-
query_terms = query.lower().split()
|
| 118 |
-
query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
|
| 119 |
-
if not query_indices:
|
| 120 |
-
return []
|
| 121 |
-
scores = []
|
| 122 |
-
for doc_idx in range(self.tfidf_matrix.shape[0]):
|
| 123 |
-
doc_vector = self.tfidf_matrix[doc_idx]
|
| 124 |
-
doc_score = sum(doc_vector[0, i] for i in query_indices)
|
| 125 |
-
if doc_score > 0:
|
| 126 |
-
scores.append((doc_idx, doc_score))
|
| 127 |
-
scores.sort(key=lambda x: x[1], reverse=True)
|
| 128 |
-
return scores[:top_n]
|
| 129 |
-
|
| 130 |
def plot_3d_embeddings(self, embedding):
|
| 131 |
# Example: plot random points, replace with your embeddings
|
| 132 |
pca = PCA(n_components=3)
|
|
@@ -144,6 +132,7 @@ class ArxivSearch:
|
|
| 144 |
pca.fit(all_data)
|
| 145 |
reduced_data = pca.transform(self.word2vec_embeddings[:5000])
|
| 146 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
|
|
|
| 147 |
|
| 148 |
elif embedding == "bert":
|
| 149 |
all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
|
|
@@ -151,6 +140,7 @@ class ArxivSearch:
|
|
| 151 |
pca.fit(all_data)
|
| 152 |
reduced_data = pca.transform(self.bert_embeddings[:5000])
|
| 153 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
|
|
|
| 154 |
|
| 155 |
else:
|
| 156 |
raise ValueError(f"Unsupported embedding type: {embedding}")
|
|
@@ -159,7 +149,8 @@ class ArxivSearch:
|
|
| 159 |
y=reduced_data[:, 1],
|
| 160 |
z=reduced_data[:, 2],
|
| 161 |
mode='markers',
|
| 162 |
-
marker=dict(size=3.5, color='
|
|
|
|
| 163 |
)
|
| 164 |
layout = go.Layout(
|
| 165 |
margin=dict(l=0, r=0, b=0, t=0),
|
|
@@ -182,18 +173,42 @@ class ArxivSearch:
|
|
| 182 |
z=reduced_results_points[:, 2],
|
| 183 |
mode='markers',
|
| 184 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
)
|
| 186 |
-
fig = go.Figure(data=[trace, results_trace], layout=layout)
|
| 187 |
else:
|
| 188 |
fig = go.Figure(data=[trace], layout=layout)
|
| 189 |
return fig
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def word2vec_search(self, query, top_n=5):
|
| 192 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
| 193 |
if not tokens:
|
| 194 |
return []
|
| 195 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
| 196 |
query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
|
|
|
|
| 197 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
| 198 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 199 |
return [(i, sims[i]) for i in top_indices]
|
|
@@ -203,6 +218,7 @@ class ArxivSearch:
|
|
| 203 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
| 204 |
outputs = self.model(**inputs)
|
| 205 |
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
|
|
|
| 206 |
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
|
| 207 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 208 |
return [(i, sims[i]) for i in top_indices]
|
|
|
|
| 23 |
self.raw_texts = []
|
| 24 |
self.arxiv_ids = []
|
| 25 |
self.last_results = []
|
| 26 |
+
self.query_encoding = None
|
| 27 |
|
| 28 |
self.embedding_dropdown = gr.Dropdown(
|
| 29 |
choices=["tfidf", "word2vec", "bert"],
|
|
|
|
| 114 |
self.documents.append(text.strip())
|
| 115 |
self.arxiv_ids.append(arxiv_id)
|
| 116 |
|
| 117 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def plot_3d_embeddings(self, embedding):
|
| 119 |
# Example: plot random points, replace with your embeddings
|
| 120 |
pca = PCA(n_components=3)
|
|
|
|
| 132 |
pca.fit(all_data)
|
| 133 |
reduced_data = pca.transform(self.word2vec_embeddings[:5000])
|
| 134 |
reduced_results_points = pca.transform(self.word2vec_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
| 135 |
+
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
| 136 |
|
| 137 |
elif embedding == "bert":
|
| 138 |
all_indices = list(set(results_indices) | set(range(min(5000, self.bert_embeddings.shape[0]))))
|
|
|
|
| 140 |
pca.fit(all_data)
|
| 141 |
reduced_data = pca.transform(self.bert_embeddings[:5000])
|
| 142 |
reduced_results_points = pca.transform(self.bert_embeddings[results_indices]) if len(results_indices) > 0 else np.empty((0, 3))
|
| 143 |
+
query_point = pca.transform(self.query_encoding) if self.query_encoding is not None and self.query_encoding.shape[0] > 0 else np.empty((0, 3))
|
| 144 |
|
| 145 |
else:
|
| 146 |
raise ValueError(f"Unsupported embedding type: {embedding}")
|
|
|
|
| 149 |
y=reduced_data[:, 1],
|
| 150 |
z=reduced_data[:, 2],
|
| 151 |
mode='markers',
|
| 152 |
+
marker=dict(size=3.5, color='#cccccc', opacity=0.35),
|
| 153 |
+
name='All Documents'
|
| 154 |
)
|
| 155 |
layout = go.Layout(
|
| 156 |
margin=dict(l=0, r=0, b=0, t=0),
|
|
|
|
| 173 |
z=reduced_results_points[:, 2],
|
| 174 |
mode='markers',
|
| 175 |
marker=dict(size=3.5, color='orange', opacity=0.75),
|
| 176 |
+
name='Results'
|
| 177 |
+
)
|
| 178 |
+
query_trace = go.Scatter3d(
|
| 179 |
+
x=query_point[:, 0],
|
| 180 |
+
y=query_point[:, 1],
|
| 181 |
+
z=query_point[:, 2],
|
| 182 |
+
mode='markers',
|
| 183 |
+
marker=dict(size=5, color='red', opacity=0.8),
|
| 184 |
+
name='Query'
|
| 185 |
)
|
| 186 |
+
fig = go.Figure(data=[trace, results_trace, query_trace], layout=layout)
|
| 187 |
else:
|
| 188 |
fig = go.Figure(data=[trace], layout=layout)
|
| 189 |
return fig
|
| 190 |
+
|
| 191 |
+
def keyword_match_ranking(self, query, top_n=5):
|
| 192 |
+
query_terms = query.lower().split()
|
| 193 |
+
query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
|
| 194 |
+
if not query_indices:
|
| 195 |
+
return []
|
| 196 |
+
scores = []
|
| 197 |
+
for doc_idx in range(self.tfidf_matrix.shape[0]):
|
| 198 |
+
doc_vector = self.tfidf_matrix[doc_idx]
|
| 199 |
+
doc_score = sum(doc_vector[0, i] for i in query_indices)
|
| 200 |
+
if doc_score > 0:
|
| 201 |
+
scores.append((doc_idx, doc_score))
|
| 202 |
+
scores.sort(key=lambda x: x[1], reverse=True)
|
| 203 |
+
return scores[:top_n]
|
| 204 |
+
|
| 205 |
def word2vec_search(self, query, top_n=5):
|
| 206 |
tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
|
| 207 |
if not tokens:
|
| 208 |
return []
|
| 209 |
vectors = np.array([self.wv_model[word] for word in tokens])
|
| 210 |
query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
|
| 211 |
+
self.query_encoding = query_vec
|
| 212 |
sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
|
| 213 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 214 |
return [(i, sims[i]) for i in top_indices]
|
|
|
|
| 218 |
inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
|
| 219 |
outputs = self.model(**inputs)
|
| 220 |
query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
|
| 221 |
+
self.query_encoding = query_vec
|
| 222 |
sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
|
| 223 |
top_indices = sims.argsort()[::-1][:top_n]
|
| 224 |
return [(i, sims[i]) for i in top_indices]
|