Spaces:
Running
Running
File size: 9,700 Bytes
1e66d1d 9a96b62 1e66d1d 9a96b62 f7d7a98 1e66d1d f7d7a98 b183d7b f7d7a98 1e66d1d f7d7a98 9a96b62 f7d7a98 e3e81aa 8041be5 f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 3d8de76 f7d7a98 c6179a9 f7d7a98 5ec65d6 f7d7a98 c6179a9 03e0863 1e66d1d 5274b44 1e66d1d f7d7a98 03e0863 0acec27 03e0863 0acec27 e1bd431 3d8de76 e1bd431 3d8de76 0acec27 5968656 029a4b3 3d8de76 029a4b3 5968656 03e0863 3d8de76 03e0863 b1d5a3b 5968656 f7d7a98 5968656 0acec27 03e0863 5968656 f7d7a98 5968656 3d8de76 029a4b3 b1d5a3b 03e0863 b1d5a3b 03e0863 f7d7a98 0acec27 03e0863 f7d7a98 0acec27 91a6a7f 0acec27 b1d5a3b 0acec27 c59bc5d 1e66d1d 03e0863 1e66d1d f7d7a98 1e66d1d 3d8de76 f7d7a98 5ec65d6 03e0863 f7d7a98 9a96b62 f7d7a98 1e66d1d 9a96b62 f7d7a98 9a96b62 f7d7a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
# app.py β Unified Panel App with Semantic Search + Filterable Tabulator
import os, io, gc
import panel as pn
import pandas as pd
import boto3, torch
import psycopg2
from sentence_transformers import SentenceTransformer, util
pn.extension('tabulator')
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1) Database and Resource Loading
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
# Ensure year column is int, show blank instead of NaN
if "year" in df_.columns:
df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''})
return df_
df = get_data()
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
@pn.cache()
def get_semantic_resources():
model = get_st_model()
ids_list, emb_tensor = load_embeddings()
return model, ids_list, emb_tensor
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2) Widgets
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"]
w_columns = pn.widgets.MultiChoice(
name="Columns to show",
options=ALL_COLUMNS,
value=["country","year","question_text","answer_text"]
)
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True)
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3) Unified Results Table (Tabulator)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
result_table = pn.widgets.Tabulator(
pagination='remote',
page_size=15,
sizing_mode="stretch_width",
layout='fit_columns',
show_index=False
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4) Search Logic
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame:
if df_in.empty:
return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"])
tmp = df_in.copy()
tmp["year"] = tmp["year"].replace('', pd.NA)
grouped = (
tmp.groupby("question_text", dropna=False)
.agg({
"country": lambda x: sorted({v for v in x if pd.notna(v)}),
"year": lambda x: sorted({str(v) for v in x if pd.notna(v)}),
"answer_text": lambda x: list(x.dropna())[:3],
})
.reset_index()
.rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"})
)
return grouped
def _selected_cols(has_score=False):
allowed = set(ALL_COLUMNS)
if not has_score and "Score" in w_columns.value:
w_columns.value = [c for c in w_columns.value if c != "Score"]
cols = [c for c in w_columns.value if c in allowed]
if not cols:
cols = ["country", "year", "question_text", "answer_text"]
return cols
def search(event=None):
query = w_semquery.value.strip()
filt = df.copy()
if w_countries.value:
filt = filt[filt["country"].isin(w_countries.value)]
if w_years.value:
filt = filt[filt["year"].isin(w_years.value)]
if w_keyword.value:
filt = filt[
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
]
if not query:
result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)]
return
model, ids_list, emb_tensor = get_semantic_resources()
filtered_ids = filt["id"].tolist()
id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
if not filtered_indices:
result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True))
return
top_k = min(int(w_topk.value), len(filtered_indices))
filtered_embs = emb_tensor[filtered_indices]
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, filtered_embs)[0]
top_vals, top_idx = torch.topk(sims, k=top_k)
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)]
def clear_filters(event=None):
w_countries.value = []
w_years.value = []
w_keyword.value = ""
w_semquery.value = ""
w_topk.disabled = True
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
w_search_button.on_click(search)
w_clear_filters.on_click(clear_filters)
# Live updates for filters (except semantic query and keyword)
w_group.param.watch(lambda e: search(), 'value')
w_countries.param.watch(lambda e: search(), 'value')
w_years.param.watch(lambda e: search(), 'value')
w_columns.param.watch(lambda e: search(), 'value')
# Allow pressing Enter in semantic query or keyword to trigger search
w_semquery.param.watch(lambda e: search(), 'enter_pressed')
w_keyword.param.watch(lambda e: search(), 'enter_pressed')
# Enable/disable Top-K based on semantic query presence
def _toggle_topk_disabled(event=None):
w_topk.disabled = (w_semquery.value.strip() == '')
_toggle_topk_disabled()
w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value')
# Show all data at startup
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5) Layout
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
sidebar = pn.Column(
"## π Filters",
w_countries, w_years, w_keyword, w_group, w_columns,
pn.Spacer(height=20),
"## π§ Semantic Search",
w_semquery,
w_topk,
w_search_button,
pn.Spacer(height=20),
w_clear_filters,
width=300
)
main = pn.Column(
pn.pane.Markdown("## π CGD Survey Explorer"),
result_table
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()
|