cgd-ui-TEST / app.py
gigiliu12's picture
Update app.py
feb2540 verified
#!/usr/bin/env python3
# app.py – CGD Survey Explorer (keyword + semantic in one table)
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ─────────────────────────────────────────────────────────────
# 1) Database credentials (HF Secrets or env vars)
# ─────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
"""Read survey_info once every 10 min."""
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require",
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
return df_
df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}
# ─────────────────────────────────────────────────────────────
# 2) Cached resources
# ─────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
"""Download ids + embedding tensor from S3 once per session."""
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
@st.cache_resource
def get_st_model():
"""Mini-LM sentence-transformer pinned to CPU (avoids meta-tensor bug)."""
return SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
device="cpu",
)
# ─────────────────────────────────────────────────────────────
# 3) Streamlit UI
# ─────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")
st.sidebar.header("πŸ”Ž Filter Questions")
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
sel_years = st.sidebar.multiselect("Select Year(s)", year_opts)
keyword = st.sidebar.text_input("Keyword Search (Question / Answer / Code)")
group_by_q = st.sidebar.checkbox("Group by Question Text")
# ── Semantic search panel
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
# ── Always build the keyword/dropdown subset
filtered = df[
(df["country"].isin(sel_countries) if sel_countries else True) &
(df["year"].isin(sel_years) if sel_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ─────────────────────────────────────────────────────────────
# 4) Semantic Search β†’ merged table
# ─────────────────────────────────────────────────────────────
if search_clicked:
with st.spinner("Embedding & searching…"):
model = get_st_model()
q_vec = model.encode(
sem_query.strip(),
convert_to_tensor=True,
device="cpu"
).cpu()
sims = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(sims, k=50) # 50 candidates
sem_ids = [ids_list[i] for i in top_idx.tolist()]
sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
score_map = dict(zip(sem_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
# rows that matched keyword/dropdown but not semantic
remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy()
remainder["Score"] = "" # blank score
combined = pd.concat([sem_rows, remainder], ignore_index=True)
st.subheader(f"πŸ” Combined Results ({len(combined)})")
st.dataframe(
combined[["Score", "country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
st.stop() # skip original display logic below when semantic ran
# ─────────────────────────────────────────────────────────────
# 5) Original display (keyword / filters only)
# ─────────────────────────────────────────────────────────────
if group_by_q:
st.subheader("πŸ“Š Grouped by Question Text")
grouped = (
filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped, use_container_width=True)
if grouped.empty:
st.info("No questions found with current filters.")
else:
hdr = []
if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
if sel_years: hdr.append("Years: " + ", ".join(map(str, sel_years)))
st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
st.dataframe(
filtered[["country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
if filtered.empty:
st.info("No matching questions found.")