cgd-ui-TEST / app.py
gigiliu12's picture
Update logic
c43bc47 verified
raw
history blame
7.78 kB
#!/usr/bin/env python3
# app.py – CGD Survey Explorer + merged semantic search
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ────────────────────────────────────────────────────────────────────────
# 1) Database credentials (provided via HF Secrets / env vars)
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST") # set these in the Space’s Secrets
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
"""Pull the full survey_info table (cached for 10 min)."""
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
sslmode="require",
)
df_ = pd.read_sql_query(
"""SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;""",
conn,
)
conn.close()
return df_
df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}
# ────────────────────────────────────────────────────────────────────────
# 2) Pre-computed embeddings (ids + tensor) – download once per session
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# ────────────────────────────────────────────────────────────────────────
# 3) Streamlit UI
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")
st.sidebar.header("πŸ”Ž Filter Questions")
country_options = sorted(df["country"].dropna().unique())
year_options = sorted(df["year"].dropna().unique())
sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
sel_years = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
"Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")
# --- Semantic-search input (kept in sidebar) ---------------------------
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
# ── base_filtered: applies dropdown + keyword logic (always computed) ──
base_filtered = df[
(df["country"].isin(sel_countries) if sel_countries else True) &
(df["year"].isin(sel_years) if sel_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ────────────────────────────────────────────────────────────────────────
# 4) When the Search button is clicked β†’ build merged table
# ────────────────────────────────────────────────────────────────────────
if search_clicked:
with st.spinner("Embedding & searching…"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
sims = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(sims, k=50) # get 50 candidates
sem_ids = [ids_list[i] for i in top_idx.tolist()]
sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
score_map = dict(zip(sem_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy()
remainder["Score"] = "" # blank score for keyword-only rows
combined = pd.concat([sem_rows, remainder], ignore_index=True)
st.subheader(f"πŸ” Combined Results ({len(combined)})")
st.dataframe(
combined[["Score", "country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
# ────────────────────────────────────────────────────────────────────────
# 5) No semantic query β†’ use original keyword filter logic / grouping
# ────────────────────────────────────────────────────────────────────────
else:
if group_by_question:
st.subheader("πŸ“Š Grouped by Question Text")
grouped = (
base_filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped, use_container_width=True)
if grouped.empty:
st.info("No questions found with current filters.")
else:
heading = []
if sel_countries: heading.append("Countries: " + ", ".join(sel_countries))
if sel_years: heading.append("Years: " + ", ".join(map(str, sel_years)))
st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years"))
st.dataframe(
base_filtered[["country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
if base_filtered.empty:
st.info("No matching questions found.")