cgd-ui / app.py
gigiliu12's picture
Update app.py
c863afd verified
raw
history blame
7.67 kB
#!/usr/bin/env python3
import os, io, json, gc
import boto3, psycopg2, pandas as pd, torch
import streamlit as st
from sentence_transformers import SentenceTransformer, util
# ────────────────────────────────────────────────────────────────────────
# 0) Hugging Face secrets β†’ env vars (already set inside Spaces)
# DB_HOST / DB_PORT / DB_NAME / DB_USER / DB_PASSWORD
# AWS creds must be in aws_creds.json pushed with the app repo
# ────────────────────────────────────────────────────────────────────────
with open("aws_creds.json") as f:
creds = json.load(f)
os.environ["AWS_ACCESS_KEY_ID"] = creds["AccessKey"]
os.environ["AWS_SECRET_ACCESS_KEY"] = creds["SecretAccessKey"]
os.environ["AWS_DEFAULT_REGION"] = "us-east-2"
# ────────────────────────────────────────────────────────────────────────
# 1) DB β†’ DataFrame (cached 10 min) |
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def load_survey_dataframe() -> pd.DataFrame:
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require",
)
df = pd.read_sql_query(
"""SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info
""",
conn,
)
conn.close()
return df
df = load_survey_dataframe()
# ────────────────────────────────────────────────────────────────────────
# 2) S3 β†’ ids + embeddings (cached for session) |
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
bio = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, bio)
bio.seek(0)
ckpt = torch.load(bio, map_location="cpu")
bio.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# build quick lookup from id β†’ row index in DataFrame
row_lookup = {row_id: i for i, row_id in enumerate(df["id"])}
# ────────────────────────────────────────────────────────────────────────
# 3) Streamlit UI |
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB + Semantic Search)")
# ── 3a) Sidebar filters (original UI) ───────────────────────────────────
st.sidebar.header("πŸ”Ž Filter Questions")
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
sel_years = st.sidebar.multiselect("Select Year(s)", year_opts)
keyword = st.sidebar.text_input(
"Keyword Search (Question / Answer / Code)", ""
)
group_by_q = st.sidebar.checkbox("Group by Question Text")
# Apply keyword & dropdown filters
filtered = df[
(df["country"].isin(sel_countries) if sel_countries else True) &
(df["year"].isin(sel_years) if sel_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"] .str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ── 3b) Semantic-search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
with st.spinner("Embedding & searching…"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
scores = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(scores, k=10)
results = []
for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
db_id = ids_list[emb_row]
if db_id in row_lookup:
row = df.iloc[row_lookup[db_id]]
results.append({
"score": f"{score:.3f}",
"country": row["country"],
"year": row["year"],
"question": row["question_text"],
"answer": row["answer_text"],
})
if results:
st.subheader("πŸ” Semantic Results")
st.write(f"Showing top {len(results)} for **{sem_query}**")
st.dataframe(pd.DataFrame(results))
else:
st.info("No semantic matches found.")
st.markdown("---")
# ── 3c) Original results table / grouped view ───────────────────────────
if group_by_q:
st.subheader("πŸ“Š Grouped by Question Text")
grouped = (
filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped)
if grouped.empty:
st.info("No questions found with current filters.")
else:
# contextual heading
hdr = []
if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
if sel_years: hdr.append("Years: " + ", ".join(map(str, sel_years)))
st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
if filtered.empty:
st.info("No matching questions found.")