File size: 7,326 Bytes
feb2540
 
 
baa583b
d3a33c8
 
23381bb
baa583b
 
b40e2c2
feb2540
 
 
6969959
23381bb
 
d638a1c
23381bb
 
 
baa583b
feb2540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baa583b
 
feb2540
 
 
baa583b
 
feb2540
 
baa583b
 
 
 
 
feb2540
baa583b
 
 
 
6969959
feb2540
 
 
 
 
 
 
 
 
 
 
23381bb
d3a33c8
cc9cf8b
feb2540
 
cc9cf8b
feb2540
 
 
 
6969959
feb2540
baa583b
 
feb2540
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6969959
feb2540
ecd8944
 
 
 
 
 
 
feb2540
ecd8944
 
 
 
 
 
 
feb2540
ecd8944
feb2540
ecd8944
 
 
 
 
 
 
 
feb2540
81d02b5
feb2540
 
 
 
6969959
 
 
 
 
 
 
 
 
 
 
 
 
 
c43bc47
feb2540
6969959
 
c43bc47
feb2540
 
 
 
 
 
 
 
6969959
feb2540
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
# app.py  – CGD Survey Explorer (keyword + semantic in one table)

import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util

# ─────────────────────────────────────────────────────────────
# 1)  Database credentials (HF Secrets or env vars)
# ─────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
    """Read survey_info once every 10 min."""
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT,
        dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
        sslmode="require",
    )
    df_ = pd.read_sql_query("""
        SELECT id, country, year, section,
               question_code, question_text,
               answer_code,  answer_text
          FROM survey_info;
    """, conn)
    conn.close()
    return df_

df = get_data()
row_lookup = {row.id: i for i, row in df.iterrows()}

# ─────────────────────────────────────────────────────────────
# 2)  Cached resources
# ─────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
    """Download ids + embedding tensor from S3 once per session."""
    BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
    buf = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
    buf.seek(0)
    ckpt = torch.load(buf, map_location="cpu")
    buf.close(); gc.collect()
    if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
        st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
    return ckpt["ids"], ckpt["embeddings"]

ids_list, emb_tensor = load_embeddings()

@st.cache_resource
def get_st_model():
    """Mini-LM sentence-transformer pinned to CPU (avoids meta-tensor bug)."""
    return SentenceTransformer(
        "sentence-transformers/all-MiniLM-L6-v2",
        device="cpu",
    )

# ─────────────────────────────────────────────────────────────
# 3)  Streamlit UI
# ─────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")

st.sidebar.header("πŸ”Ž Filter Questions")
country_opts = sorted(df["country"].dropna().unique())
year_opts    = sorted(df["year"].dropna().unique())

sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
sel_years     = st.sidebar.multiselect("Select Year(s)", year_opts)
keyword       = st.sidebar.text_input("Keyword Search (Question / Answer / Code)")
group_by_q    = st.sidebar.checkbox("Group by Question Text")

# ── Semantic search panel
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query      = st.sidebar.text_input("Enter a natural-language query")
search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())

# ── Always build the keyword/dropdown subset
filtered = df[
    (df["country"].isin(sel_countries) if sel_countries else True) &
    (df["year"].isin(sel_years)        if sel_years     else True) &
    (
        df["question_text"].str.contains(keyword, case=False, na=False) |
        df["answer_text"].str.contains(keyword, case=False, na=False)   |
        df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
    )
]

# ─────────────────────────────────────────────────────────────
# 4)  Semantic Search β†’ merged table
# ─────────────────────────────────────────────────────────────
if search_clicked:
    with st.spinner("Embedding & searching…"):
        model = get_st_model()
        q_vec = model.encode(
            sem_query.strip(),
            convert_to_tensor=True,
            device="cpu"
        ).cpu()

        sims = util.cos_sim(q_vec, emb_tensor)[0]
        top_vals, top_idx = torch.topk(sims, k=50)        # 50 candidates

        sem_ids   = [ids_list[i] for i in top_idx.tolist()]
        sem_rows  = df.loc[df["id"].isin(sem_ids)].copy()
        score_map = dict(zip(sem_ids, top_vals.tolist()))
        sem_rows["Score"] = sem_rows["id"].map(score_map)
        sem_rows = sem_rows.sort_values("Score", ascending=False)

        # rows that matched keyword/dropdown but not semantic
        remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy()
        remainder["Score"] = ""    # blank score

        combined = pd.concat([sem_rows, remainder], ignore_index=True)

    st.subheader(f"πŸ” Combined Results ({len(combined)})")
    st.dataframe(
        combined[["Score", "country", "year", "question_text", "answer_text"]],
        use_container_width=True,
    )
    st.stop()   # skip original display logic below when semantic ran

# ─────────────────────────────────────────────────────────────
# 5)  Original display (keyword / filters only)
# ─────────────────────────────────────────────────────────────
if group_by_q:
    st.subheader("πŸ“Š Grouped by Question Text")
    grouped = (
        filtered.groupby("question_text")
        .agg({
            "country": lambda x: sorted(set(x)),
            "year":    lambda x: sorted(set(x)),
            "answer_text": lambda x: list(x)[:3]
        })
        .reset_index()
        .rename(columns={
            "country": "Countries",
            "year":    "Years",
            "answer_text": "Sample Answers"
        })
    )
    st.dataframe(grouped, use_container_width=True)
    if grouped.empty:
        st.info("No questions found with current filters.")
else:
    hdr = []
    if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
    if sel_years:     hdr.append("Years: " + ", ".join(map(str, sel_years)))
    st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
    st.dataframe(
        filtered[["country", "year", "question_text", "answer_text"]],
        use_container_width=True,
    )
    if filtered.empty:
        st.info("No matching questions found.")