File size: 9,700 Bytes
1e66d1d
9a96b62
1e66d1d
9a96b62
f7d7a98
 
 
1e66d1d
f7d7a98
b183d7b
f7d7a98
1e66d1d
 
 
f7d7a98
 
 
 
 
 
 
 
 
 
 
 
9a96b62
f7d7a98
 
 
 
 
 
 
e3e81aa
 
 
8041be5
f7d7a98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e66d1d
 
 
 
 
 
 
f7d7a98
1e66d1d
f7d7a98
 
 
3d8de76
 
 
 
 
 
 
f7d7a98
 
c6179a9
f7d7a98
5ec65d6
f7d7a98
c6179a9
 
03e0863
1e66d1d
 
 
 
 
 
 
 
 
5274b44
1e66d1d
f7d7a98
03e0863
0acec27
03e0863
 
0acec27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e1bd431
 
 
 
 
3d8de76
 
e1bd431
3d8de76
 
0acec27
 
 
5968656
 
 
 
 
 
 
 
 
 
 
 
029a4b3
3d8de76
029a4b3
 
 
5968656
 
 
03e0863
3d8de76
03e0863
 
b1d5a3b
5968656
f7d7a98
5968656
0acec27
03e0863
5968656
 
 
f7d7a98
 
5968656
3d8de76
029a4b3
b1d5a3b
03e0863
 
 
 
 
b1d5a3b
03e0863
f7d7a98
0acec27
03e0863
f7d7a98
0acec27
 
 
 
91a6a7f
0acec27
 
 
 
b1d5a3b
 
 
 
 
 
0acec27
c59bc5d
 
 
1e66d1d
03e0863
1e66d1d
f7d7a98
1e66d1d
3d8de76
f7d7a98
 
5ec65d6
 
 
03e0863
 
f7d7a98
9a96b62
 
f7d7a98
 
1e66d1d
9a96b62
 
f7d7a98
 
 
9a96b62
f7d7a98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# app.py – Unified Panel App with Semantic Search + Filterable Tabulator

import os, io, gc
import panel as pn
import pandas as pd
import boto3, torch
import psycopg2
from sentence_transformers import SentenceTransformer, util

pn.extension('tabulator')

# ──────────────────────────────────────────────────────────────────────
# 1) Database and Resource Loading
# ──────────────────────────────────────────────────────────────────────
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")

@pn.cache()
def get_data():
    conn = psycopg2.connect(
        host=DB_HOST, port=DB_PORT,
        dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
        sslmode="require"
    )
    df_ = pd.read_sql_query("""
        SELECT id, country, year, section,
               question_code, question_text,
               answer_code,  answer_text
          FROM survey_info;
    """, conn)
    conn.close()

    # Ensure year column is int, show blank instead of NaN
    if "year" in df_.columns:
        df_["year"] = pd.to_numeric(df_["year"], errors="coerce").astype("Int64").astype(str).replace({'<NA>': ''})
    return df_

df = get_data()

@pn.cache()
def load_embeddings():
    BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
    buf = io.BytesIO()
    boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
    buf.seek(0)
    ckpt = torch.load(buf, map_location="cpu")
    buf.close(); gc.collect()
    return ckpt["ids"], ckpt["embeddings"]

@pn.cache()
def get_st_model():
    return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")

@pn.cache()
def get_semantic_resources():
    model = get_st_model()
    ids_list, emb_tensor = load_embeddings()
    return model, ids_list, emb_tensor

# ──────────────────────────────────────────────────────────────────────
# 2) Widgets
# ──────────────────────────────────────────────────────────────────────
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())

ALL_COLUMNS = ["country","year","section","question_code","question_text","answer_code","answer_text","Score"]
w_columns = pn.widgets.MultiChoice(
    name="Columns to show",
    options=ALL_COLUMNS,
    value=["country","year","question_text","answer_text"]
)

w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers with exact string matching")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
w_topk = pn.widgets.Select(name="Top-K (semantic)", options=[5, 10, 20, 50, 100], value=10, disabled=True)

w_semquery = pn.widgets.TextInput(name="Semantic Query", placeholder="LLM-powered semantic search")
w_search_button = pn.widgets.Button(name="Search", button_type="primary")
w_clear_filters = pn.widgets.Button(name="Clear Filters", button_type="warning")

# ──────────────────────────────────────────────────────────────────────
# 3) Unified Results Table (Tabulator)
# ──────────────────────────────────────────────────────────────────────
result_table = pn.widgets.Tabulator(
    pagination='remote',
    page_size=15,
    sizing_mode="stretch_width",
    layout='fit_columns',
    show_index=False
)

# ──────────────────────────────────────────────────────────────────────
# 4) Search Logic
# ──────────────────────────────────────────────────────────────────────

def _group_by_question(df_in: pd.DataFrame) -> pd.DataFrame:
    if df_in.empty:
        return pd.DataFrame(columns=["question_text", "Countries", "Years", "Sample Answers"])
    tmp = df_in.copy()
    tmp["year"] = tmp["year"].replace('', pd.NA)
    grouped = (
        tmp.groupby("question_text", dropna=False)
        .agg({
            "country": lambda x: sorted({v for v in x if pd.notna(v)}),
            "year":    lambda x: sorted({str(v) for v in x if pd.notna(v)}),
            "answer_text": lambda x: list(x.dropna())[:3],
        })
        .reset_index()
        .rename(columns={"country": "Countries", "year": "Years", "answer_text": "Sample Answers"})
    )
    return grouped
    
def _selected_cols(has_score=False):
    allowed = set(ALL_COLUMNS)
    if not has_score and "Score" in w_columns.value:
        w_columns.value = [c for c in w_columns.value if c != "Score"]
    cols = [c for c in w_columns.value if c in allowed]
    if not cols:
        cols = ["country", "year", "question_text", "answer_text"]
    return cols

    
def search(event=None):
    query = w_semquery.value.strip()
    filt = df.copy()
    if w_countries.value:
        filt = filt[filt["country"].isin(w_countries.value)]
    if w_years.value:
        filt = filt[filt["year"].isin(w_years.value)]
    if w_keyword.value:
        filt = filt[
            filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
            filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
            filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
        ]

    if not query:
        result_table.value = _group_by_question(filt) if w_group.value else filt[_selected_cols(False)]
        return

    model, ids_list, emb_tensor = get_semantic_resources()
    filtered_ids = filt["id"].tolist()
    id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
    filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
    if not filtered_indices:
        result_table.value = _group_by_question(filt.iloc[0:0]) if w_group.value else pd.DataFrame(columns=_selected_cols(True))
        return

    top_k = min(int(w_topk.value), len(filtered_indices))
    filtered_embs = emb_tensor[filtered_indices]
    q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
    sims = util.cos_sim(q_vec, filtered_embs)[0]
    top_vals, top_idx = torch.topk(sims, k=top_k)

    top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
    sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
    score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
    sem_rows["Score"] = sem_rows["id"].map(score_map)
    sem_rows = sem_rows.sort_values("Score", ascending=False)

    result_table.value = _group_by_question(sem_rows.drop(columns=["Score"])) if w_group.value else sem_rows[_selected_cols(True)]


def clear_filters(event=None):
    w_countries.value = []
    w_years.value = []
    w_keyword.value = ""
    w_semquery.value = ""
    w_topk.disabled = True
    result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()

w_search_button.on_click(search)
w_clear_filters.on_click(clear_filters)

# Live updates for filters (except semantic query and keyword)
w_group.param.watch(lambda e: search(), 'value')
w_countries.param.watch(lambda e: search(), 'value')
w_years.param.watch(lambda e: search(), 'value')
w_columns.param.watch(lambda e: search(), 'value')

# Allow pressing Enter in semantic query or keyword to trigger search
w_semquery.param.watch(lambda e: search(), 'enter_pressed')
w_keyword.param.watch(lambda e: search(), 'enter_pressed')

# Enable/disable Top-K based on semantic query presence
def _toggle_topk_disabled(event=None):
    w_topk.disabled = (w_semquery.value.strip() == '')
_toggle_topk_disabled()
w_semquery.param.watch(lambda e: _toggle_topk_disabled(), 'value')

# Show all data at startup
result_table.value = df[["country", "year", "question_text", "answer_text"]].copy()

# ──────────────────────────────────────────────────────────────────────
# 5) Layout
# ──────────────────────────────────────────────────────────────────────
sidebar = pn.Column(
    "## πŸ”Ž Filters",
    w_countries, w_years, w_keyword, w_group, w_columns,
    pn.Spacer(height=20),
    "## 🧠 Semantic Search",
    w_semquery, 
    w_topk,
    w_search_button,
    pn.Spacer(height=20),
    w_clear_filters,
    width=300
)

main = pn.Column(
    pn.pane.Markdown("## 🌍 CGD Survey Explorer"),
    result_table
)

pn.template.FastListTemplate(
    title="CGD Survey Explorer",
    sidebar=sidebar,
    main=main,
    theme_toggle=True,
).servable()