gigiliu12 commited on
Commit
6969959
Β·
verified Β·
1 Parent(s): 6758cdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -109
app.py CHANGED
@@ -1,5 +1,3 @@
1
- #!/usr/bin/env python3
2
- # app.py – CGD Survey Explorer + merged semantic search
3
  import os, io, json, gc
4
  import streamlit as st
5
  import pandas as pd
@@ -8,90 +6,119 @@ import boto3, torch
8
  from sentence_transformers import SentenceTransformer, util
9
 
10
  # ────────────────────────────────────────────────────────────────────────
11
- # 1) Database credentials (provided via HF Secrets / env vars)
12
  # ────────────────────────────────────────────────────────────────────────
13
- DB_HOST = os.getenv("DB_HOST") # set these in the Space’s Secrets
14
  DB_PORT = os.getenv("DB_PORT", "5432")
15
  DB_NAME = os.getenv("DB_NAME")
16
  DB_USER = os.getenv("DB_USER")
17
  DB_PASSWORD = os.getenv("DB_PASSWORD")
18
 
 
19
  @st.cache_data(ttl=600)
20
  def get_data() -> pd.DataFrame:
21
- """Pull the full survey_info table (cached for 10 min)."""
22
- conn = psycopg2.connect(
23
- host=DB_HOST,
24
- port=DB_PORT,
25
- dbname=DB_NAME,
26
- user=DB_USER,
27
- password=DB_PASSWORD,
28
- sslmode="require",
29
- )
30
- df_ = pd.read_sql_query(
31
- """SELECT id, country, year, section,
32
- question_code, question_text,
33
- answer_code, answer_text
34
- FROM survey_info;""",
35
- conn,
36
- )
37
- conn.close()
38
- return df_
39
 
40
- df = get_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  row_lookup = {row.id: i for i, row in df.iterrows()}
42
 
43
  # ────────────────────────────────────────────────────────────────────────
44
- # 2) Pre-computed embeddings (ids + tensor) – download once per session
45
  # ────────────────────────────────────────────────────────────────────────
46
  @st.cache_resource
47
  def load_embeddings():
 
48
  BUCKET = "cgd-embeddings-bucket"
49
- KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
50
  buf = io.BytesIO()
51
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
52
  buf.seek(0)
53
  ckpt = torch.load(buf, map_location="cpu")
54
  buf.close(); gc.collect()
55
 
56
- if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
57
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
 
58
  return ckpt["ids"], ckpt["embeddings"]
59
 
60
  ids_list, emb_tensor = load_embeddings()
61
- @st.cache_resource
62
- def get_st_model():
63
- # force CPU so we avoid the meta-tensor copy error
64
- return SentenceTransformer(
65
- "sentence-transformers/all-MiniLM-L6-v2",
66
- device="cpu",
67
- )
68
  # ────────────────────────────────────────────────────────────────────────
69
- # 3) Streamlit UI
70
  # ────────────────────────────────────────────────────────────────────────
71
  st.title("🌍 CGD Survey Explorer (Live DB)")
72
 
73
  st.sidebar.header("πŸ”Ž Filter Questions")
74
 
 
75
  country_options = sorted(df["country"].dropna().unique())
76
  year_options = sorted(df["year"].dropna().unique())
77
 
78
- sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
79
- sel_years = st.sidebar.multiselect("Select Year(s)", year_options)
80
  keyword = st.sidebar.text_input(
81
  "Keyword Search (Question text / Answer text / Question code)", ""
82
  )
83
  group_by_question = st.sidebar.checkbox("Group by Question Text")
84
 
85
- # --- Semantic-search input (kept in sidebar) ---------------------------
86
  st.sidebar.markdown("---")
87
  st.sidebar.subheader("🧠 Semantic Search")
88
  sem_query = st.sidebar.text_input("Enter a natural-language query")
89
- search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
90
-
91
- # ── base_filtered: applies dropdown + keyword logic (always computed) ──
92
- base_filtered = df[
93
- (df["country"].isin(sel_countries) if sel_countries else True) &
94
- (df["year"].isin(sel_years) if sel_years else True) &
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  (
96
  df["question_text"].str.contains(keyword, case=False, na=False) |
97
  df["answer_text"].str.contains(keyword, case=False, na=False) |
@@ -99,72 +126,43 @@ base_filtered = df[
99
  )
100
  ]
101
 
102
- # ────────────────────────────────────────────────────────────────────────
103
- # 4) When the Search button is clicked β†’ build merged table
104
- # ────────────────────────────────────────────────────────────────────────
105
- if search_clicked:
106
- with st.spinner("Embedding & searching…"):
107
- model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
108
- q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
109
-
110
- model = get_st_model() # ← cached instance
111
- q_vec = model.encode(
112
- sem_query.strip(),
113
- convert_to_tensor=True,
114
- device="cpu"
115
- ).cpu()
116
-
117
- sims = util.cos_sim(q_vec, emb_tensor)[0]
118
- top_vals, top_idx = torch.topk(sims, k=50) # get 50 candidates
119
-
120
- sem_ids = [ids_list[i] for i in top_idx.tolist()]
121
- sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
122
- score_map = dict(zip(sem_ids, top_vals.tolist()))
123
- sem_rows["Score"] = sem_rows["id"].map(score_map)
124
- sem_rows = sem_rows.sort_values("Score", ascending=False)
125
-
126
- remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy()
127
- remainder["Score"] = "" # blank score for keyword-only rows
128
-
129
- combined = pd.concat([sem_rows, remainder], ignore_index=True)
130
-
131
- st.subheader(f"πŸ” Combined Results ({len(combined)})")
132
- st.dataframe(
133
- combined[["Score", "country", "year", "question_text", "answer_text"]],
134
- use_container_width=True,
135
  )
136
 
137
- # ────────────────────────────────────────────────────────────────────────
138
- # 5) No semantic query β†’ use original keyword filter logic / grouping
139
- # ────────────────────────────────────────────────────────────────────────
 
 
140
  else:
141
- if group_by_question:
142
- st.subheader("πŸ“Š Grouped by Question Text")
143
- grouped = (
144
- base_filtered.groupby("question_text")
145
- .agg({
146
- "country": lambda x: sorted(set(x)),
147
- "year": lambda x: sorted(set(x)),
148
- "answer_text": lambda x: list(x)[:3]
149
- })
150
- .reset_index()
151
- .rename(columns={
152
- "country": "Countries",
153
- "year": "Years",
154
- "answer_text": "Sample Answers"
155
- })
156
- )
157
- st.dataframe(grouped, use_container_width=True)
158
- if grouped.empty:
159
- st.info("No questions found with current filters.")
160
- else:
161
- heading = []
162
- if sel_countries: heading.append("Countries: " + ", ".join(sel_countries))
163
- if sel_years: heading.append("Years: " + ", ".join(map(str, sel_years)))
164
- st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years"))
165
- st.dataframe(
166
- base_filtered[["country", "year", "question_text", "answer_text"]],
167
- use_container_width=True,
168
- )
169
- if base_filtered.empty:
170
- st.info("No matching questions found.")
 
 
 
1
  import os, io, json, gc
2
  import streamlit as st
3
  import pandas as pd
 
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
  # ────────────────────────────────────────────────────────────────────────
9
+ # 1) DB credentials (from HF secrets or env) – original
10
  # ────────────────────────────────────────────────────────────────────────
11
+ DB_HOST = os.getenv("DB_HOST")
12
  DB_PORT = os.getenv("DB_PORT", "5432")
13
  DB_NAME = os.getenv("DB_NAME")
14
  DB_USER = os.getenv("DB_USER")
15
  DB_PASSWORD = os.getenv("DB_PASSWORD")
16
 
17
+
18
  @st.cache_data(ttl=600)
19
  def get_data() -> pd.DataFrame:
20
+ try:
21
+ conn = psycopg2.connect(
22
+ host=DB_HOST,
23
+ dbname=DB_NAME,
24
+ user=DB_USER,
25
+ password=DB_PASSWORD,
26
+ sslmode="require",
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ )
29
+ query = """
30
+ SELECT id, country, year, section,
31
+ question_code, question_text,
32
+ answer_code, answer_text
33
+ FROM survey_info;
34
+ """
35
+ df_ = pd.read_sql_query(query, conn)
36
+ conn.close()
37
+ return df_
38
+ except Exception as e:
39
+ st.error(f"Failed to connect to the database: {e}")
40
+ st.stop()
41
+
42
+ df = get_data() # ← original DataFrame
43
+
44
+ # Build a quick lookup row-index β†’ DataFrame row for later
45
  row_lookup = {row.id: i for i, row in df.iterrows()}
46
 
47
  # ────────────────────────────────────────────────────────────────────────
48
+ # 2) Load embeddings + ids once per session (S3) – new, cached
49
  # ────────────────────────────────────────────────────────────────────────
50
  @st.cache_resource
51
  def load_embeddings():
52
+ # credentials already in env (HF secrets) – boto3 will pick them up
53
  BUCKET = "cgd-embeddings-bucket"
54
+ KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
55
  buf = io.BytesIO()
56
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
57
  buf.seek(0)
58
  ckpt = torch.load(buf, map_location="cpu")
59
  buf.close(); gc.collect()
60
 
61
+ if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
62
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
63
+
64
  return ckpt["ids"], ckpt["embeddings"]
65
 
66
  ids_list, emb_tensor = load_embeddings()
67
+
 
 
 
 
 
 
68
  # ────────────────────────────────────────────────────────────────────────
69
+ # 3) Streamlit UI – original filters + new semantic search
70
  # ────────────────────────────────────────────────────────────────────────
71
  st.title("🌍 CGD Survey Explorer (Live DB)")
72
 
73
  st.sidebar.header("πŸ”Ž Filter Questions")
74
 
75
+
76
  country_options = sorted(df["country"].dropna().unique())
77
  year_options = sorted(df["year"].dropna().unique())
78
 
79
+ selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
80
+ selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
81
  keyword = st.sidebar.text_input(
82
  "Keyword Search (Question text / Answer text / Question code)", ""
83
  )
84
  group_by_question = st.sidebar.checkbox("Group by Question Text")
85
 
86
+ # ── new semantic search panel ───────────────────────────────────────────
87
  st.sidebar.markdown("---")
88
  st.sidebar.subheader("🧠 Semantic Search")
89
  sem_query = st.sidebar.text_input("Enter a natural-language query")
90
+ if st.sidebar.button("Search", disabled=not sem_query.strip()):
91
+ with st.spinner("Embedding & searching…"):
92
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
93
+ q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
94
+ scores = util.cos_sim(q_vec, emb_tensor)[0]
95
+ top_vals, top_idx = torch.topk(scores, k=10) # grab extra
96
+
97
+ results = []
98
+ for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
99
+ db_id = ids_list[emb_row]
100
+ if db_id in row_lookup:
101
+ row = df.iloc[row_lookup[db_id]]
102
+ if row["question_text"] and row["answer_text"]:
103
+ results.append({
104
+ "Score": f"{score:.3f}",
105
+ "Country": row["country"],
106
+ "Year": row["year"],
107
+ "Question": row["question_text"],
108
+ "Answer": row["answer_text"],
109
+ })
110
+ if results:
111
+ st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
112
+ st.dataframe(pd.DataFrame(results).head(5))
113
+ else:
114
+ st.info("No semantic matches found.")
115
+
116
+ st.markdown("---")
117
+
118
+ # ── apply original filters ──────────────────────────────────────────────
119
+ filtered = df[
120
+ (df["country"].isin(selected_countries) if selected_countries else True) &
121
+ (df["year"].isin(selected_years) if selected_years else True) &
122
  (
123
  df["question_text"].str.contains(keyword, case=False, na=False) |
124
  df["answer_text"].str.contains(keyword, case=False, na=False) |
 
126
  )
127
  ]
128
 
129
+ # ── original output logic ───────────────────────
130
+ if group_by_question:
131
+ st.subheader("πŸ“Š Grouped by Question Text")
132
+
133
+ grouped = (
134
+ filtered.groupby("question_text")
135
+ .agg({
136
+ "country": lambda x: sorted(set(x)),
137
+ "year": lambda x: sorted(set(x)),
138
+ "answer_text": lambda x: list(x)[:3]
139
+ })
140
+ .reset_index()
141
+ .rename(columns={
142
+ "country": "Countries",
143
+ "year": "Years",
144
+ "answer_text": "Sample Answers"
145
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  )
147
 
148
+ st.dataframe(grouped)
149
+
150
+ if grouped.empty:
151
+ st.info("No questions found with current filters.")
152
+
153
  else:
154
+
155
+ heading_parts = []
156
+ if selected_countries:
157
+ heading_parts.append("Countries: " + ", ".join(selected_countries))
158
+ if selected_years:
159
+ heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
160
+ st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
161
+
162
+
163
+
164
+
165
+ st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
166
+
167
+ if filtered.empty:
168
+ st.info("No matching questions found.")