cgd-ui-panel / app.py
myshirk's picture
change button descriptions
c6179a9 verified
raw
history blame
7.73 kB
# app.py – Unified Panel App with Semantic Search + Filterable Tabulator
import os, io, gc
import panel as pn
import pandas as pd
import boto3, torch
import psycopg2
from sentence_transformers import SentenceTransformer, util
pn.extension('tabulator')
# ──────────────────────────────────────────────────────────────────────
# 1) Database and Resource Loading
# ──────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
# Ensure year column is int, show blank instead of NaN
if "year" in df_.columns:
df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64")
return df_
df = get_data()
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
@pn.cache()
def get_semantic_resources():
model = get_st_model()
ids_list, emb_tensor = load_embeddings()
return model, ids_list, emb_tensor
# ──────────────────────────────────────────────────────────────────────
# 2) Widgets
# ──────────────────────────────────────────────────────────────────────
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning")
# ──────────────────────────────────────────────────────────────────────
# 3) Unified Results Table (Tabulator)
# ──────────────────────────────────────────────────────────────────────
result_table = pn.widgets.Tabulator(
pagination='remote',
page_size=15,
sizing_mode="stretch_width",
layout='fit_columns',
show_index=False
)
# ──────────────────────────────────────────────────────────────────────
# 4) Semantic Search with Filtering
# ──────────────────────────────────────────────────────────────────────
def semantic_search(event=None):
"""Run filtered view if no semantic query; otherwise do semantic within filtered subset."""
query = w_semquery.value.strip()
# 1) Apply filters first (country/year/keyword)
filt = df.copy()
if w_countries.value:
filt = filt[filt["country"].isin(w_countries.value)]
if w_years.value:
filt = filt[filt["year"].isin(w_years.value)]
if w_keyword.value:
filt = filt[
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
]
# 2) If no semantic query, just show the filtered data (no Score column)
if not query:
if filt.empty:
result_table.value = pd.DataFrame(columns=["country", "year", "question_text", "answer_text"])
else:
result_table.value = filt[["country", "year", "question_text", "answer_text"]]
return
# 3) Otherwise, do semantic search *within* the filtered subset
model, ids_list, emb_tensor = get_semantic_resources()
filtered_ids = filt["id"].tolist()
id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
if not filtered_indices:
result_table.value = pd.DataFrame(columns=["Score", "country", "year", "question_text", "answer_text"])
return
filtered_embs = emb_tensor[filtered_indices]
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, filtered_embs)[0]
top_vals, top_idx = torch.topk(sims, k=50)
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
result_table.value = sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
def clear_filters(event=None):
w_countries.value = []
w_years.value = []
w_keyword.value = ""
w_semquery.value = ""
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()
w_search_button.on_click(semantic_search)
w_clear_filters.on_click(clear_filters)
# ──────────────────────────────────────────────────────────────────────
# 5) Layout
# ──────────────────────────────────────────────────────────────────────
sidebar = pn.Column(
"## πŸ”Ž Filters",
w_countries, w_years, w_keyword, w_group,
pn.Spacer(height=20),
"## 🧠 Semantic Search",
w_semquery, w_search_button,
pn.Spacer(height=20),
w_clear_filters,
width=300
)
main = pn.Column(
pn.pane.Markdown("## 🌍 CGD Survey Explorer"),
result_table
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()