ppsingh's picture
soruces and metadata filtering
f1afeff
raw
history blame
4.74 kB
import streamlit as st
from utils.retriever import retrieve_paragraphs
from utils.generator import build_messages, _call_llm
from utils.utils import meetings_list, countries_list, projects_list
import ast
import time
import asyncio
import re
import logging
logging.basicConfig(level=logging.INFO)
########### Function for getting response #######################
def chat_response(query, filter_metadata=None):
"""Generate chat response based on method and inputs"""
try:
retrieved_paragraphs = retrieve_paragraphs(query, filter_metadata=filter_metadata)
context_retrieved = ast.literal_eval(retrieved_paragraphs)
# Build list of only content, no metadata
context_retrieved_formatted = "||".join(doc['answer'] for doc in context_retrieved)
context_retrieved_lst = [doc['answer'] for doc in context_retrieved]
logging.info("Context Retrieval done")
messages = build_messages(query, context_retrieved_lst)
answer = asyncio.run(_call_llm(messages))
return answer, context_retrieved
except Exception as e:
error_message = f"Error processing request: {str(e)}"
return error_message
############## UI related functions #####################
def reset_page():
"""
Reset pagination back to the first page; used as on_change callback.
"""
st.session_state["page"] = 1
def contruct_metadata_filter():
filter_metadata = {}
if st.session_state['meetings_filter'] != 'All':
filter_metadata['meeting_id'] = st.session_state['meetings_filter']
## need to change the filter for coutnry and project tolist
if st.session_state['country_filter'] != 'All':
filter_metadata['Countries'] = st.session_state['country_filter']
if st.session_state['project_filter'] != 'All':
filter_metadata['Projects'] = st.session_state['project_filter']
return filter_metadata
def render_sources(chunks, query):
# 11.7. Render each result chunk
st.write("Sources")
st.write("======================================")
start_idx = 0
for idx, doc in enumerate(chunks, start=start_idx + 1):
meta = doc.get('answer_metadata', {})
title = meta.get('Decision Number', 'Unknown Project')
agencies = meta.get('Agencies', 'Unknown Agencies')
country = meta.get('country', 'Unknown Country')
snippet = doc.get('answer', '')
preview = snippet.split(maxsplit=90)[:90]
remainder = snippet[len(" ".join(preview)):]
# Title + metadata
st.markdown(f"#### {idx}. {title}", unsafe_allow_html=True)
st.markdown(f"**Agencies:** {agencies} | **Country:** {country}")
# Snippet + optional expander
st.markdown(" ".join(preview), unsafe_allow_html=True)
if remainder:
with st.expander("Show more"):
st.markdown(remainder, unsafe_allow_html=True)
st.divider()
st.set_page_config(page_title="Montreal AI Decisions (MVP)")
for key in ('meetings_filter', 'country_filter', 'project_filter'):
if key not in st.session_state:
st.session_state[key] = 'All'
if 'page' not in st.session_state:
st.session_state['page'] = 1
col_title, col_about = st.columns([8, 2])
with col_title:
st.markdown(
"<h1 style='text-align:center;'> Montreal AI Decisions (MVP)</h1>",
unsafe_allow_html=True
)
# 10.1. Question input
query = st.text_input(
label="Enter your question:",
key="query",
on_change = reset_page
)
# 10.2. Filter widgets
col1, col2, col3, col4 = st.columns(4)
with col1:
meetings = sorted(meetings_list)
st.selectbox(
"Meeting",
options=['All'] + meetings,
key='meetings_filter',
on_change=reset_page
)
with col2:
countries = sorted(countries_list)
st.selectbox(
"Country",
options=['All'] + countries,
key='country_filter',
on_change=reset_page
)
with col3:
projects = sorted(projects_list)
st.selectbox(
"Projects",
options=['All'] + projects,
key='project_filter',
on_change=reset_page
)
# Only run search & display if user has entered something
if not query.strip():
st.info("Please enter a question to see results.")
st.stop()
else:
filter_metadata = contruct_metadata_filter()
if filter_metadata:
answer, context_retrieved = chat_response(query, filter_metadata)
st.write(answer)
render_sources(context_retrieved, query)
else:
answer, context_retrieved = chat_response(query)
st.write(answer)
render_sources(context_retrieved, query)