Spaces:
Sleeping
Sleeping
File size: 7,778 Bytes
c43bc47 baa583b d3a33c8 23381bb baa583b b40e2c2 baa583b c43bc47 baa583b c43bc47 23381bb d638a1c 23381bb baa583b c43bc47 1068edf c43bc47 baa583b c43bc47 baa583b c43bc47 baa583b c43bc47 baa583b 23381bb baa583b c43bc47 baa583b 23381bb d3a33c8 cc9cf8b 4e71d04 baa583b 4e71d04 c43bc47 fad7dca 9cb0a2b baa583b cc9cf8b 81d02b5 c43bc47 baa583b c43bc47 cca0254 9cb0a2b baa583b cca0254 81d02b5 c43bc47 baa583b c43bc47 baa583b c43bc47 baa583b c43bc47 baa583b c43bc47 4e71d04 c43bc47 81d02b5 c43bc47 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
#!/usr/bin/env python3
# app.py β CGD Survey Explorer + merged semantic search
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1) Database credentials (provided via HF Secrets / env vars)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DB_HOST = os.getenv("DB_HOST") # set these in the Spaceβs Secrets
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
"""Pull the full survey_info table (cached for 10 min)."""
conn = psycopg2.connect(
host=DB_HOST,
port=DB_PORT,
dbname=DB_NAME,
user=DB_USER,
password=DB_PASSWORD,
sslmode="require",
)
df_ = pd.read_sql_query(
"""SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;""",
conn,
)
conn.close()
return df_
df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2) Pre-computed embeddings (ids + tensor) β download once per session
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
@st.cache_resource
def load_embeddings():
BUCKET = "cgd-embeddings-bucket"
KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
return ckpt["ids"], ckpt["embeddings"]
ids_list, emb_tensor = load_embeddings()
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3) Streamlit UI
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.title("π CGD Survey Explorer (Live DB)")
st.sidebar.header("π Filter Questions")
country_options = sorted(df["country"].dropna().unique())
year_options = sorted(df["year"].dropna().unique())
sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
sel_years = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
"Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")
# --- Semantic-search input (kept in sidebar) ---------------------------
st.sidebar.markdown("---")
st.sidebar.subheader("π§ Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
# ββ base_filtered: applies dropdown + keyword logic (always computed) ββ
base_filtered = df[
(df["country"].isin(sel_countries) if sel_countries else True) &
(df["year"].isin(sel_years) if sel_years else True) &
(
df["question_text"].str.contains(keyword, case=False, na=False) |
df["answer_text"].str.contains(keyword, case=False, na=False) |
df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
)
]
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4) When the Search button is clicked β build merged table
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
if search_clicked:
with st.spinner("Embedding & searchingβ¦"):
model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
sims = util.cos_sim(q_vec, emb_tensor)[0]
top_vals, top_idx = torch.topk(sims, k=50) # get 50 candidates
sem_ids = [ids_list[i] for i in top_idx.tolist()]
sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
score_map = dict(zip(sem_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy()
remainder["Score"] = "" # blank score for keyword-only rows
combined = pd.concat([sem_rows, remainder], ignore_index=True)
st.subheader(f"π Combined Results ({len(combined)})")
st.dataframe(
combined[["Score", "country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 5) No semantic query β use original keyword filter logic / grouping
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
else:
if group_by_question:
st.subheader("π Grouped by Question Text")
grouped = (
base_filtered.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
st.dataframe(grouped, use_container_width=True)
if grouped.empty:
st.info("No questions found with current filters.")
else:
heading = []
if sel_countries: heading.append("Countries: " + ", ".join(sel_countries))
if sel_years: heading.append("Years: " + ", ".join(map(str, sel_years)))
st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years"))
st.dataframe(
base_filtered[["country", "year", "question_text", "answer_text"]],
use_container_width=True,
)
if base_filtered.empty:
st.info("No matching questions found.")
|