ai_advisor / app.py
rodrigomasini's picture
Duplicate from bondares/ai_advisor
b1f727f
import os
cwd = os.getcwd()
os.environ['PYTORCH_TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/')
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/')
os.environ['HF_HOME'] = os.path.join(cwd, 'huggingface/')
# import sys
import logging
from json import JSONDecodeError
from pathlib import Path
# import zipfile
import pandas as pd
import streamlit as st
from markdown import markdown
from utils import get_backlink, get_pipelines, query, send_feedback, upload_doc
# Adjust to a question that you would like users to see in the search bar when they load the UI:
DEFAULT_QUESTION_AT_STARTUP = os.getenv(
"DEFAULT_QUESTION_AT_STARTUP", "How to get TPS?")
DEFAULT_ANSWER_AT_STARTUP = os.getenv(
"DEFAULT_ANSWER_AT_STARTUP", "You must file a Form I-765")
# Sliders
DEFAULT_DOCS_FROM_RETRIEVER = int(
os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "5"))
DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "1"))
# Whether the file upload should be enabled or not
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD", "True"))
LANG_MAP = {"English": "English", "Ukrainian": "Ukrainian", "russian": "russian"}
pipelines = get_pipelines()
def set_state_if_absent(key, value):
if key not in st.session_state:
st.session_state[key] = value
def main():
st.set_page_config(page_title="AI advisor")
# Persistent state
set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP)
set_state_if_absent("answer", DEFAULT_ANSWER_AT_STARTUP)
set_state_if_absent("results", None)
set_state_if_absent("raw_json", None)
set_state_if_absent("random_question_requested", False)
# Small callback to reset the interface in case the text of the question changes
def reset_results(*args):
st.session_state.answer = None
st.session_state.results = None
st.session_state.raw_json = None
# Title
st.write("# AI Immigration advisor")
# Sidebar
st.sidebar.header("Options")
language = st.sidebar.selectbox(
"Select language: ", ("English", "Ukrainian", "Spanish", "French", "Italian", "Arabic", "Hindi", "Portuguese", "Mandarin Chinese", "Japanese", "russian"))
debug = False
debug = False
# debug = st.sidebar.checkbox("Show debug info")
if debug:
top_k_reader = st.sidebar.slider(
"Max. number of answers",
min_value=1,
max_value=100,
value=DEFAULT_NUMBER_OF_ANSWERS,
step=1,
on_change=reset_results,
)
top_k_retriever = st.sidebar.slider(
"Max. number of documents from retriever",
min_value=1,
max_value=100,
value=DEFAULT_DOCS_FROM_RETRIEVER,
step=1,
on_change=reset_results,
)
else:
top_k_reader = DEFAULT_NUMBER_OF_ANSWERS
top_k_retriever = DEFAULT_DOCS_FROM_RETRIEVER
# File upload block
if not DISABLE_FILE_UPLOAD:
st.sidebar.write("## File Upload:")
data_files = st.sidebar.file_uploader(
"", type=["pdf", "txt", "docx"], accept_multiple_files=True)
for data_file in data_files:
# Upload file
if data_file:
raw_json = upload_doc(data_file)
st.sidebar.write(str(data_file.name) + "    βœ… ")
if debug:
st.subheader("REST API JSON response")
st.sidebar.write(raw_json)
# st.sidebar.markdown(
# f"""
# <style>
# a {{
# text-decoration: none;
# }}
# .haystack-footer {{
# text-align: center;
# }}
# .haystack-footer h4 {{
# margin: 0.1rem;
# padding:0;
# }}
# footer {{
# opacity: 0;
# }}
# </style>
# <div class="haystack-footer">
# <hr />
# <h4>Debug parameters</h4>
# <small>Data crawled from <a href="https://www.uscis.gov">USCIS</a></small></div>
# """,
# unsafe_allow_html=True,
# )
# Search bar
question = st.text_input(
"", value=st.session_state.question, max_chars=100, on_change=reset_results)
col1, col2 = st.columns(2)
col1.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
col2.markdown(
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True)
# Run button
run_pressed = col1.button("Run")
run_query = (
run_pressed or question != st.session_state.question
) and not st.session_state.random_question_requested
# Get results for query
if run_query and question:
reset_results()
st.session_state.question = question
with st.spinner("🧠 &nbsp;&nbsp; Performing neural search on documents... \n "):
try:
st.session_state.results, st.session_state.raw_json = query(
pipelines, question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever, language=language
)
except JSONDecodeError as je:
st.error(
"πŸ‘“ &nbsp;&nbsp; An error occurred reading the results. Is the document store working?")
return
except Exception as e:
logging.exception(e)
if "The server is busy processing requests" in str(e) or "503" in str(e):
st.error(
"πŸ§‘β€πŸŒΎ &nbsp;&nbsp; All our workers are busy! Try again later.")
else:
st.error(
"🐞 &nbsp;&nbsp; An error occurred during the request.")
return
if st.session_state.results:
st.write("## Results:")
for count, result in enumerate(st.session_state.results):
if result["answer"]:
answer, context = result["answer"], result["context"]
start_idx = context.find(answer)
end_idx = start_idx + len(answer)
# Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190
st.write(
markdown(f"**Answer:** {answer}"), unsafe_allow_html=True)
# st.write(
# markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#8ef")) + context[end_idx:]),
# unsafe_allow_html=True,
# )
source = ""
url, title = get_backlink(result)
if url and title:
source = f"[{result['document']['meta']['title']}]({result['document']['meta']['url']})"
else:
source = f"{result['source']}"
st.markdown(f"**Source:** {source}")
else:
st.info(
"πŸ€” &nbsp;&nbsp; Unsure whether any of the documents contain an answer to your question. Try to reformulate it!"
)
st.write("___")
if debug:
st.subheader("REST API JSON response")
st.write(st.session_state.raw_json)
main()