myshirk commited on
Commit
0acec27
Β·
verified Β·
1 Parent(s): 8041be5

add group by function. allow filters to adjust automatically

Browse files
Files changed (1) hide show
  1. app.py +35 -18
app.py CHANGED
@@ -87,13 +87,28 @@ result_table = pn.widgets.Tabulator(
87
  )
88
 
89
  # ──────────────────────────────────────────────────────────────────────
90
- # 4) Semantic Search with Filtering
91
  # ──────────────────────────────────────────────────────────────────────
92
- def semantic_search(event=None):
93
- """Run filtered view if no semantic query; otherwise do semantic within filtered subset."""
94
- query = w_semquery.value.strip()
95
 
96
- # 1) Apply filters first (country/year/keyword)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  filt = df.copy()
98
  if w_countries.value:
99
  filt = filt[filt["country"].isin(w_countries.value)]
@@ -106,30 +121,23 @@ def semantic_search(event=None):
106
  filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
107
  ]
108
 
109
- # 2) If no semantic query, just show the filtered data (no Score column)
110
  if not query:
111
- if filt.empty:
112
- result_table.value = pd.DataFrame(columns=["country", "year", "question_text", "answer_text"])
113
- else:
114
- result_table.value = filt[["country", "year", "question_text", "answer_text"]]
115
  return
116
 
117
- # 3) Otherwise, do semantic search *within* the filtered subset
118
  model, ids_list, emb_tensor = get_semantic_resources()
119
-
120
  filtered_ids = filt["id"].tolist()
121
  id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
122
  filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
123
-
124
  if not filtered_indices:
125
- result_table.value = pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
126
  return
127
 
128
  filtered_embs = emb_tensor[filtered_indices]
129
-
130
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
131
  sims = util.cos_sim(q_vec, filtered_embs)[0]
132
- top_vals, top_idx = torch.topk(sims, k=50)
 
133
 
134
  top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
135
  sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
@@ -137,7 +145,7 @@ def semantic_search(event=None):
137
  sem_rows["Score"] = sem_rows["id"].map(score_map)
138
  sem_rows = sem_rows.sort_values("Score", ascending=False)
139
 
140
- result_table.value = sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
141
 
142
 
143
  def clear_filters(event=None):
@@ -147,9 +155,18 @@ def clear_filters(event=None):
147
  w_semquery.value = ""
148
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
149
 
150
- w_search_button.on_click(semantic_search)
151
  w_clear_filters.on_click(clear_filters)
152
 
 
 
 
 
 
 
 
 
 
153
  # ──────────────────────────────────────────────────────────────────────
154
  # 5) Layout
155
  # ──────────────────────────────────────────────────────────────────────
 
87
  )
88
 
89
  # ──────────────────────────────────────────────────────────────────────
90
+ # 4) Search Logic
91
  # ──────────────────────────────────────────────────────────────────────
 
 
 
92
 
93
+ def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame:
94
+ if df_in.empty:
95
+ return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"])
96
+ tmp = df_in.copy()
97
+ tmp["year"] = tmp["year"].replace('', pd.NA)
98
+ grouped = (
99
+ tmp.groupby("question_text", dropna=False)
100
+ .agg({
101
+ "country": lambda x: sorted({v for v in x if pd.notna(v)}),
102
+ "year": lambda x: sorted({str(v) for v in x if pd.notna(v)}),
103
+ "answer_text": lambda x: list(x.dropna())[:3],
104
+ })
105
+ .reset_index()
106
+ .rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"})
107
+ )
108
+ return grouped
109
+
110
+ def search(event=None):
111
+ query = w_semquery.value.strip()
112
  filt = df.copy()
113
  if w_countries.value:
114
  filt = filt[filt["country"].isin(w_countries.value)]
 
121
  filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
122
  ]
123
 
 
124
  if not query:
125
+ result_table.value = _group_by_question(filt) if w_group.value else filt[["country", "year", "question_text", "answer_text"]]
 
 
 
126
  return
127
 
 
128
  model, ids_list, emb_tensor = get_semantic_resources()
 
129
  filtered_ids = filt["id"].tolist()
130
  id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
131
  filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
 
132
  if not filtered_indices:
133
+ result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
134
  return
135
 
136
  filtered_embs = emb_tensor[filtered_indices]
 
137
  q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
138
  sims = util.cos_sim(q_vec, filtered_embs)[0]
139
+ top_k = min(50, len(filtered_indices))
140
+ top_vals, top_idx = torch.topk(sims, k=top_k)
141
 
142
  top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
143
  sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
 
145
  sem_rows["Score"] = sem_rows["id"].map(score_map)
146
  sem_rows = sem_rows.sort_values("Score", ascending=False)
147
 
148
+ result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
149
 
150
 
151
  def clear_filters(event=None):
 
155
  w_semquery.value = ""
156
  result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
157
 
158
+ w_search_button.on_click(search)
159
  w_clear_filters.on_click(clear_filters)
160
 
161
+ # Live updates for filters (except semantic query and keyword)
162
+ w_group.param.watch(lambda e: search(), 'value')
163
+ w_countries.param.watch(lambda e: search(), 'value')
164
+ w_years.param.watch(lambda e: search(), 'value')
165
+
166
+ # Allow pressing Enter in semantic query or keyword to trigger search
167
+ w_semquery.param.watch(lambda e: search(), 'enter_pressed')
168
+ w_keyword.param.watch(lambda e: search(), 'enter_pressed')
169
+
170
  # ──────────────────────────────────────────────────────────────────────
171
  # 5) Layout
172
  # ──────────────────────────────────────────────────────────────────────