Spaces:
Sleeping
Sleeping
File size: 8,160 Bytes
1e66d1d 9a96b62 1e66d1d 9a96b62 f7d7a98 1e66d1d f7d7a98 b183d7b f7d7a98 1e66d1d f7d7a98 9a96b62 f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d 9a96b62 f7d7a98 5968656 1e66d1d 5968656 f7d7a98 5968656 f7d7a98 5968656 f7d7a98 5968656 f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 1e66d1d f7d7a98 9a96b62 f7d7a98 1e66d1d 9a96b62 f7d7a98 9a96b62 f7d7a98 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# app.py β Unified Panel App with Semantic Search + Filterable Tabulator
import os, io, gc
import panel as pn
import pandas as pd
import boto3, torch
import psycopg2
from sentence_transformers import SentenceTransformer, util
pn.extension('tabulator')
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 1) Database and Resource Loading
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT", "5432")
DB_NAME = os.getenv("DB_NAME")
DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD")
@pn.cache()
def get_data():
conn = psycopg2.connect(
host=DB_HOST, port=DB_PORT,
dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
sslmode="require"
)
df_ = pd.read_sql_query("""
SELECT id, country, year, section,
question_code, question_text,
answer_code, answer_text
FROM survey_info;
""", conn)
conn.close()
return df_
df = get_data()
@pn.cache()
def load_embeddings():
BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
buf = io.BytesIO()
boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
buf.seek(0)
ckpt = torch.load(buf, map_location="cpu")
buf.close(); gc.collect()
return ckpt["ids"], ckpt["embeddings"]
@pn.cache()
def get_st_model():
return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device="cpu")
@pn.cache()
def get_semantic_resources():
model = get_st_model()
ids_list, emb_tensor = load_embeddings()
return model, ids_list, emb_tensor
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 2) Widgets
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
country_opts = sorted(df["country"].dropna().unique())
year_opts = sorted(df["year"].dropna().unique())
w_countries = pn.widgets.MultiSelect(name="Countries", options=country_opts)
w_years = pn.widgets.MultiSelect(name="Years", options=year_opts)
w_keyword = pn.widgets.TextInput(name="Keyword Search", placeholder="Search questions or answers")
w_group = pn.widgets.Checkbox(name="Group by Question Text", value=False)
w_semquery = pn.widgets.TextInput(name="Semantic Query")
w_search_button = pn.widgets.Button(name="Semantic Search", button_type="primary")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 3) Unified Results Table (Tabulator)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
result_table = pn.widgets.Tabulator(
pagination='remote',
page_size=15,
sizing_mode="stretch_width",
layout='fit_columns',
show_index=False,
)
@pn.depends(w_countries, w_years, w_keyword, w_group, watch=True)
def update_table(countries, years, keyword, group):
filt = df.copy()
if countries:
filt = filt[filt["country"].isin(countries)]
if years:
filt = filt[filt["year"].isin(years)]
if keyword:
filt = filt[
filt["question_text"].str.contains(keyword, case=False, na=False) |
filt["answer_text"].str.contains(keyword, case=False, na=False) |
filt["question_code"].astype(str).str.contains(keyword, case=False, na=False)
]
if group:
grouped = (
filt.groupby("question_text")
.agg({
"country": lambda x: sorted(set(x)),
"year": lambda x: sorted(set(x)),
"answer_text": lambda x: list(x)[:3]
})
.reset_index()
.rename(columns={
"country": "Countries",
"year": "Years",
"answer_text": "Sample Answers"
})
)
result_table.value = grouped
else:
result_table.value = filt[["country", "year", "question_text", "answer_text"]]
def semantic_search(event=None):
query = w_semquery.value.strip()
if not query:
return
# Step 1: Filter the full dataframe
filt = df.copy()
if w_countries.value:
filt = filt[filt["country"].isin(w_countries.value)]
if w_years.value:
filt = filt[filt["year"].isin(w_years.value)]
if w_keyword.value:
filt = filt[
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
]
# Step 2: Load only embeddings for the filtered rows
model, ids_list, emb_tensor = get_semantic_resources()
# Create a mask for filtered IDs
filtered_ids = filt["id"].tolist()
id_to_index = {id_: i for i, id_ in enumerate(ids_list)}
filtered_indices = [id_to_index[id_] for id_ in filtered_ids if id_ in id_to_index]
# Subset the embedding tensor
filtered_embs = emb_tensor[filtered_indices]
# Step 3: Semantic search only within filtered subset
q_vec = model.encode(query, convert_to_tensor=True, device="cpu").cpu()
sims = util.cos_sim(q_vec, filtered_embs)[0]
top_vals, top_idx = torch.topk(sims, k=50)
top_filtered_ids = [filtered_ids[i] for i in top_idx.tolist()]
sem_rows = filt[filt["id"].isin(top_filtered_ids)].copy()
score_map = dict(zip(top_filtered_ids, top_vals.tolist()))
sem_rows["Score"] = sem_rows["id"].map(score_map)
sem_rows = sem_rows.sort_values("Score", ascending=False)
# Final output
result_table.value = sem_rows[["Score", "country", "year", "question_text", "answer_text"]]
filt = df.copy()
if w_countries.value:
filt = filt[filt["country"].isin(w_countries.value)]
if w_years.value:
filt = filt[filt["year"].isin(w_years.value)]
if w_keyword.value:
filt = filt[
filt["question_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["answer_text"].str.contains(w_keyword.value, case=False, na=False) |
filt["question_code"].astype(str).str.contains(w_keyword.value, case=False, na=False)
]
remainder = filt.loc[~filt["id"].isin(sem_ids)].copy()
remainder["Score"] = ""
combined = pd.concat([sem_rows, remainder], ignore_index=True)
result_table.value = combined[["Score", "country", "year", "question_text", "answer_text"]]
w_search_button.on_click(semantic_search)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# 4) Layout
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
sidebar = pn.Column(
"## π Filters",
w_countries, w_years, w_keyword, w_group,
pn.Spacer(height=20),
"## π§ Semantic Search",
w_semquery, w_search_button,
width=300
)
main = pn.Column(
pn.pane.Markdown("## π CGD Survey Explorer"),
result_table
)
pn.template.FastListTemplate(
title="CGD Survey Explorer",
sidebar=sidebar,
main=main,
theme_toggle=True,
).servable()
|