gigiliu12 commited on
Commit
c863afd
Β·
verified Β·
1 Parent(s): 9cb0a2b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -71
app.py CHANGED
@@ -1,104 +1,156 @@
1
-
 
 
2
  import streamlit as st
3
- import pandas as pd
4
- import psycopg2
5
- import os
6
-
7
- # Load DB credentials from Hugging Face secrets or environment variables
8
- DB_HOST = os.getenv("DB_HOST")
9
- DB_PORT = os.getenv("DB_PORT", "5432")
10
- DB_NAME = os.getenv("DB_NAME")
11
- DB_USER = os.getenv("DB_USER")
 
 
 
 
 
 
 
 
 
 
 
12
  DB_PASSWORD = os.getenv("DB_PASSWORD")
13
 
14
  @st.cache_data(ttl=600)
15
- def get_data():
16
- try:
17
- conn = psycopg2.connect(
18
- host=DB_HOST,
19
- port=DB_PORT,
20
- dbname=DB_NAME,
21
- user=DB_USER,
22
- password=DB_PASSWORD,
23
- sslmode="require"
24
-
25
- )
26
- query = "SELECT country, year, section, question_code, question_text, answer_code, answer_text FROM survey_info;"
27
- df = pd.read_sql_query(query, conn)
28
- conn.close()
29
- return df
30
- except Exception as e:
31
- st.error(f"Failed to connect to the database: {e}")
32
- st.stop()
33
-
34
- # Load data
35
- df = get_data()
36
-
37
- # Streamlit UI
38
- st.title("🌍 CGD Survey Explorer (Live DB)")
39
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  st.sidebar.header("πŸ”Ž Filter Questions")
41
 
42
- # Multiselect filters with default = show all
43
- country_options = sorted(df["country"].dropna().unique())
44
- year_options = sorted(df["year"].dropna().unique())
45
 
46
- selected_countries = st.sidebar.multiselect("Select Country/Countries", country_options)
47
- selected_years = st.sidebar.multiselect("Select Year(s)", year_options)
48
- keyword = st.sidebar.text_input(
49
- "Keyword Search (Question text / Answer text / Question code)", ""
50
- ) #NEW
51
- group_by_question = st.sidebar.checkbox("Group by Question Text")
52
 
53
- # Apply filters
54
  filtered = df[
55
- (df["country"].isin(selected_countries) if selected_countries else True) &
56
- (df["year"].isin(selected_years) if selected_years else True) &
57
  (
58
  df["question_text"].str.contains(keyword, case=False, na=False) |
59
- df["answer_text"].str.contains(keyword, case=False, na=False) |
60
- df["question_code"].astype(str).str.contains(keyword, case=False, na=False) # NEW
61
  )
62
  ]
63
 
64
- # Output
65
- if group_by_question:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  st.subheader("πŸ“Š Grouped by Question Text")
67
-
68
  grouped = (
69
  filtered.groupby("question_text")
70
  .agg({
71
  "country": lambda x: sorted(set(x)),
72
- "year": lambda x: sorted(set(x)),
73
- "answer_text": lambda x: list(x)[:3] # preview up to 3 answers
74
  })
75
  .reset_index()
76
  .rename(columns={
77
  "country": "Countries",
78
- "year": "Years",
79
  "answer_text": "Sample Answers"
80
  })
81
  )
82
-
83
  st.dataframe(grouped)
84
-
85
  if grouped.empty:
86
  st.info("No questions found with current filters.")
87
-
88
  else:
89
- # Context-aware heading
90
- heading_parts = []
91
- if selected_countries:
92
- heading_parts.append("Countries: " + ", ".join(selected_countries))
93
- if selected_years:
94
- heading_parts.append("Years: " + ", ".join(map(str, selected_years)))
95
- if heading_parts:
96
- st.markdown("### Results for " + " | ".join(heading_parts))
97
- else:
98
- st.markdown("### Results for All Countries and Years")
99
-
100
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
101
-
102
  if filtered.empty:
103
  st.info("No matching questions found.")
104
-
 
1
+ #!/usr/bin/env python3
2
+ import os, io, json, gc
3
+ import boto3, psycopg2, pandas as pd, torch
4
  import streamlit as st
5
+ from sentence_transformers import SentenceTransformer, util
6
+
7
+ # ────────────────────────────────────────────────────────────────────────
8
+ # 0) Hugging Face secrets β†’ env vars (already set inside Spaces)
9
+ # DB_HOST / DB_PORT / DB_NAME / DB_USER / DB_PASSWORD
10
+ # AWS creds must be in aws_creds.json pushed with the app repo
11
+ # ────────────────────────────────────────────────────────────────────────
12
+ with open("aws_creds.json") as f:
13
+ creds = json.load(f)
14
+ os.environ["AWS_ACCESS_KEY_ID"] = creds["AccessKey"]
15
+ os.environ["AWS_SECRET_ACCESS_KEY"] = creds["SecretAccessKey"]
16
+ os.environ["AWS_DEFAULT_REGION"] = "us-east-2"
17
+
18
+ # ────────────────────────────────────────────────────────────────────────
19
+ # 1) DB β†’ DataFrame (cached 10 min) |
20
+ # ────────────────────────────────────────────────────────────────────────
21
+ DB_HOST = os.getenv("DB_HOST")
22
+ DB_PORT = os.getenv("DB_PORT", "5432")
23
+ DB_NAME = os.getenv("DB_NAME")
24
+ DB_USER = os.getenv("DB_USER")
25
  DB_PASSWORD = os.getenv("DB_PASSWORD")
26
 
27
  @st.cache_data(ttl=600)
28
+ def load_survey_dataframe() -> pd.DataFrame:
29
+ conn = psycopg2.connect(
30
+ host=DB_HOST, port=DB_PORT,
31
+ dbname=DB_NAME, user=DB_USER, password=DB_PASSWORD,
32
+ sslmode="require",
33
+ )
34
+ df = pd.read_sql_query(
35
+ """SELECT id, country, year, section,
36
+ question_code, question_text,
37
+ answer_code, answer_text
38
+ FROM survey_info
39
+ """,
40
+ conn,
41
+ )
42
+ conn.close()
43
+ return df
44
+
45
+ df = load_survey_dataframe()
46
+
47
+ # ────────────────────────────────────────────────────────────────────────
48
+ # 2) S3 β†’ ids + embeddings (cached for session) |
49
+ # ────────────────────────────────────────────────────────────────────────
50
+ @st.cache_resource
51
+ def load_embeddings():
52
+ BUCKET = "cgd-embeddings-bucket"
53
+ KEY = "survey_info_embeddings.pt" # contains {'ids', 'embeddings'}
54
+ bio = io.BytesIO()
55
+ boto3.client("s3").download_fileobj(BUCKET, KEY, bio)
56
+ bio.seek(0)
57
+ ckpt = torch.load(bio, map_location="cpu")
58
+ bio.close(); gc.collect()
59
+ if not (isinstance(ckpt, dict) and {"ids","embeddings"} <= ckpt.keys()):
60
+ st.error("Bad checkpoint format"); st.stop()
61
+ return ckpt["ids"], ckpt["embeddings"]
62
+
63
+ ids_list, emb_tensor = load_embeddings()
64
+
65
+ # build quick lookup from id β†’ row index in DataFrame
66
+ row_lookup = {row_id: i for i, row_id in enumerate(df["id"])}
67
+
68
+ # ────────────────────────────────────────────────────────────────────────
69
+ # 3) Streamlit UI |
70
+ # ────────────────────────────────────────────────────────────────────────
71
+ st.title("🌍 CGD Survey Explorer (Live DB + Semantic Search)")
72
+
73
+ # ── 3a) Sidebar filters (original UI) ────────────���──────────────────────
74
  st.sidebar.header("πŸ”Ž Filter Questions")
75
 
76
+ country_opts = sorted(df["country"].dropna().unique())
77
+ year_opts = sorted(df["year"].dropna().unique())
 
78
 
79
+ sel_countries = st.sidebar.multiselect("Select Country/Countries", country_opts)
80
+ sel_years = st.sidebar.multiselect("Select Year(s)", year_opts)
81
+ keyword = st.sidebar.text_input(
82
+ "Keyword Search (Question / Answer / Code)", ""
83
+ )
84
+ group_by_q = st.sidebar.checkbox("Group by Question Text")
85
 
86
+ # Apply keyword & dropdown filters
87
  filtered = df[
88
+ (df["country"].isin(sel_countries) if sel_countries else True) &
89
+ (df["year"].isin(sel_years) if sel_years else True) &
90
  (
91
  df["question_text"].str.contains(keyword, case=False, na=False) |
92
+ df["answer_text"] .str.contains(keyword, case=False, na=False) |
93
+ df["question_code"].astype(str).str.contains(keyword, case=False, na=False)
94
  )
95
  ]
96
 
97
+ # ── 3b) Semantic-search panel ───────────────────────────────────────────
98
+ st.sidebar.markdown("---")
99
+ st.sidebar.subheader("🧠 Semantic Search")
100
+ sem_query = st.sidebar.text_input("Enter a natural-language query")
101
+ if st.sidebar.button("Search", disabled=not sem_query.strip()):
102
+ with st.spinner("Embedding & searching…"):
103
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
104
+ q_vec = model.encode(sem_query.strip(), convert_to_tensor=True).cpu()
105
+ scores = util.cos_sim(q_vec, emb_tensor)[0]
106
+ top_vals, top_idx = torch.topk(scores, k=10)
107
+ results = []
108
+ for score, emb_row in zip(top_vals.tolist(), top_idx.tolist()):
109
+ db_id = ids_list[emb_row]
110
+ if db_id in row_lookup:
111
+ row = df.iloc[row_lookup[db_id]]
112
+ results.append({
113
+ "score": f"{score:.3f}",
114
+ "country": row["country"],
115
+ "year": row["year"],
116
+ "question": row["question_text"],
117
+ "answer": row["answer_text"],
118
+ })
119
+ if results:
120
+ st.subheader("πŸ” Semantic Results")
121
+ st.write(f"Showing top {len(results)} for **{sem_query}**")
122
+ st.dataframe(pd.DataFrame(results))
123
+ else:
124
+ st.info("No semantic matches found.")
125
+
126
+ st.markdown("---")
127
+
128
+ # ── 3c) Original results table / grouped view ───────────────────────────
129
+ if group_by_q:
130
  st.subheader("πŸ“Š Grouped by Question Text")
 
131
  grouped = (
132
  filtered.groupby("question_text")
133
  .agg({
134
  "country": lambda x: sorted(set(x)),
135
+ "year": lambda x: sorted(set(x)),
136
+ "answer_text": lambda x: list(x)[:3]
137
  })
138
  .reset_index()
139
  .rename(columns={
140
  "country": "Countries",
141
+ "year": "Years",
142
  "answer_text": "Sample Answers"
143
  })
144
  )
 
145
  st.dataframe(grouped)
 
146
  if grouped.empty:
147
  st.info("No questions found with current filters.")
 
148
  else:
149
+ # contextual heading
150
+ hdr = []
151
+ if sel_countries: hdr.append("Countries: " + ", ".join(sel_countries))
152
+ if sel_years: hdr.append("Years: " + ", ".join(map(str, sel_years)))
153
+ st.markdown("### Results for " + (" | ".join(hdr) if hdr else "All Countries and Years"))
 
 
 
 
 
 
154
  st.dataframe(filtered[["country", "year", "question_text", "answer_text"]])
 
155
  if filtered.empty:
156
  st.info("No matching questions found.")