gigiliu12 commited on
Commit
c43bc47
Β·
verified Β·
1 Parent(s): b40e2c2

Update logic

Browse files
Files changed (1) hide show
  1. app.py +92 -103
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os, io, json, gc
2
  import streamlit as st
3
  import pandas as pd
@@ -6,119 +8,84 @@ import boto3, torch
6
  from sentence_transformers import SentenceTransformer, util
7
 
8
  # ────────────────────────────────────────────────────────────────────────
9
- # 1) DB credentials (from HF secrets or env) – original
10
  # ────────────────────────────────────────────────────────────────────────
11
- DB_HOST = os.getenv("DB_HOST")
12
  DB_PORT = os.getenv("DB_PORT", "5432")
13
  DB_NAME = os.getenv("DB_NAME")
14
  DB_USER = os.getenv("DB_USER")
15
  DB_PASSWORD = os.getenv("DB_PASSWORD")
16
 
17
-
18
  @st.cache_data(ttl=600)
19
  def get_data() -> pd.DataFrame:
20
- try:
21
- conn = psycopg2.connect(
22
- host=DB_HOST,
23
- dbname=DB_NAME,
24
- user=DB_USER,
25
- password=DB_PASSWORD,
26
- sslmode="require",
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- )
29
- query = """
30
- SELECT id, country, year, section,
31
- question_code, question_text,
32
- answer_code, answer_text
33
- FROM survey_info;
34
- """
35
- df_ = pd.read_sql_query(query, conn)
36
- conn.close()
37
- return df_
38
- except Exception as e:
39
- st.error(f"Failed to connect to the database: {e}")
40
- st.stop()
41
-
42
- df = get_data() # ← original DataFrame
43
-
44
- # Build a quick lookup row-index β†’ DataFrame row for later
45
  row_lookup = {row.id: i for i, row in df.iterrows()}
46
 
47
  # ────────────────────────────────────────────────────────────────────────
48
- # 2) Load embeddings + ids once per session (S3) – new, cached
49
  # ────────────────────────────────────────────────────────────────────────
50
  @st.cache_resource
51
  def load_embeddings():
52
- # credentials already in env (HF secrets) – boto3 will pick them up
53
  BUCKET = "cgd-embeddings-bucket"
54
- KEY = "survey_info_embeddings.pt" # dict {'ids', 'embeddings'}
55
  buf = io.BytesIO()
56
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
57
  buf.seek(0)
58
  ckpt = torch.load(buf, map_location="cpu")
59
  buf.close(); gc.collect()
60
 
61
- if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
62
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
63
-
64
  return ckpt["ids"], ckpt["embeddings"]
65
 
66
  ids_list, emb_tensor = load_embeddings()
67
 
68
  # ────────────────────────────────────────────────────────────────────────
69
- # 3) Streamlit UI – original filters + new semantic search
70
  # ────────────────────────────────────────────────────────────────────────
71
  st.title("🌍 CGD Survey Explorer (Live DB)")
72
 
73
  st.sidebar.header("πŸ”Ž Filter Questions")
74
 
75
-
76
  country_options = sorted(df["country"].dropna().unique())
77
  year_options = sorted(df["year"].dropna().unique())
78
 
79
- selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
80
- selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
81
  keyword = st.sidebar.text_input(
82
  "Keyword Search (Question text / Answer text / Question code)", ""
83
  )
84
  group_by_question = st.sidebar.checkbox("Group by Question Text")
85
 
86
- # ── new semantic search panel ───────────────────────────────────────────
87
  st.sidebar.markdown("---")
88
  st.sidebar.subheader("🧠 Semantic Search")
89
  sem_query = st.sidebar.text_input("Enter a natural-language query")
90
- if st.sidebar.button("Search", disabled=not sem_query.strip()):
91
- with st.spinner("Embedding & searching…"):
92
- model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
93
- q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
94
- scores = util.cos_sim(q_vec, emb_tensor)[0]
95
- top_vals, top_idx = torch.topk(scores, k=10) # grab extra
96
-
97
- results = []
98
- for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
99
- db_id = ids_list[emb_row]
100
- if db_id in row_lookup:
101
- row = df.iloc[row_lookup[db_id]]
102
- if row["question_text"] and row["answer_text"]:
103
- results.append({
104
- "Score": f"{score:.3f}",
105
- "Country": row["country"],
106
- "Year": row["year"],
107
- "Question": row["question_text"],
108
- "Answer": row["answer_text"],
109
- })
110
- if results:
111
- st.subheader(f"πŸ” Semantic Results ({len(results)} found)")
112
- st.dataframe(pd.DataFrame(results).head(5))
113
- else:
114
- st.info("No semantic matches found.")
115
-
116
- st.markdown("---")
117
-
118
- # ── apply original filters ──────────────────────────────────────────────
119
- filtered = df[
120
- (df["country"].isin(selected_countries) if selected_countries else True) &
121
- (df["year"].isin(selected_years) if selected_years else True) &
122
  (
123
  df["question_text"].str.contains(keyword, case=False, na=False) |
124
  df["answer_text"].str.contains(keyword, case=False, na=False) |
@@ -126,43 +93,65 @@ filtered = df[
126
  )
127
  ]
128
 
129
- # ── original output logic ───────────────────────
130
- if group_by_question:
131
- st.subheader("πŸ“Š Grouped by Question Text")
132
-
133
- grouped = (
134
- filtered.groupby("question_text")
135
- .agg({
136
- "country": lambda x: sorted(set(x)),
137
- "year": lambda x: sorted(set(x)),
138
- "answer_text": lambda x: list(x)[:3]
139
- })
140
- .reset_index()
141
- .rename(columns={
142
- "country": "Countries",
143
- "year": "Years",
144
- "answer_text": "Sample Answers"
145
- })
146
- )
147
-
148
- st.dataframe(grouped)
149
-
150
- if grouped.empty:
151
- st.info("No questions found with current filters.")
152
-
153
- else:
154
 
155
- heading_parts = []
156
- if selected_countries:
157
- heading_parts.append("Countries: " + ", ".join(selected_countries))
158
- if selected_years:
159
- heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
160
- st.markdown("### Results for " + (" | ".join(heading_parts) if heading_parts else "All Countries and Years"))
161
 
 
 
 
 
 
162
 
 
 
163
 
 
164
 
165
- st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
 
 
 
 
166
 
167
- if filtered.empty:
168
- st.info("No matching questions found.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # app.py – CGD Survey Explorer + merged semantic search
3
  import os, io, json, gc
4
  import streamlit as st
5
  import pandas as pd
 
8
  from sentence_transformers import SentenceTransformer, util
9
 
10
  # ────────────────────────────────────────────────────────────────────────
11
+ # 1) Database credentials (provided via HF Secrets / env vars)
12
  # ────────────────────────────────────────────────────────────────────────
13
+ DB_HOST = os.getenv("DB_HOST") # set these in the Space’s Secrets
14
  DB_PORT = os.getenv("DB_PORT", "5432")
15
  DB_NAME = os.getenv("DB_NAME")
16
  DB_USER = os.getenv("DB_USER")
17
  DB_PASSWORD = os.getenv("DB_PASSWORD")
18
 
 
19
  @st.cache_data(ttl=600)
20
  def get_data() -> pd.DataFrame:
21
+ """Pull the full survey_info table (cached for 10 min)."""
22
+ conn = psycopg2.connect(
23
+ host=DB_HOST,
24
+ port=DB_PORT,
25
+ dbname=DB_NAME,
26
+ user=DB_USER,
27
+ password=DB_PASSWORD,
28
+ sslmode="require",
29
+ )
30
+ df_ = pd.read_sql_query(
31
+ """SELECT id, country, year, section,
32
+ question_code, question_text,
33
+ answer_code, answer_text
34
+ FROM survey_info;""",
35
+ conn,
36
+ )
37
+ conn.close()
38
+ return df_
39
 
40
+ df = get_data()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  row_lookup = {row.id: i for i, row in df.iterrows()}
42
 
43
  # ────────────────────────────────────────────────────────────────────────
44
+ # 2) Pre-computed embeddings (ids + tensor) – download once per session
45
  # ────────────────────────────────────────────────────────────────────────
46
  @st.cache_resource
47
  def load_embeddings():
 
48
  BUCKET = "cgd-embeddings-bucket"
49
+ KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
50
  buf = io.BytesIO()
51
  boto3.client("s3").download_fileobj(BUCKET, KEY, buf)
52
  buf.seek(0)
53
  ckpt = torch.load(buf, map_location="cpu")
54
  buf.close(); gc.collect()
55
 
56
+ if not (isinstance(ckpt, dict) and {"ids", "embeddings"} <= ckpt.keys()):
57
  st.error("Bad checkpoint format in survey_info_embeddings.pt"); st.stop()
 
58
  return ckpt["ids"], ckpt["embeddings"]
59
 
60
  ids_list, emb_tensor = load_embeddings()
61
 
62
  # ────────────────────────────────────────────────────────────────────────
63
+ # 3) Streamlit UI
64
  # ────────────────────────────────────────────────────────────────────────
65
  st.title("🌍 CGD Survey Explorer (Live DB)")
66
 
67
  st.sidebar.header("πŸ”Ž Filter Questions")
68
 
 
69
  country_options = sorted(df["country"].dropna().unique())
70
  year_options = sorted(df["year"].dropna().unique())
71
 
72
+ sel_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
73
+ sel_years = st.sidebar.multiselect("Select Year(s)", year_options)
74
  keyword = st.sidebar.text_input(
75
  "Keyword Search (Question text / Answer text / Question code)", ""
76
  )
77
  group_by_question = st.sidebar.checkbox("Group by Question Text")
78
 
79
+ # --- Semantic-search input (kept in sidebar) ---------------------------
80
  st.sidebar.markdown("---")
81
  st.sidebar.subheader("🧠 Semantic Search")
82
  sem_query = st.sidebar.text_input("Enter a natural-language query")
83
+ search_clicked = st.sidebar.button("Search", disabled=not sem_query.strip())
84
+
85
+ # ── base_filtered: applies dropdown + keyword logic (always computed) ──
86
+ base_filtered = df[
87
+ (df["country"].isin(sel_countries) if sel_countries else True) &
88
+ (df["year"].isin(sel_years) if sel_years else True) &
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  (
90
  df["question_text"].str.contains(keyword, case=False, na=False) |
91
  df["answer_text"].str.contains(keyword, case=False, na=False) |
 
93
  )
94
  ]
95
 
96
+ # ────────────────────────────────────────────────────────────────────────
97
+ # 4) When the Search button is clicked β†’ build merged table
98
+ # ────────────────────────────────────────────────────────────────────────
99
+ if search_clicked:
100
+ with st.spinner("Embedding & searching…"):
101
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
102
+ q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
+ sims = util.cos_sim(q_vec, emb_tensor)[0]
105
+ top_vals, top_idx = torch.topk(sims, k=50) # get 50 candidates
 
 
 
 
106
 
107
+ sem_ids = [ids_list[i] for i in top_idx.tolist()]
108
+ sem_rows = df.loc[df["id"].isin(sem_ids)].copy()
109
+ score_map = dict(zip(sem_ids, top_vals.tolist()))
110
+ sem_rows["Score"] = sem_rows["id"].map(score_map)
111
+ sem_rows = sem_rows.sort_values("Score", ascending=False)
112
 
113
+ remainder = base_filtered.loc[~base_filtered["id"].isin(sem_ids)].copy()
114
+ remainder["Score"] = "" # blank score for keyword-only rows
115
 
116
+ combined = pd.concat([sem_rows, remainder], ignore_index=True)
117
 
118
+ st.subheader(f"πŸ” Combined Results ({len(combined)})")
119
+ st.dataframe(
120
+ combined[["Score", "country", "year", "question_text", "answer_text"]],
121
+ use_container_width=True,
122
+ )
123
 
124
+ # ────────────────────────────────────────────────────────────────────────
125
+ # 5) No semantic query β†’ use original keyword filter logic / grouping
126
+ # ────────────────────────────────────────────────────────────────────────
127
+ else:
128
+ if group_by_question:
129
+ st.subheader("πŸ“Š Grouped by Question Text")
130
+ grouped = (
131
+ base_filtered.groupby("question_text")
132
+ .agg({
133
+ "country": lambda x: sorted(set(x)),
134
+ "year": lambda x: sorted(set(x)),
135
+ "answer_text": lambda x: list(x)[:3]
136
+ })
137
+ .reset_index()
138
+ .rename(columns={
139
+ "country": "Countries",
140
+ "year": "Years",
141
+ "answer_text": "Sample Answers"
142
+ })
143
+ )
144
+ st.dataframe(grouped, use_container_width=True)
145
+ if grouped.empty:
146
+ st.info("No questions found with current filters.")
147
+ else:
148
+ heading = []
149
+ if sel_countries: heading.append("Countries: " + ", ".join(sel_countries))
150
+ if sel_years: heading.append("Years: " + ", ".join(map(str, sel_years)))
151
+ st.markdown("### Results for " + (" | ".join(heading) if heading else "All Countries and Years"))
152
+ st.dataframe(
153
+ base_filtered[["country", "year", "question_text", "answer_text"]],
154
+ use_container_width=True,
155
+ )
156
+ if base_filtered.empty:
157
+ st.info("No matching questions found.")