cgd-ui / app.py
gigiliu12's picture
Update app.py
e35f77b verified
raw
history blame
7.18 kB
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ────────────────────────────────────────────────────────────────────────
# 1) DB credentials (from HF secrets or env) – original
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
try:
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
sslmode="require",
)
query = """
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
"""
df_ = pd.read_sql_query(query, conn)
conn.close()
return df_
except Exception as e:
st.error(f"Failed to connect to the database: {e}")
st.stop()
df = get_data() # ← original DataFrame
# Build a quick lookup row-index β†’ DataFrame row for later
row_lookup = {row.id: i for i, row in df.iterrows()}
# ────────────────────────────────────────────────────────────────────────
# 2) Load embeddings + ids once per session (S3) – new, cached
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
# credentials already in env (HF secrets) – boto3 will pick them up
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# ────────────────────────────────────────────────────────────────────────
# 3) Streamlit UI – original filters + new semantic search
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")
st.sidebar.header("πŸ”Ž Filter Questions")
country_options = sorted(df["country"].dropna().unique())
year_options = sorted(df["year"].dropna().unique())
selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
"Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")
# ── new semantic search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
with st.spinner("Embedding & searching…"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
scores = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(scores, k=10) # grab extra
results = []
for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
db_id = ids_list[emb_row]
if db_id in row_lookup:
row = df.iloc[row_lookup[db_id]]
if row["question_text"] and row["answer_text"]:
results.append({
"Score": f"{score:.3f}",
"Country": row["country"],
"Year": row["year"],
"Question": row["question_text"],
"Answer": row["answer_text"],
})
if results:
st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
st.dataframe(pd.DataFrame(results).head(5))
else:
st.info("No semantic matches found.")
st.markdown("---")
# ── apply original filters ──────────────────────────────────────────────
filtered = df[
(df["country"].isin(selected_countries) if selected_countries else True) &
(df["year"].isin(selected_years) if selected_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ── original output logic ───────────────────────
if group_by_question:
st.subheader("πŸ“Š Grouped by Question Text")
grouped = (
filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped)
if grouped.empty:
st.info("No questions found with current filters.")
else:
heading_parts = []
if selected_countries:
heading_parts.append("Countries: " + ", ".join(selected_countries))
if selected_years:
heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
if filtered.empty:
st.info("No matching questions found.")