Spaces:
Running
Running
# app.py β Unified Panel App with Semantic Search + Filterable Tabulator | |
import os, io, gc | |
import panel as pn | |
import pandas as pd | |
import boto3, torch | |
import psycopg2 | |
from sentence_transformers import SentenceTransformer, util | |
pn.extension('tabulator') | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) Database and Resource Loading | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
DB_HOST = os.getenv("DB_HOST") | |
DB_PORT = os.getenv("DB_PORT", "5432") | |
DB_NAME = os.getenv("DB_NAME") | |
DB_USER = os.getenv("DB_USER") | |
DB_PASSWORD = os.getenv("DB_PASSWORD") | |
def get_data(): | |
conn = psycopg2.connect( | |
host=DB_HOST, port=DB_PORT, | |
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD, | |
sslmode="require" | |
) | |
df_ = pd.read_sql_query(""" | |
SELECT id, country, year, section, | |
question_code, question_text, | |
answer_code, answer_text | |
FROM survey_info; | |
""", conn) | |
conn.close() | |
# Ensure year column is int, show blank instead of NaN | |
if "year" in df_.columns: | |
df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''}) | |
return df_ | |
df = get_data() | |
def load_embeddings(): | |
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt" | |
buf = io.BytesIO() | |
boto3.client("s3").download_fileobj(BUCKET, KEY, buf) | |
buf.seek(0) | |
ckpt = torch.load(buf, map_location="cpu") | |
buf.close(); gc.collect() | |
return ckpt["ids"], ckpt["embeddings"] | |
def get_st_model(): | |
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu") | |
def get_semantic_resources(): | |
model = get_st_model() | |
ids_list, emb_tensor = load_embeddings() | |
return model, ids_list, emb_tensor | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 2) Widgets | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
country_opts = sorted(df["country"].dropna().unique()) | |
year_opts = sorted(df["year"].dropna().unique()) | |
ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"] | |
w_columns = pn.widgets.MultiChoice( | |
name="Columns to show", | |
options=ALL_COLUMNS, | |
value=["country","year","question_text","answer_text"] | |
) | |
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts) | |
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts) | |
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching") | |
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False) | |
w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True) | |
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search") | |
w_search_button = pn.widgets.Button(name="Search", button_type="primary") | |
w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning") | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 3) Unified Results Table (Tabulator) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
result_table = pn.widgets.Tabulator( | |
pagination='remote', | |
page_size=15, | |
sizing_mode="stretch_width", | |
layout='fit_columns', | |
show_index=False | |
) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 4) Search Logic | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame: | |
if df_in.empty: | |
return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"]) | |
tmp = df_in.copy() | |
tmp["year"] = tmp["year"].replace('', pd.NA) | |
grouped = ( | |
tmp.groupby("question_text", dropna=False) | |
.agg({ | |
"country": lambda x: sorted({v for v in x if pd.notna(v)}), | |
"year": lambda x: sorted({str(v) for v in x if pd.notna(v)}), | |
"answer_text": lambda x: list(x.dropna())[:3], | |
}) | |
.reset_index() | |
.rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"}) | |
) | |
return grouped | |
def _selected_cols(has_score=False): | |
allowed = set(ALL_COLUMNS) | |
if not has_score and "Score" in w_columns.value: | |
w_columns.value = [c for c in w_columns.value if c != "Score"] | |
cols = [c for c in w_columns.value if c in allowed] | |
if not cols: | |
cols = ["country", "year", "question_text", "answer_text"] | |
return cols | |
def search(event=None): | |
query = w_semquery.value.strip() | |
filt = df.copy() | |
if w_countries.value: | |
filt = filt[filt["country"].isin(w_countries.value)] | |
if w_years.value: | |
filt = filt[filt["year"].isin(w_years.value)] | |
if w_keyword.value: | |
filt = filt[ | |
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) | | |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) | | |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False) | |
] | |
if not query: | |
result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)] | |
return | |
model, ids_list, emb_tensor = get_semantic_resources() | |
filtered_ids = filt["id"].tolist() | |
id_to_index = {id_: i for i, id_ in enumerate(ids_list)} | |
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index] | |
if not filtered_indices: | |
result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True)) | |
return | |
top_k = min(int(w_topk.value), len(filtered_indices)) | |
filtered_embs = emb_tensor[filtered_indices] | |
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu() | |
sims = util.cos_sim(q_vec, filtered_embs)[0] | |
top_vals, top_idx = torch.topk(sims, k=top_k) | |
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()] | |
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy() | |
score_map = dict(zip(top_filtered_ids, top_vals.tolist())) | |
sem_rows["Score"] = sem_rows["id"].map(score_map) | |
sem_rows = sem_rows.sort_values("Score", ascending=False) | |
result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)] | |
def clear_filters(event=None): | |
w_countries.value = [] | |
w_years.value = [] | |
w_keyword.value = "" | |
w_semquery.value = "" | |
w_topk.disabled = True | |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() | |
w_search_button.on_click(search) | |
w_clear_filters.on_click(clear_filters) | |
# Live updates for filters (except semantic query and keyword) | |
w_group.param.watch(lambda e: search(), 'value') | |
w_countries.param.watch(lambda e: search(), 'value') | |
w_years.param.watch(lambda e: search(), 'value') | |
w_columns.param.watch(lambda e: search(), 'value') | |
# Allow pressing Enter in semantic query or keyword to trigger search | |
w_semquery.param.watch(lambda e: search(), 'enter_pressed') | |
w_keyword.param.watch(lambda e: search(), 'enter_pressed') | |
# Enable/disable Top-K based on semantic query presence | |
def _toggle_topk_disabled(event=None): | |
w_topk.disabled = (w_semquery.value.strip() == '') | |
_toggle_topk_disabled() | |
w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value') | |
# Show all data at startup | |
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy() | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 5) Layout | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
sidebar = pn.Column( | |
"## π Filters", | |
w_countries, w_years, w_keyword, w_group, w_columns, | |
pn.Spacer(height=20), | |
"## π§ Semantic Search", | |
w_semquery, | |
w_topk, | |
w_search_button, | |
pn.Spacer(height=20), | |
w_clear_filters, | |
width=300 | |
) | |
main = pn.Column( | |
pn.pane.Markdown("## π CGD Survey Explorer"), | |
result_table | |
) | |
pn.template.FastListTemplate( | |
title="CGD Survey Explorer", | |
sidebar=sidebar, | |
main=main, | |
theme_toggle=True, | |
).servable() | |