Spaces:
Build error
Build error
import os | |
cwd = os.getcwd() | |
os.environ['PYTORCH_TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') | |
os.environ['TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') | |
os.environ['HF_HOME'] = os.path.join(cwd, 'huggingface/') | |
# import sys | |
import logging | |
from json import JSONDecodeError | |
from pathlib import Path | |
# import zipfile | |
import pandas as pd | |
import streamlit as st | |
from markdown import markdown | |
from utils import get_backlink, get_pipelines, query, send_feedback, upload_doc | |
# Adjust to a question that you would like users to see in the search bar when they load the UI: | |
DEFAULT_QUESTION_AT_STARTUP = os.getenv( | |
"DEFAULT_QUESTION_AT_STARTUP", "How to get TPS?") | |
DEFAULT_ANSWER_AT_STARTUP = os.getenv( | |
"DEFAULT_ANSWER_AT_STARTUP", "You must file a Form I-765") | |
# Sliders | |
DEFAULT_DOCS_FROM_RETRIEVER = int( | |
os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "5")) | |
DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "1")) | |
# Whether the file upload should be enabled or not | |
DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD", "True")) | |
LANG_MAP = {"English": "English", "Ukrainian": "Ukrainian", "russian": "russian"} | |
pipelines = get_pipelines() | |
def set_state_if_absent(key, value): | |
if key not in st.session_state: | |
st.session_state[key] = value | |
def main(): | |
st.set_page_config(page_title="AI advisor") | |
# Persistent state | |
set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP) | |
set_state_if_absent("answer", DEFAULT_ANSWER_AT_STARTUP) | |
set_state_if_absent("results", None) | |
set_state_if_absent("raw_json", None) | |
set_state_if_absent("random_question_requested", False) | |
# Small callback to reset the interface in case the text of the question changes | |
def reset_results(*args): | |
st.session_state.answer = None | |
st.session_state.results = None | |
st.session_state.raw_json = None | |
# Title | |
st.write("# AI Immigration advisor") | |
# Sidebar | |
st.sidebar.header("Options") | |
language = st.sidebar.selectbox( | |
"Select language: ", ("English", "Ukrainian", "Spanish", "French", "Italian", "Arabic", "Hindi", "Portuguese", "Mandarin Chinese", "Japanese", "russian")) | |
debug = False | |
debug = False | |
# debug = st.sidebar.checkbox("Show debug info") | |
if debug: | |
top_k_reader = st.sidebar.slider( | |
"Max. number of answers", | |
min_value=1, | |
max_value=100, | |
value=DEFAULT_NUMBER_OF_ANSWERS, | |
step=1, | |
on_change=reset_results, | |
) | |
top_k_retriever = st.sidebar.slider( | |
"Max. number of documents from retriever", | |
min_value=1, | |
max_value=100, | |
value=DEFAULT_DOCS_FROM_RETRIEVER, | |
step=1, | |
on_change=reset_results, | |
) | |
else: | |
top_k_reader = DEFAULT_NUMBER_OF_ANSWERS | |
top_k_retriever = DEFAULT_DOCS_FROM_RETRIEVER | |
# File upload block | |
if not DISABLE_FILE_UPLOAD: | |
st.sidebar.write("## File Upload:") | |
data_files = st.sidebar.file_uploader( | |
"", type=["pdf", "txt", "docx"], accept_multiple_files=True) | |
for data_file in data_files: | |
# Upload file | |
if data_file: | |
raw_json = upload_doc(data_file) | |
st.sidebar.write(str(data_file.name) + " β ") | |
if debug: | |
st.subheader("REST API JSON response") | |
st.sidebar.write(raw_json) | |
# st.sidebar.markdown( | |
# f""" | |
# <style> | |
# a {{ | |
# text-decoration: none; | |
# }} | |
# .haystack-footer {{ | |
# text-align: center; | |
# }} | |
# .haystack-footer h4 {{ | |
# margin: 0.1rem; | |
# padding:0; | |
# }} | |
# footer {{ | |
# opacity: 0; | |
# }} | |
# </style> | |
# <div class="haystack-footer"> | |
# <hr /> | |
# <h4>Debug parameters</h4> | |
# <small>Data crawled from <a href="https://www.uscis.gov">USCIS</a></small></div> | |
# """, | |
# unsafe_allow_html=True, | |
# ) | |
# Search bar | |
question = st.text_input( | |
"", value=st.session_state.question, max_chars=100, on_change=reset_results) | |
col1, col2 = st.columns(2) | |
col1.markdown( | |
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
col2.markdown( | |
"<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
# Run button | |
run_pressed = col1.button("Run") | |
run_query = ( | |
run_pressed or question != st.session_state.question | |
) and not st.session_state.random_question_requested | |
# Get results for query | |
if run_query and question: | |
reset_results() | |
st.session_state.question = question | |
with st.spinner("π§ Performing neural search on documents... \n "): | |
try: | |
st.session_state.results, st.session_state.raw_json = query( | |
pipelines, question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever, language=language | |
) | |
except JSONDecodeError as je: | |
st.error( | |
"π An error occurred reading the results. Is the document store working?") | |
return | |
except Exception as e: | |
logging.exception(e) | |
if "The server is busy processing requests" in str(e) or "503" in str(e): | |
st.error( | |
"π§βπΎ All our workers are busy! Try again later.") | |
else: | |
st.error( | |
"π An error occurred during the request.") | |
return | |
if st.session_state.results: | |
st.write("## Results:") | |
for count, result in enumerate(st.session_state.results): | |
if result["answer"]: | |
answer, context = result["answer"], result["context"] | |
start_idx = context.find(answer) | |
end_idx = start_idx + len(answer) | |
# Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190 | |
st.write( | |
markdown(f"**Answer:** {answer}"), unsafe_allow_html=True) | |
# st.write( | |
# markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#8ef")) + context[end_idx:]), | |
# unsafe_allow_html=True, | |
# ) | |
source = "" | |
url, title = get_backlink(result) | |
if url and title: | |
source = f"[{result['document']['meta']['title']}]({result['document']['meta']['url']})" | |
else: | |
source = f"{result['source']}" | |
st.markdown(f"**Source:** {source}") | |
else: | |
st.info( | |
"π€ Unsure whether any of the documents contain an answer to your question. Try to reformulate it!" | |
) | |
st.write("___") | |
if debug: | |
st.subheader("REST API JSON response") | |
st.write(st.session_state.raw_json) | |
main() | |