import streamlit as st
import requests
import pandas as pd
import re
import json
import configparser
from datetime import datetime
from torch import cuda
# Import existing modules from appStore
from appStore.prep_data import process_giz_worldwide, remove_duplicates, get_max_end_year, extract_year
from appStore.prep_utils import create_documents, get_client
from appStore.embed import hybrid_embed_chunks
from appStore.search import hybrid_search
from appStore.region_utils import (
load_region_data,
clean_country_code,
get_country_name,
get_regions,
get_country_name_and_region_mapping
)
# TF-IDF part (excluded from the app for now)
# from appStore.tfidf_extraction import extract_top_keywords
# Import helper modules
from appStore.rag_utils import (
highlight_query,
get_rag_answer,
compute_title
)
from appStore.filter_utils import (
parse_budget,
filter_results,
get_crs_options
)
from appStore.crs_utils import lookup_crs_value
###########################################
# Model Config
###########################################
config = configparser.ConfigParser()
config.read('model_params.cfg')
DEDICATED_MODEL = config.get('MODEL', 'DEDICATED_MODEL')
DEDICATED_ENDPOINT = config.get('MODEL', 'DEDICATED_ENDPOINT')
WRITE_ACCESS_TOKEN = st.secrets["Llama_3_1"]
st.set_page_config(page_title="SEARCH IATI", layout='wide')
###########################################
# Cache the project data
###########################################
@st.cache_data
def load_project_data():
"""
Load and process the GIZ worldwide data, returning a pandas DataFrame.
"""
return process_giz_worldwide()
project_data = load_project_data()
# Determine min and max budgets in million euros
budget_series = pd.to_numeric(project_data['total_project'], errors='coerce').dropna()
min_budget_val = float(budget_series.min() / 1e6)
max_budget_val = float(budget_series.max() / 1e6)
###########################################
# Prepare region data
###########################################
region_lookup_path = "docStore/regions_lookup.csv"
region_df = load_region_data(region_lookup_path)
###########################################
# Get device
###########################################
device = 'cuda' if cuda.is_available() else 'cpu'
###########################################
# Streamlit App Layout
###########################################
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown("
GIZ Project Search (PROTOTYPE)
", unsafe_allow_html=True)
with col_about:
with st.expander("ℹ️ About"):
st.markdown(
"""
This app is a prototype for testing purposes using publicly available project data
from the German International Cooperation Society (GIZ) as of 23rd February 2025.
**Please do NOT enter sensitive or personal information.**
**Note**: The answers are AI-generated and may be wrong or misleading.
""", unsafe_allow_html=True
)
# Main query input (with a key so we can reset it)
var = st.text_input("Enter Question", key="query")
###########################################
# Create or load the embeddings collection
###########################################
collection_name = "giz_worldwide"
client = get_client()
print(client.get_collections())
# Uncomment if needed:
# chunks = process_giz_worldwide()
# temp_doc = create_documents(chunks, 'chunks')
# hybrid_embed_chunks(docs=temp_doc, collection_name=collection_name, del_if_exists=True)
max_end_year = get_max_end_year(client, collection_name)
_, unique_sub_regions = get_regions(region_df)
# Build country->code and code->region mapping
country_name_mapping, iso_code_to_sub_region = get_country_name_and_region_mapping(
client,
collection_name,
region_df,
hybrid_search,
clean_country_code,
get_country_name
)
unique_country_names = sorted(country_name_mapping.keys())
###########################################
# Define reset_filters function using session_state
###########################################
def reset_filters():
st.session_state["region_filter"] = "All/Not allocated"
st.session_state["country_filter"] = "All/Not allocated"
current_year = datetime.now().year
default_start_year = current_year - 4
st.session_state["end_year_range"] = (default_start_year, max_end_year)
st.session_state["crs_filter"] = "All/Not allocated"
st.session_state["min_budget"] = min_budget_val
st.session_state["client_filter"] = "All/Not allocated"
st.session_state["query"] = ""
st.session_state["show_exact_matches"] = False
st.session_state["page"] = 1
###########################################
# Filter Controls - Row 1
###########################################
col1, col2, col3, col4, col5 = st.columns([1, 1, 1, 1, 1])
with col1:
region_filter = st.selectbox("Region", ["All/Not allocated"] + sorted(unique_sub_regions), key="region_filter")
if region_filter == "All/Not allocated":
filtered_country_names = unique_country_names
else:
filtered_country_names = [
name for name, code in country_name_mapping.items()
if iso_code_to_sub_region.get(code) == region_filter
]
with col2:
country_filter = st.selectbox("Country", ["All/Not allocated"] + filtered_country_names, key="country_filter")
with col3:
current_year = datetime.now().year
default_start_year = current_year - 4
end_year_range = st.slider(
"Project End Year",
min_value=2010,
max_value=max_end_year,
value=(default_start_year, max_end_year),
key="end_year_range"
)
with col4:
crs_options = ["All/Not allocated"] + get_crs_options(client, collection_name)
crs_filter = st.selectbox("CRS", crs_options, key="crs_filter")
with col5:
min_budget = st.slider(
"Minimum Project Budget (Million €)",
min_value=min_budget_val,
max_value=max_budget_val,
value=min_budget_val,
key="min_budget"
)
###########################################
# Filter Controls - Row 2 (Additional Filters)
###########################################
col1_2, col2_2, col3_2, col4_2, col5_2 = st.columns(5)
with col1_2:
client_options = sorted(project_data["client"].dropna().unique().tolist())
client_filter = st.selectbox("Client", ["All/Not allocated"] + client_options, key="client_filter")
with col2_2:
st.empty()
with col3_2:
st.empty()
with col4_2:
st.empty()
with col5_2:
# Plain reset button (will be moved to row 3 as well)
st.button("Reset Filters", on_click=reset_filters, key="reset_button_row2")
###########################################
# Filter Controls - Row 3 (Remaining Filter)
###########################################
col1_3, col2_3, col3_3, col4_3, col5_3 = st.columns(5)
with col1_3:
# Place the "Show only exact matches" checkbox here
show_exact_matches = st.checkbox("Show only exact matches", key="show_exact_matches")
with col2_3:
st.empty()
with col3_3:
st.empty()
with col4_3:
st.empty()
with col5_3:
# Right-align a more prominent reset button
with st.container():
st.markdown("", unsafe_allow_html=True)
if st.button("**Reset Filters**", key="reset_button_row3"):
reset_filters()
st.markdown("
", unsafe_allow_html=True)
###########################################
# Main Search / Results
###########################################
if not var.strip():
st.info("Please enter a question to see results.")
else:
# 1) Perform hybrid search
results = hybrid_search(client, var, collection_name, limit=500)
semantic_all, lexical_all = results[0], results[1]
# Filter out short pages
semantic_all = [r for r in semantic_all if len(r.payload["page_content"]) >= 5]
lexical_all = [r for r in lexical_all if len(r.payload["page_content"]) >= 5]
# Apply threshold to semantic results if desired
semantic_thresholded = [r for r in semantic_all if r.score >= 0.0]
# 2) Filter results based on the user’s selections
filtered_semantic = filter_results(
semantic_thresholded,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
filtered_lexical = filter_results(
lexical_all,
country_filter,
region_filter,
end_year_range,
crs_filter,
min_budget,
region_df,
iso_code_to_sub_region,
clean_country_code,
get_country_name
)
# Additional filter by client
if client_filter != "All/Not allocated":
filtered_semantic = [r for r in filtered_semantic if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
filtered_lexical = [r for r in filtered_lexical if r.payload.get("metadata", {}).get("client", "Unknown Client") == client_filter]
# Remove duplicates
filtered_semantic_no_dupe = remove_duplicates(filtered_semantic)
filtered_lexical_no_dupe = remove_duplicates(filtered_lexical)
def format_currency(value):
try:
return f"€{int(float(value)):,}"
except (ValueError, TypeError):
return value
# --- Reprint Query (Right Aligned with "Query:") ---
st.markdown(f"Query: {var}
", unsafe_allow_html=True)
# 3) Display results
# Lexical Search Results Branch
if show_exact_matches:
st.write("Showing **Top Lexical Search results**")
query_substring = var.strip().lower()
lexical_substring_filtered = [
r for r in filtered_lexical
if query_substring in r.payload["page_content"].lower()
]
filtered_lexical_no_dupe = remove_duplicates(lexical_substring_filtered)
if not filtered_lexical_no_dupe:
st.write('No exact matches, consider unchecking "Show only exact matches"')
else:
top_results = filtered_lexical_no_dupe # Show all matching lexical results
# --- Pagination (Above Lexical Results) ---
page_size = 15
total_results = len(top_results)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Top pagination widget (right aligned, 1/7 width)
col_pag_top = st.columns([6, 1])[1]
new_page_top = col_pag_top.selectbox("Select Page", list(range(1, total_pages + 1)), index=current_page - 1, key="page_top")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
paged_results = top_results[start_index:end_index]
for i, res in enumerate(paged_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_html = highlight_query(metadata["title"], var) if var.strip() else metadata["title"]
title_clean = re.sub(r'|', '', title_html)
# Prepend the result number
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
objective = metadata.get("objective", "None")
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
# Insert Predecessor/Successor line if available
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
extra_line = ""
if predecessor:
extra_line += f"
**Predecessor Project:** {predecessor}"
if successor:
extra_line += f"
**Successor Project:** {successor}"
additional_text = (
f"**Objective:** {highlight_query(objective, var)}
"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}
"
f"**Projekt duration:** {start_year_str}-{end_year_str}
"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}
"
+ extra_line +
f"
**Country:** {country_raw}
"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "transparenz@giz.de":
additional_text += f"
**Contact:** xxx@giz.de"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget
col_pag_bot = st.columns([6, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)), index=st.session_state.page - 1, key="page_bot")
st.session_state.page = new_page_bot
# Semantic Search Results Branch
else:
if not filtered_semantic_no_dupe:
st.write("No relevant results found.")
else:
page_size = 15
total_results = len(filtered_semantic_no_dupe)
total_pages = (total_results - 1) // page_size + 1
if "page" not in st.session_state:
st.session_state.page = 1
current_page = st.session_state.page
# Top pagination widget (right aligned, 1/7 width)
col_pag_top = st.columns([6, 1])[1]
new_page_top = col_pag_top.selectbox("Select Page", list(range(1, total_pages + 1)), index=current_page - 1, key="page_top_sem")
st.session_state.page = new_page_top
start_index = (st.session_state.page - 1) * page_size
end_index = start_index + page_size
top_results = filtered_semantic_no_dupe[start_index:end_index]
# Prominent page info with bold numbers and green highlight if current page is not 1
page_num = f"{st.session_state.page}" if st.session_state.page != 1 else f"{st.session_state.page}"
total_pages_str = f"{total_pages}"
st.markdown(f"Showing **{len(top_results)}** Semantic Search results (Page {page_num} of {total_pages_str})", unsafe_allow_html=True)
# --- RAG Answer (Right aligned, bullet points, bold numbers) ---
rag_answer = get_rag_answer(var, top_results, DEDICATED_ENDPOINT, WRITE_ACCESS_TOKEN)
bullet_lines = []
for line in rag_answer.splitlines():
if line.strip():
# Bold any numbers in the line
line_bold = re.sub(r'(\d+)', r'\1', line)
bullet_lines.append(f"{line_bold}")
formatted_rag_answer = "" + "".join(bullet_lines) + "
"
st.markdown(formatted_rag_answer, unsafe_allow_html=True)
st.divider()
for i, res in enumerate(top_results, start=start_index+1):
metadata = res.payload.get('metadata', {})
if "title" not in metadata:
metadata["title"] = compute_title(metadata)
title_clean = re.sub(r'|', '', metadata["title"])
# Prepend result number and make title bold
st.markdown(f"#### {i}. **{title_clean}**", unsafe_allow_html=True)
desc_en = metadata.get("description.en", "").strip()
desc_de = metadata.get("description.de", "").strip()
description = desc_en if desc_en else desc_de
if not description:
description = "No project description available"
words = description.split()
preview_word_count = 90
preview_text = " ".join(words[:preview_word_count])
remainder_text = " ".join(words[preview_word_count:])
col_left, col_right = st.columns(2)
with col_left:
st.markdown(highlight_query(preview_text, var), unsafe_allow_html=True)
if remainder_text:
with st.expander("Show more"):
st.markdown(highlight_query(remainder_text, var), unsafe_allow_html=True)
with col_right:
start_year_str = extract_year(metadata.get('start_year', None)) or "Unknown"
end_year_str = extract_year(metadata.get('end_year', None)) or "Unknown"
total_project = metadata.get('total_project', "Unknown")
total_volume = metadata.get('total_volume', "Unknown")
formatted_project_budget = format_currency(total_project)
formatted_total_volume = format_currency(total_volume)
country_raw = metadata.get('country', "Unknown")
crs_key = metadata.get("crs_key", "").strip()
crs_key_clean = re.sub(r'\.0$', '', str(crs_key))
new_crs_value = lookup_crs_value(crs_key_clean)
new_crs_value_clean = re.sub(r'\.0$', '', str(new_crs_value))
crs_combined = f"{crs_key_clean}: {new_crs_value_clean}" if crs_key_clean else "Unknown"
predecessor = metadata.get("predecessor_id", "").strip()
successor = metadata.get("successor_id", "").strip()
extra_line = ""
if predecessor:
extra_line += f"
**Predecessor Project:** {predecessor}"
if successor:
extra_line += f"
**Successor Project:** {successor}"
additional_text = (
f"**Objective:** {metadata.get('objective', '')}
"
f"**Commissioned by:** {metadata.get('client', 'Unknown Client')}
"
f"**Projekt duration:** {start_year_str}-{end_year_str}
"
f"**Budget:** Project: {formatted_project_budget}, Total volume: {formatted_total_volume}
"
+ extra_line +
f"
**Country:** {country_raw}
"
f"**Sector:** {crs_combined}"
)
contact = metadata.get("contact", "").strip()
if contact and contact.lower() != "transparenz@giz.de":
additional_text += f"
**Contact:** xxx@giz.de"
st.markdown(additional_text, unsafe_allow_html=True)
st.divider()
# Bottom pagination widget (right aligned, 1/7 width)
col_pag_bot = st.columns([6, 1])[1]
new_page_bot = col_pag_bot.selectbox("Select Page", list(range(1, total_pages + 1)), index=st.session_state.page - 1, key="page_bot_sem")
st.session_state.page = new_page_bot