annikwag's picture
Update app.py
0f39553 verified
raw
history blame
16.8 kB
import streamlit as st
import requests
import pandas as pd
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import load_region_data, get_country_name, get_regions
from appStore.tfidf_extraction import extract_top_keywords
from torch import cuda
import json
from datetime import datetime
#model_config = getconfig("model_params.cfg")
###########
# ToDo move to functions
# Configuration for the dedicated model
DEDICATED_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
DEDICATED_ENDPOINT = "https://qu2d8m6dmsollhly.us-east-1.aws.endpoints.huggingface.cloud"
# Write access token from the settings
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
def get_rag_answer(query, top_results):
"""
Constructs a prompt from the query and the page contexts of the top results,
truncates the context to avoid exceeding the token limit, then sends it to the
dedicated endpoint and returns only the generated answer.
"""
# Combine the context from the top results (adjust the separator as needed)
context = "\n\n".join([res.payload["page_content"] for res in top_results])
# Truncate the context to a maximum number of characters (e.g., 12000 characters)
max_context_chars = 15000
if len(context) > max_context_chars:
context = context[:max_context_chars]
# Build the prompt, instructing the model to only output the final answer.
prompt = (
"Using the following context, answer the question concisely. "
"Only output the final answer below, without repeating the context or question.\n\n"
f"Context:\n{context}\n\n"
f"Question: {query}\n\n"
"Answer:"
)
headers = {"Authorization": f"Bearer {WRITE_ACCESS_TOKEN}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": 150 # Adjust max tokens as needed
}
}
response = requests.post(DEDICATED_ENDPOINT, headers=headers, json=payload)
if response.status_code == 200:
result = response.json()
answer = result[0]["generated_text"]
# If the model returns the full prompt, split and extract only the portion after "Answer:"
if "Answer:" in answer:
answer = answer.split("Answer:")[-1].strip()
return answer
else:
return f"Error in generating answer: {response.text}"
#######
# Helper function: Format project id (e.g., "201940485" -> "2019.4048.5")
def format_project_id(pid):
s = str(pid)
if len(s) > 5:
return s[:4] + "." + s[4:-1] + "." + s[-1]
return s
# Helper function: Compute title from metadata using name.en (or name.de if empty)
def compute_title(metadata):
name_en = metadata.get("name.en", "").strip()
name_de = metadata.get("name.de", "").strip()
base = name_en if name_en else name_de
pid = metadata.get("id", "")
if base and pid:
return f"{base} [{format_project_id(pid)}]"
return base or "No Title"
# Helper function: Get CRS filter options from all documents in the collection
@st.cache_data
def get_crs_options(client, collection_name):
results = hybrid_search(client, "", collection_name)
all_results = results[0] + results[1]
crs_set = set()
for res in all_results:
metadata = res.payload.get('metadata', {})
crs_key = metadata.get("crs_key", "").strip()
crs_value = metadata.get("crs_value", "").strip()
if crs_key or crs_value:
crs_combined = f"{crs_key}: {crs_value}"
crs_set.add(crs_combined)
return sorted(crs_set)
# Update filter_results to also filter by crs_combined.
def filter_results(results, country_filter, region_filter, end_year_range, crs_filter):
filtered = []
for r in results:
metadata = r.payload.get('metadata', {})
countries = metadata.get('countries', "[]")
year_str = metadata.get('end_year')
if year_str:
extracted = extract_year(year_str)
try:
end_year_val = int(extracted) if extracted != "Unknown" else 0
except ValueError:
end_year_val = 0
else:
end_year_val = 0
try:
c_list = json.loads(countries.replace("'", '"'))
c_list = [code.upper() for code in c_list if len(code) == 2]
except json.JSONDecodeError:
c_list = []
selected_iso_code = country_name_mapping.get(country_filter, None)
if region_filter != "All/Not allocated":
countries_in_region = [code for code in c_list if iso_code_to_sub_region.get(code) == region_filter]
else:
countries_in_region = c_list
# Filter by CRS: compute crs_combined and compare to the selected filter.
crs_key = metadata.get("crs_key", "").strip()
crs_value = metadata.get("crs_value", "").strip()
crs_combined = f"{crs_key}: {crs_value}" if (crs_key or crs_value) else ""
if crs_filter != "All/Not allocated" and crs_filter != crs_combined:
continue
if ((country_filter == "All/Not allocated" or selected_iso_code in c_list)
and (region_filter == "All/Not allocated" or countries_in_region)
and (end_year_range[0] <= end_year_val <= end_year_range[1])):
filtered.append(r)
return filtered
#######
# get the device to be used eithe gpu or cpu
device = 'cuda' if cuda.is_available() else 'cpu'
st.set_page_config(page_title="SEARCH IATI",layout='wide')
st.title("GIZ Project Database (PROTOTYPE)")
var = st.text_input("Enter Search Question")
# Load the region lookup CSV
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
#################### Create the embeddings collection and save ######################
# the steps below need to be performed only once and then commented out any unnecssary compute over-run
##### First we process and create the chunks for relvant data source
#chunks = process_giz_worldwide()
##### Convert to langchain documents
#temp_doc = create_documents(chunks,'chunks')
##### Embed and store docs, check if collection exist then you need to update the collection
collection_name = "giz_worldwide"
#hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
################### Hybrid Search #####################################################
client = get_client()
print(client.get_collections())
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
@st.cache_data
def get_country_name_and_region_mapping(_client, collection_name, region_df):
results = hybrid_search(_client, "", collection_name)
country_set = set()
for res in results[0] + results[1]:
countries = res.payload.get('metadata', {}).get('countries', "[]")
try:
country_list = json.loads(countries.replace("'", '"'))
two_digit_codes = [code.upper() for code in country_list if len(code) == 2]
country_set.update(two_digit_codes)
except json.JSONDecodeError:
pass
country_name_to_code = {}
iso_code_to_sub_region = {}
for code in country_set:
name = get_country_name(code, region_df)
sub_region_row = region_df[region_df['alpha-2'] == code]
sub_region = sub_region_row['sub-region'].values[0] if not sub_region_row.empty else "Not allocated"
country_name_to_code[name] = code
iso_code_to_sub_region[code] = sub_region
return country_name_to_code, iso_code_to_sub_region
client = get_client()
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(client, collection_name, region_df)
unique_country_names = sorted(country_name_mapping.keys()) # List of country names
# Layout filters in columns: add a new filter for CRS in col4.
col1, col2, col3, col4 = st.columns([1, 1, 1, 4])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions))
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names if (filtered_country_names := unique_country_names) else unique_country_names)
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider("Project End Year", min_value=2010, max_value=max_end_year, value=(default_start_year, max_end_year))
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options)
# Checkbox to control whether to show only exact matches
show_exact_matches = st.checkbox("Show only exact matches", value=False)
# Run the search
# 1) Adjust limit so we get more than 15 results
results = hybrid_search(client, var, collection_name, limit=500) # e.g., 100 or 200
# results is a tuple: (semantic_results, lexical_results)
semantic_all = results[0]
lexical_all = results[1]
# 2) Filter out content < 20 chars (as intermediate fix to problem that e.g. super short paragraphs with few chars get high similarity score)
semantic_all = [
r for r in semantic_all if len(r.payload["page_content"]) >= 5
]
lexical_all = [
r for r in lexical_all if len(r.payload["page_content"]) >= 5
]
# 2) Apply a threshold to SEMANTIC results (score >= 0.4)
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
filtered_semantic = filter_results(semantic_thresholded, country_filter, region_filter, end_year_range, crs_filter)
filtered_lexical = filter_results(lexical_all, country_filter, region_filter, end_year_range, crs_filter)
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic) # ToDo remove duplicates again?
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
# Define a helper function to format currency values
def format_currency(value):
try:
# Convert to float then int for formatting (assumes whole numbers)
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# Helper function to highlight query matches (case-insensitive)
def highlight_query(text, query):
pattern = re.compile(re.escape(query), re.IGNORECASE)
return pattern.sub(lambda m: f"**{m.group(0)}**", text)
###############################
# Display Lexical Results Branch
###############################
if show_exact_matches:
st.write(f"Showing **Top 15 Lexical Search results** for query: {var}")
query_substring = var.strip().lower()
lexical_substring_filtered = [r for r in lexical_all if query_substring in r.payload["page_content"].lower()]
filtered_lexical = filter_results(lexical_substring_filtered, country_filter, region_filter, end_year_range, crs_filter)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe[:5]
rag_answer = get_rag_answer(var, top_results)
st.markdown("### Generated Answer")
st.write(rag_answer)
st.divider()
for res in top_results:
metadata = res.payload.get('metadata', {})
# Compute new title if not already set
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
# Use new title instead of project_name and highlight query if present
display_title = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
proj_id = metadata.get('id', 'Unknown')
st.markdown(f"#### {display_title} [{proj_id}]")
# Build snippet with potential highlighting
objectives = metadata.get("objectives", "")
desc_de = metadata.get("description.de", "")
desc_en = metadata.get("description.en", "")
description = desc_de if desc_de else desc_en
full_snippet = f"Objective: {objectives} Description: {description}"
words = full_snippet.split()
preview_word_count = 200
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
preview_text = highlight_query(preview_text, var) if var.strip() else preview_text
st.write(preview_text)
if remainder_text:
with st.expander("Show more"):
st.write(remainder_text)
# Keywords
full_text = res.payload['page_content']
top_keywords = extract_top_keywords(full_text, top_n=5)
if top_keywords:
st.markdown(f"_{' · '.join(top_keywords)}_")
# Country info
try:
c_list = json.loads(metadata.get('countries', "[]").replace("'", '"'))
except json.JSONDecodeError:
c_list = []
matched_countries = []
for code in c_list:
if len(code) == 2:
resolved_name = get_country_name(code.upper(), region_df)
if resolved_name.upper() != code.upper():
matched_countries.append(resolved_name)
additional_text = f"Country: **{', '.join(matched_countries) if matched_countries else 'Unknown'}**"
# Add contact info if available and not [email protected]
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f" | Contact: **{contact}**"
st.markdown(additional_text)
st.divider()
###############################
# Display Semantic Results Branch
###############################
else:
st.write(f"Showing **Top 15 Semantic Search results** for query: {var}")
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
top_results = filtered_semantic_no_dupe[:5]
rag_answer = get_rag_answer(var, top_results)
st.markdown("### Generated Answer")
st.write(rag_answer)
st.divider()
for res in top_results:
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
display_title = metadata["title"]
st.markdown(f"#### {display_title} [{metadata.get('id', 'Unknown')}]")
objectives = metadata.get("objectives", "")
desc_de = metadata.get("description.de", "")
desc_en = metadata.get("description.en", "")
description = desc_de if desc_de else desc_en
full_snippet = f"Objective: {objectives} Description: {description}"
words = full_snippet.split()
preview_word_count = 200
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
st.write(preview_text)
if remainder_text:
with st.expander("Show more"):
st.write(remainder_text)
top_keywords = extract_top_keywords(res.payload['page_content'], top_n=5)
if top_keywords:
st.markdown(f"_{' · '.join(top_keywords)}_")
try:
c_list = json.loads(metadata.get('countries', "[]").replace("'", '"'))
except json.JSONDecodeError:
c_list = []
matched_countries = []
for code in c_list:
if len(code) == 2:
resolved_name = get_country_name(code.upper(), region_df)
if resolved_name.upper() != code.upper():
matched_countries.append(resolved_name)
additional_text = f"Country: **{', '.join(matched_countries) if matched_countries else 'Unknown'}**"
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "[email protected]":
additional_text += f" | Contact: **{contact}**"
st.markdown(additional_text)
st.divider()
# for i in results:
# st.subheader(str(i.metadata['id'])+":"+str(i.metadata['title_main']))
# st.caption(f"Status:{str(i.metadata['status'])}, Country:{str(i.metadata['country_name'])}")
# st.write(i.page_content)
# st.divider()