File size: 7,392 Bytes
baa583b
d3a33c8
 
23381bb
baa583b
 
b40e2c2
baa583b
6969959
baa583b
6969959
23381bb
 
d638a1c
23381bb
 
6969959
23381bb
baa583b
6969959
 
 
 
 
 
 
1068edf
6969959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
baa583b
 
 
6969959
baa583b
 
ecd8944
 
 
 
 
baa583b
6969959
baa583b
6969959
baa583b
 
 
 
 
 
6969959
baa583b
6969959
baa583b
 
 
6969959
baa583b
6969959
baa583b
23381bb
d3a33c8
cc9cf8b
 
6969959
4e71d04
baa583b
4e71d04
6969959
 
fad7dca
9cb0a2b
baa583b
cc9cf8b
81d02b5
6969959
baa583b
 
 
6969959
 
ecd8944
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6969959
 
 
 
 
cca0254
 
9cb0a2b
baa583b
cca0254
81d02b5
 
6969959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c43bc47
81d02b5
6969959
 
 
 
 
c43bc47
6969959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os, io, json, gc
import streamlit as st
import pandas as pd
import psycopg2
import boto3, torch
from sentence_transformers import SentenceTransformer, util

# ────────────────────────────────────────────────────────────────────────
# 1)  DB credentials (from HF secrets or env)  – original 
# ────────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")


@st.cache_data(ttl=600)
def get_data() -> pd.DataFrame:
    try:
        conn = psycopg2.connect(
            host=DB_HOST,
            dbname=DB_NAME,
            user=DB_USER,
            password=DB_PASSWORD,
            sslmode="require",

        )
        query = """
            SELECT id, country, year, section,
                   question_code, question_text,
                   answer_code, answer_text
              FROM survey_info;
        """
        df_ = pd.read_sql_query(query, conn)
        conn.close()
        return df_
    except Exception as e:
        st.error(f"Failed to connect to the database: {e}")
        st.stop()

df = get_data()              # ← original DataFrame

# Build a quick lookup row-index β†’ DataFrame row for later
row_lookup = {row.id: i for i, row in df.iterrows()}

# ────────────────────────────────────────────────────────────────────────
# 2)  Load embeddings + ids once per session  (S3) – new, cached
# ────────────────────────────────────────────────────────────────────────
@st.cache_resource
def get_st_model():
    return SentenceTransformer(
        "sentence-transformers/all-MiniLM-L6-v2",
        device="cpu",
    )
def load_embeddings():
    # credentials already in env (HF secrets) – boto3 will pick them up
    BUCKET = "cgd-embeddings-bucket"
    KEY    = "survey_info_embeddings.pt"   # dict {'ids', 'embeddings'}
    buf = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
    buf.seek(0)
    ckpt = torch.load(buf, map_location="cpu")
    buf.close(); gc.collect()

    if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
        st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()

    return ckpt["ids"], ckpt["embeddings"]

ids_list, emb_tensor = load_embeddings()

# ────────────────────────────────────────────────────────────────────────
# 3)  Streamlit UI – original filters + new semantic search
# ────────────────────────────────────────────────────────────────────────
st.title("🌍 CGD Survey Explorer (Live DB)")

st.sidebar.header("πŸ”Ž Filter Questions")


country_options = sorted(df["country"].dropna().unique())
year_options    = sorted(df["year"].dropna().unique())

selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
selected_years     = st.sidebar.multiselect("Select Year(s)", year_options)
keyword = st.sidebar.text_input(
    "Keyword Search (Question text / Answer text / Question code)", ""
)
group_by_question = st.sidebar.checkbox("Group by Question Text")

# ── new semantic search panel ───────────────────────────────────────────
st.sidebar.markdown("---")
st.sidebar.subheader("🧠 Semantic Search")
sem_query = st.sidebar.text_input("Enter a natural-language query")
if st.sidebar.button("Search", disabled=not sem_query.strip()):
    with st.spinner("Embedding & searching…"):
        # 1) embed query
        model = get_st_model()                      # cached CPU model
        q_vec = model.encode(
            sem_query.strip(),
            convert_to_tensor=True,
            device="cpu"
        ).cpu()

        # 2) semantic similarity
        sims = util.cos_sim(q_vec, emb_tensor)[0]
        top_vals, top_idx = torch.topk(sims, k=50)

        sem_ids   = [ids_list[i] for i in top_idx.tolist()]
        sem_rows  = df.loc[df["id"].isin(sem_ids)].copy()
        score_map = dict(zip(sem_ids, top_vals.tolist()))
        sem_rows["Score"] = sem_rows["id"].map(score_map)
        sem_rows = sem_rows.sort_values("Score", ascending=False)

        # 3) keyword / dropdown remainder
        remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy()
        remainder["Score"] = ""         # blank for keyword-only rows

        combined = pd.concat([sem_rows, remainder], ignore_index=True)

    st.subheader(f"πŸ” Combined Results ({len(combined)})")
    st.dataframe(
        combined[["Score", "country", "year", "question_text", "answer_text"]],
        use_container_width=True,
    )
    st.stop()   # skip the old display logic below when semantic search ran

# ── apply original filters ──────────────────────────────────────────────
filtered = df[
    (df["country"].isin(selected_countries) if selected_countries else True) &
    (df["year"].isin(selected_years)        if selected_years else True)       &
    (
        df["question_text"].str.contains(keyword, case=False, na=False) |
        df["answer_text"].str.contains(keyword, case=False, na=False)   |
        df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
    )
]

# ── original output logic ───────────────────────
if group_by_question:
    st.subheader("πŸ“Š Grouped by Question Text")

    grouped = (
        filtered.groupby("question_text")
        .agg({
            "country": lambda x: sorted(set(x)),
            "year":    lambda x: sorted(set(x)),
            "answer_text": lambda x: list(x)[:3]
        })
        .reset_index()
        .rename(columns={
            "country": "Countries",
            "year":    "Years",
            "answer_text": "Sample Answers"
        })
    )

    st.dataframe(grouped)

    if grouped.empty:
        st.info("No questions found with current filters.")

else:

    heading_parts = []
    if selected_countries:
        heading_parts.append("Countries: " + ", ".join(selected_countries))
    if selected_years:
        heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
    st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))




    st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])

    if filtered.empty:
        st.info("No matching questions found.")