gigiliu12 commited on
Commit
feb2540
Β·
verified Β·
1 Parent(s): ecd8944

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -98
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import os, io, json, gc
2
  import streamlit as st
3
  import pandas as pd
@@ -5,106 +8,106 @@ import psycopg2
5
  import boto3, torch
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
- # ────────────────────────────────────────────────────────────────────────
9
- # 1) DB credentials (from HF secrets or env) – original
10
- # ────────────────────────────────────────────────────────────────────────
11
  DB_HOST = os.getenv("DB_HOST")
12
  DB_PORT = os.getenv("DB_PORT", "5432")
13
  DB_NAME = os.getenv("DB_NAME")
14
  DB_USER = os.getenv("DB_USER")
15
  DB_PASSWORD = os.getenv("DB_PASSWORD")
16
 
17
-
18
  @st.cache_data(ttl=600)
19
  def get_data() -> pd.DataFrame:
20
- try:
21
- conn = psycopg2.connect(
22
- host=DB_HOST,
23
- dbname=DB_NAME,
24
- user=DB_USER,
25
- password=DB_PASSWORD,
26
- sslmode="require",
27
-
28
- )
29
- query = """
30
- SELECT id, country, year, section,
31
- question_code, question_text,
32
- answer_code, answer_text
33
- FROM survey_info;
34
- """
35
- df_ = pd.read_sql_query(query, conn)
36
- conn.close()
37
- return df_
38
- except Exception as e:
39
- st.error(f"Failed to connect to the database: {e}")
40
- st.stop()
41
-
42
- df = get_data() # ← original DataFrame
43
-
44
- # Build a quick lookup row-index β†’ DataFrame row for later
45
  row_lookup = {row.id: i for i, row in df.iterrows()}
46
 
47
- # ────────────────────────────────────────────────────────────────────────
48
- # 2) Load embeddings + ids once per session (S3) – new, cached
49
- # ────────────────────────────────────────────────────────────────────────
50
  @st.cache_resource
51
- def get_st_model():
52
- return SentenceTransformer(
53
- "sentence-transformers/all-MiniLM-L6-v2",
54
- device="cpu",
55
- )
56
  def load_embeddings():
57
- # credentials already in env (HF secrets) – boto3 will pick them up
58
- BUCKET = "cgd-embeddings-bucket"
59
- KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
60
  buf = io.BytesIO()
61
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
62
  buf.seek(0)
63
  ckpt = torch.load(buf, map_location="cpu")
64
  buf.close(); gc.collect()
65
-
66
- if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
67
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
68
-
69
  return ckpt["ids"], ckpt["embeddings"]
70
 
71
  ids_list, emb_tensor = load_embeddings()
72
 
73
- # ────────────────────────────────────────────────────────────────────────
74
- # 3) Streamlit UI – original filters + new semantic search
75
- # ───────────────────────────────────────────────────────────���────────────
 
 
 
 
 
 
 
 
76
  st.title("🌍 CGD Survey Explorer (Live DB)")
77
 
78
  st.sidebar.header("πŸ”Ž Filter Questions")
 
 
79
 
 
 
 
 
80
 
81
- country_options = sorted(df["country"].dropna().unique())
82
- year_options = sorted(df["year"].dropna().unique())
83
-
84
- selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
85
- selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
86
- keyword = st.sidebar.text_input(
87
- "Keyword Search (Question text / Answer text / Question code)", ""
88
- )
89
- group_by_question = st.sidebar.checkbox("Group by Question Text")
90
-
91
- # ── new semantic search panel ───────────────────────────────────────────
92
  st.sidebar.markdown("---")
93
  st.sidebar.subheader("🧠 Semantic Search")
94
- sem_query = st.sidebar.text_input("Enter a natural-language query")
95
- if st.sidebar.button("Search", disabled=not sem_query.strip()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  with st.spinner("Embedding & searching…"):
97
- # 1) embed query
98
- model = get_st_model() # cached CPU model
99
  q_vec = model.encode(
100
  sem_query.strip(),
101
  convert_to_tensor=True,
102
  device="cpu"
103
  ).cpu()
104
 
105
- # 2) semantic similarity
106
  sims = util.cos_sim(q_vec, emb_tensor)[0]
107
- top_vals, top_idx = torch.topk(sims, k=50)
108
 
109
  sem_ids = [ids_list[i] for i in top_idx.tolist()]
110
  sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
@@ -112,9 +115,9 @@ if st.sidebar.button("Search", disabled=not sem_query.strip()):
112
  sem_rows["Score"] = sem_rows["id"].map(score_map)
113
  sem_rows = sem_rows.sort_values("Score", ascending=False)
114
 
115
- # 3) keyword / dropdown remainder
116
  remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy()
117
- remainder["Score"] = "" # blank for keyword-only rows
118
 
119
  combined = pd.concat([sem_rows, remainder], ignore_index=True)
120
 
@@ -123,23 +126,13 @@ if st.sidebar.button("Search", disabled=not sem_query.strip()):
123
  combined[["Score", "country", "year", "question_text", "answer_text"]],
124
  use_container_width=True,
125
  )
126
- st.stop() # skip the old display logic below when semantic search ran
127
-
128
- # ── apply original filters ──────────────────────────────────────────────
129
- filtered = df[
130
- (df["country"].isin(selected_countries) if selected_countries else True) &
131
- (df["year"].isin(selected_years) if selected_years else True) &
132
- (
133
- df["question_text"].str.contains(keyword, case=False, na=False) |
134
- df["answer_text"].str.contains(keyword, case=False, na=False) |
135
- df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
136
- )
137
- ]
138
 
139
- # ── original output logic ───────────────────────
140
- if group_by_question:
 
 
141
  st.subheader("πŸ“Š Grouped by Question Text")
142
-
143
  grouped = (
144
  filtered.groupby("question_text")
145
  .agg({
@@ -154,25 +147,17 @@ if group_by_question:
154
  "answer_text": "Sample Answers"
155
  })
156
  )
157
-
158
- st.dataframe(grouped)
159
-
160
  if grouped.empty:
161
  st.info("No questions found with current filters.")
162
-
163
  else:
164
-
165
- heading_parts = []
166
- if selected_countries:
167
- heading_parts.append("Countries: " + ", ".join(selected_countries))
168
- if selected_years:
169
- heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
170
- st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
171
-
172
-
173
-
174
-
175
- st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
176
-
177
  if filtered.empty:
178
- st.info("No matching questions found.")
 
1
+ #!/usr/bin/env python3
2
+ # app.py – CGD Survey Explorer (keyword + semantic in one table)
3
+
4
  import os, io, json, gc
5
  import streamlit as st
6
  import pandas as pd
 
8
  import boto3, torch
9
  from sentence_transformers import SentenceTransformer, util
10
 
11
+ # ─────────────────────────────────────────────────────────────
12
+ # 1) Database credentials (HF Secrets or env vars)
13
+ # ─────────────────────────────────────────────────────────────
14
  DB_HOST = os.getenv("DB_HOST")
15
  DB_PORT = os.getenv("DB_PORT", "5432")
16
  DB_NAME = os.getenv("DB_NAME")
17
  DB_USER = os.getenv("DB_USER")
18
  DB_PASSWORD = os.getenv("DB_PASSWORD")
19
 
 
20
  @st.cache_data(ttl=600)
21
  def get_data() -> pd.DataFrame:
22
+ """Read survey_info once every 10 min."""
23
+ conn = psycopg2.connect(
24
+ host=DB_HOST, port=DB_PORT,
25
+ dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
26
+ sslmode="require",
27
+ )
28
+ df_ = pd.read_sql_query("""
29
+ SELECT id, country, year, section,
30
+ question_code, question_text,
31
+ answer_code, answer_text
32
+ FROM survey_info;
33
+ """, conn)
34
+ conn.close()
35
+ return df_
36
+
37
+ df = get_data()
 
 
 
 
 
 
 
 
 
38
  row_lookup = {row.id: i for i, row in df.iterrows()}
39
 
40
+ # ─────────────────────────────────────────────────────────────
41
+ # 2) Cached resources
42
+ # ─────────────────────────────────────────────────────────────
43
  @st.cache_resource
 
 
 
 
 
44
  def load_embeddings():
45
+ """Download ids + embedding tensor from S3 once per session."""
46
+ BUCKET, KEY = "cgd-embeddings-bucket", "survey_info_embeddings.pt"
 
47
  buf = io.BytesIO()
48
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
49
  buf.seek(0)
50
  ckpt = torch.load(buf, map_location="cpu")
51
  buf.close(); gc.collect()
52
+ if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
 
53
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
 
54
  return ckpt["ids"], ckpt["embeddings"]
55
 
56
  ids_list, emb_tensor = load_embeddings()
57
 
58
+ @st.cache_resource
59
+ def get_st_model():
60
+ """Mini-LM sentence-transformer pinned to CPU (avoids meta-tensor bug)."""
61
+ return SentenceTransformer(
62
+ "sentence-transformers/all-MiniLM-L6-v2",
63
+ device="cpu",
64
+ )
65
+
66
+ # ─────────────────────────────────────────────────────────────
67
+ # 3) Streamlit UI
68
+ # ─────────────────────────────────────────────────────────────
69
  st.title("🌍 CGD Survey Explorer (Live DB)")
70
 
71
  st.sidebar.header("πŸ”Ž Filter Questions")
72
+ country_opts = sorted(df["country"].dropna().unique())
73
+ year_opts = sorted(df["year"].dropna().unique())
74
 
75
+ sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
76
+ sel_years = st.sidebar.multiselect("Select Year(s)", year_opts)
77
+ keyword = st.sidebar.text_input("Keyword Search (Question / Answer / Code)")
78
+ group_by_q = st.sidebar.checkbox("Group by Question Text")
79
 
80
+ # ── Semantic search panel
 
 
 
 
 
 
 
 
 
 
81
  st.sidebar.markdown("---")
82
  st.sidebar.subheader("🧠 Semantic Search")
83
+ sem_query = st.sidebar.text_input("Enter a natural-language query")
84
+ search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
85
+
86
+ # ── Always build the keyword/dropdown subset
87
+ filtered = df[
88
+ (df["country"].isin(sel_countries) if sel_countries else True) &
89
+ (df["year"].isin(sel_years) if sel_years else True) &
90
+ (
91
+ df["question_text"].str.contains(keyword, case=False, na=False) |
92
+ df["answer_text"].str.contains(keyword, case=False, na=False) |
93
+ df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
94
+ )
95
+ ]
96
+
97
+ # ─────────────────────────────────────────────────────────────
98
+ # 4) Semantic Search β†’ merged table
99
+ # ─────────────────────────────────────────────────────────────
100
+ if search_clicked:
101
  with st.spinner("Embedding & searching…"):
102
+ model = get_st_model()
 
103
  q_vec = model.encode(
104
  sem_query.strip(),
105
  convert_to_tensor=True,
106
  device="cpu"
107
  ).cpu()
108
 
 
109
  sims = util.cos_sim(q_vec, emb_tensor)[0]
110
+ top_vals, top_idx = torch.topk(sims, k=50) # 50 candidates
111
 
112
  sem_ids = [ids_list[i] for i in top_idx.tolist()]
113
  sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
 
115
  sem_rows["Score"] = sem_rows["id"].map(score_map)
116
  sem_rows = sem_rows.sort_values("Score", ascending=False)
117
 
118
+ # rows that matched keyword/dropdown but not semantic
119
  remainder = filtered.loc[~filtered["id"].isin(sem_ids)].copy()
120
+ remainder["Score"] = "" # blank score
121
 
122
  combined = pd.concat([sem_rows, remainder], ignore_index=True)
123
 
 
126
  combined[["Score", "country", "year", "question_text", "answer_text"]],
127
  use_container_width=True,
128
  )
129
+ st.stop() # skip original display logic below when semantic ran
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ # ─────────────────────────────────────────────────────────────
132
+ # 5) Original display (keyword / filters only)
133
+ # ─────────────────────────────────────────────────────────────
134
+ if group_by_q:
135
  st.subheader("πŸ“Š Grouped by Question Text")
 
136
  grouped = (
137
  filtered.groupby("question_text")
138
  .agg({
 
147
  "answer_text": "Sample Answers"
148
  })
149
  )
150
+ st.dataframe(grouped, use_container_width=True)
 
 
151
  if grouped.empty:
152
  st.info("No questions found with current filters.")
 
153
  else:
154
+ hdr = []
155
+ if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
156
+ if sel_years: hdr.append("Years: " + ", ".join(map(str, sel_years)))
157
+ st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
158
+ st.dataframe(
159
+ filtered[["country", "year", "question_text", "answer_text"]],
160
+ use_container_width=True,
161
+ )
 
 
 
 
 
162
  if filtered.empty:
163
+ st.info("No matching questions found.")