File size: 7,183 Bytes
d3a33c8
e35f77b
d3a33c8
 
23381bb
e35f77b
 
23381bb
e35f77b
 
 
23381bb
 
 
 
 
 
 
e35f77b
23381bb
 
 
 
 
 
35c1ade
e35f77b
23381bb
e35f77b
 
 
 
 
 
 
23381bb
e35f77b
23381bb
4b5445a
23381bb
 
e35f77b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23381bb
e35f77b
 
 
 
 
 
 
23381bb
d3a33c8
cc9cf8b
 
4e71d04
e35f77b
4e71d04
 
e35f77b
fad7dca
9cb0a2b
e35f77b
cc9cf8b
81d02b5
e35f77b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81d02b5
4e71d04
e35f77b
cca0254
 
9cb0a2b
e35f77b
cca0254
81d02b5
 
e35f77b
cc9cf8b
 
 
 
 
 
e35f77b
 
cc9cf8b
 
 
 
e35f77b
cc9cf8b
 
 
 
 
 
4e71d04
cc9cf8b
 
 
 
 
e35f77b
cc9cf8b
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util

# ────────────────────────────────────────────────────────────────────────
# 1)  DB credentials (from HF secrets or env)  – original 
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
    try:
        conn = psycopg2.connect(
            host=DB_HOST,
            port=DB_PORT,
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASSWORD,
            sslmode="require",
        )
        query = """
            SELECT id, country, year, section,
                   question_code, question_text,
                   answer_code, answer_text
              FROM survey_info;
        """
        df_ = pd.read_sql_query(query, conn)
        conn.close()
        return df_
    except Exception as e:
        st.error(f"Failed to connect to the database: {e}")
        st.stop()

df = get_data()              # ← original DataFrame

# Build a quick lookup row-index β†’ DataFrame row for later
row_lookup = {row.id: i for i, row in df.iterrows()}

# ────────────────────────────────────────────────────────────────────────
# 2)  Load embeddings + ids once per session  (S3) – new, cached
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def load_embeddings():
    # credentials already in env (HF secrets) – boto3 will pick them up
    BUCKET = "cgd-embeddings-bucket"
    KEY    = "survey_info_embeddings.pt"   # dict {'ids', 'embeddings'}
    buf = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
    buf.seek(0)
    ckpt = torch.load(buf, map_location="cpu")
    buf.close(); gc.collect()

    if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
        st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()

    return ckpt["ids"], ckpt["embeddings"]

ids_list, emb_tensor = load_embeddings()

# ────────────────────────────────────────────────────────────────────────
# 3)  Streamlit UI – original filters + new semantic search
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")

st.sidebar.header("πŸ”Ž Filter Questions")

country_options = sorted(df["country"].dropna().unique())
year_options    = sorted(df["year"].dropna().unique())

selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
selected_years     = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
    "Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")

# ── new semantic search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
    with st.spinner("Embedding & searching…"):
        model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
        q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
        scores = util.cos_sim(q_vec, emb_tensor)[0]
        top_vals, top_idx = torch.topk(scores, k=10)   # grab extra

        results = []
        for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
            db_id = ids_list[emb_row]
            if db_id in row_lookup:
                row = df.iloc[row_lookup[db_id]]
                if row["question_text"] and row["answer_text"]:
                    results.append({
                        "Score": f"{score:.3f}",
                        "Country": row["country"],
                        "Year": row["year"],
                        "Question": row["question_text"],
                        "Answer": row["answer_text"],
                    })
        if results:
            st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
            st.dataframe(pd.DataFrame(results).head(5))
        else:
            st.info("No semantic matches found.")

st.markdown("---")

# ── apply original filters ──────────────────────────────────────────────
filtered = df[
    (df["country"].isin(selected_countries) if selected_countries else True) &
    (df["year"].isin(selected_years)        if selected_years else True)       &
    (
        df["question_text"].str.contains(keyword, case=False, na=False) |
        df["answer_text"].str.contains(keyword, case=False, na=False)   |
        df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
    )
]

# ── original output logic ───────────────────────
if group_by_question:
    st.subheader("πŸ“Š Grouped by Question Text")
    grouped = (
        filtered.groupby("question_text")
        .agg({
            "country": lambda x: sorted(set(x)),
            "year":    lambda x: sorted(set(x)),
            "answer_text": lambda x: list(x)[:3]
        })
        .reset_index()
        .rename(columns={
            "country": "Countries",
            "year":    "Years",
            "answer_text": "Sample Answers"
        })
    )
    st.dataframe(grouped)
    if grouped.empty:
        st.info("No questions found with current filters.")
else:
    heading_parts = []
    if selected_countries:
        heading_parts.append("Countries: " + ", ".join(selected_countries))
    if selected_years:
        heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
    st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
    st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
    if filtered.empty:
        st.info("No matching questions found.")