File size: 7,778 Bytes
c43bc47
 
baa583b
d3a33c8
 
23381bb
baa583b
 
b40e2c2
baa583b
c43bc47
baa583b
c43bc47
23381bb
 
d638a1c
23381bb
 
 
baa583b
c43bc47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1068edf
c43bc47
baa583b
 
 
c43bc47
baa583b
 
 
 
c43bc47
baa583b
 
 
 
 
 
c43bc47
baa583b
 
 
 
23381bb
baa583b
c43bc47
baa583b
23381bb
d3a33c8
cc9cf8b
 
4e71d04
baa583b
4e71d04
c43bc47
 
fad7dca
9cb0a2b
baa583b
cc9cf8b
81d02b5
c43bc47
baa583b
 
 
c43bc47
 
 
 
 
 
cca0254
 
9cb0a2b
baa583b
cca0254
81d02b5
 
c43bc47
 
 
 
 
 
 
baa583b
c43bc47
 
baa583b
c43bc47
 
 
 
 
baa583b
c43bc47
 
baa583b
c43bc47
4e71d04
c43bc47
 
 
 
 
81d02b5
c43bc47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
# app.py  – CGD Survey Explorer + merged semantic search
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util

# ────────────────────────────────────────────────────────────────────────
# 1)  Database credentials (provided via HF Secrets / env vars)
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")            # set these in the Space’s Secrets
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
    """Pull the full survey_info table (cached for 10 min)."""
    conn = psycopg2.connect(
        host=DB_HOST,
        port=DB_PORT,
        dbname=DB_NAME,
        user=DB_USER,
        password=DB_PASSWORD,
        sslmode="require",
    )
    df_ = pd.read_sql_query(
        """SELECT id, country, year, section,
                  question_code, question_text,
                  answer_code, answer_text
             FROM survey_info;""",
        conn,
    )
    conn.close()
    return df_

df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}

# ────────────────────────────────────────────────────────────────────────
# 2)  Pre-computed embeddings (ids + tensor) – download once per session
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
    BUCKET = "cgd-embeddings-bucket"
    KEY    = "survey_info_embeddings.pt"  # contains {'ids', 'embeddings'}
    buf = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
    buf.seek(0)
    ckpt = torch.load(buf, map_location="cpu")
    buf.close(); gc.collect()

    if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
        st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
    return ckpt["ids"], ckpt["embeddings"]

ids_list, emb_tensor = load_embeddings()

# ────────────────────────────────────────────────────────────────────────
# 3)  Streamlit UI
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")

st.sidebar.header("πŸ”Ž Filter Questions")

country_options = sorted(df["country"].dropna().unique())
year_options    = sorted(df["year"].dropna().unique())

sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
sel_years     = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
    "Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")

# --- Semantic-search input (kept in sidebar) ---------------------------
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())

# ── base_filtered: applies dropdown + keyword logic (always computed) ──
base_filtered = df[
    (df["country"].isin(sel_countries) if sel_countries else True) &
    (df["year"].isin(sel_years)        if sel_years     else True) &
    (
        df["question_text"].str.contains(keyword, case=False, na=False) |
        df["answer_text"].str.contains(keyword, case=False, na=False)   |
        df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
    )
]

# ────────────────────────────────────────────────────────────────────────
# 4)  When the Search button is clicked β†’ build merged table
# ────────────────────────────────────────────────────────────────────────
if search_clicked:
    with st.spinner("Embedding & searching…"):
        model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()

        sims = util.cos_sim(q_vec, emb_tensor)[0]
        top_vals, top_idx = torch.topk(sims, k=50)   # get 50 candidates

        sem_ids   = [ids_list[i] for i in top_idx.tolist()]
        sem_rows  = df.loc[df["id"].isin(sem_ids)].copy()
        score_map = dict(zip(sem_ids, top_vals.tolist()))
        sem_rows["Score"] = sem_rows["id"].map(score_map)
        sem_rows = sem_rows.sort_values("Score", ascending=False)

        remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy()
        remainder["Score"] = ""   # blank score for keyword-only rows

        combined = pd.concat([sem_rows, remainder], ignore_index=True)

    st.subheader(f"πŸ” Combined Results ({len(combined)})")
    st.dataframe(
        combined[["Score", "country", "year", "question_text", "answer_text"]],
        use_container_width=True,
    )

# ────────────────────────────────────────────────────────────────────────
# 5)  No semantic query β†’ use original keyword filter logic / grouping
# ────────────────────────────────────────────────────────────────────────
else:
    if group_by_question:
        st.subheader("πŸ“Š Grouped by Question Text")
        grouped = (
            base_filtered.groupby("question_text")
            .agg({
                "country": lambda x: sorted(set(x)),
                "year":    lambda x: sorted(set(x)),
                "answer_text": lambda x: list(x)[:3]
            })
            .reset_index()
            .rename(columns={
                "country": "Countries",
                "year":    "Years",
                "answer_text": "Sample Answers"
            })
        )
        st.dataframe(grouped, use_container_width=True)
        if grouped.empty:
            st.info("No questions found with current filters.")
    else:
        heading = []
        if sel_countries: heading.append("Countries: " + ", ".join(sel_countries))
        if sel_years:     heading.append("Years: " + ", ".join(map(str, sel_years)))
        st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years"))
        st.dataframe(
            base_filtered[["country", "year", "question_text", "answer_text"]],
            use_container_width=True,
        )
        if base_filtered.empty:
            st.info("No matching questions found.")