File size: 7,673 Bytes
c863afd
 
 
d3a33c8
c863afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23381bb
 
 
c863afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc9cf8b
 
c863afd
 
4e71d04
c863afd
 
 
 
 
 
81d02b5
c863afd
81d02b5
c863afd
 
cca0254
 
c863afd
 
cca0254
81d02b5
 
c863afd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc9cf8b
 
 
 
 
c863afd
 
cc9cf8b
 
 
 
c863afd
cc9cf8b
 
 
 
 
 
4e71d04
c863afd
 
 
 
 
cc9cf8b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
import os, io, json, gc
import boto3, psycopg2, pandas as pd, torch
import streamlit as st
from sentence_transformers import SentenceTransformer, util

# ────────────────────────────────────────────────────────────────────────
# 0)  Hugging Face secrets β†’ env vars (already set inside Spaces)
#     DB_HOST / DB_PORT / DB_NAME / DB_USER / DB_PASSWORD
#     AWS creds must be in aws_creds.json pushed with the app repo
# ────────────────────────────────────────────────────────────────────────
with open("aws_creds.json") as f:
    creds = json.load(f)
os.environ["AWS_ACCESS_KEY_ID"]     = creds["AccessKey"]
os.environ["AWS_SECRET_ACCESS_KEY"] = creds["SecretAccessKey"]
os.environ["AWS_DEFAULT_REGION"]    = "us-east-2"

# ────────────────────────────────────────────────────────────────────────
# 1)  DB β†’ DataFrame  (cached 10 min)                                     |
# ────────────────────────────────────────────────────────────────────────
DB_HOST     = os.getenv("DB_HOST")
DB_PORT     = os.getenv("DB_PORT", "5432")
DB_NAME     = os.getenv("DB_NAME")
DB_USER     = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

@st.cache_data(ttl=600)
def load_survey_dataframe() -> pd.DataFrame:
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT,
        dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
        sslmode="require",
    )
    df = pd.read_sql_query(
        """SELECT id, country, year, section,
                  question_code, question_text,
                  answer_code, answer_text
             FROM survey_info
        """,
        conn,
    )
    conn.close()
    return df

df = load_survey_dataframe()

# ────────────────────────────────────────────────────────────────────────
# 2)  S3 β†’ ids + embeddings  (cached for session)                         |
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
    BUCKET = "cgd-embeddings-bucket"
    KEY    = "survey_info_embeddings.pt"   # contains {'ids', 'embeddings'}
    bio = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, bio)
    bio.seek(0)
    ckpt = torch.load(bio, map_location="cpu")
    bio.close(); gc.collect()
    if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
        st.error("Bad checkpoint format"); st.stop()
    return ckpt["ids"], ckpt["embeddings"]

ids_list, emb_tensor = load_embeddings()

# build quick lookup from id β†’ row index in DataFrame
row_lookup = {row_id: i for i, row_id in enumerate(df["id"])}

# ────────────────────────────────────────────────────────────────────────
# 3)  Streamlit UI                                                        |
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB + Semantic Search)")

# ── 3a) Sidebar filters (original UI) ───────────────────────────────────
st.sidebar.header("πŸ”Ž Filter Questions")

country_opts = sorted(df["country"].dropna().unique())
year_opts    = sorted(df["year"].dropna().unique())

sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
sel_years     = st.sidebar.multiselect("Select Year(s)", year_opts)
keyword       = st.sidebar.text_input(
    "Keyword Search (Question / Answer / Code)", ""
)
group_by_q    = st.sidebar.checkbox("Group by Question Text")

# Apply keyword & dropdown filters
filtered = df[
    (df["country"].isin(sel_countries) if sel_countries else True) &
    (df["year"].isin(sel_years)        if sel_years     else True) &
    (
        df["question_text"].str.contains(keyword, case=False, na=False) |
        df["answer_text"]  .str.contains(keyword, case=False, na=False) |
        df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
    )
]

# ── 3b) Semantic-search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
    with st.spinner("Embedding & searching…"):
        model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
        scores = util.cos_sim(q_vec, emb_tensor)[0]
        top_vals, top_idx = torch.topk(scores, k=10)
        results = []
        for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
            db_id = ids_list[emb_row]
            if db_id in row_lookup:
                row = df.iloc[row_lookup[db_id]]
                results.append({
                    "score": f"{score:.3f}",
                    "country": row["country"],
                    "year": row["year"],
                    "question": row["question_text"],
                    "answer": row["answer_text"],
                })
        if results:
            st.subheader("πŸ” Semantic Results")
            st.write(f"Showing top {len(results)} for **{sem_query}**")
            st.dataframe(pd.DataFrame(results))
        else:
            st.info("No semantic matches found.")

st.markdown("---")

# ── 3c) Original results table / grouped view ───────────────────────────
if group_by_q:
    st.subheader("πŸ“Š Grouped by Question Text")
    grouped = (
        filtered.groupby("question_text")
        .agg({
            "country": lambda x: sorted(set(x)),
            "year":    lambda x: sorted(set(x)),
            "answer_text": lambda x: list(x)[:3]
        })
        .reset_index()
        .rename(columns={
            "country": "Countries",
            "year":    "Years",
            "answer_text": "Sample Answers"
        })
    )
    st.dataframe(grouped)
    if grouped.empty:
        st.info("No questions found with current filters.")
else:
    # contextual heading
    hdr = []
    if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
    if sel_years:     hdr.append("Years: " + ", ".join(map(str, sel_years)))
    st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
    st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
    if filtered.empty:
        st.info("No matching questions found.")