import os cwd = os.getcwd() os.environ['PYTORCH_TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') os.environ['TRANSFORMERS_CACHE'] = os.path.join(cwd, 'huggingface/transformers/') os.environ['HF_HOME'] = os.path.join(cwd, 'huggingface/') # import sys import logging from json import JSONDecodeError from pathlib import Path # import zipfile import pandas as pd import streamlit as st from markdown import markdown from utils import get_backlink, get_pipelines, query, send_feedback, upload_doc # Adjust to a question that you would like users to see in the search bar when they load the UI: DEFAULT_QUESTION_AT_STARTUP = os.getenv( "DEFAULT_QUESTION_AT_STARTUP", "How to get TPS?") DEFAULT_ANSWER_AT_STARTUP = os.getenv( "DEFAULT_ANSWER_AT_STARTUP", "You must file a Form I-765") # Sliders DEFAULT_DOCS_FROM_RETRIEVER = int( os.getenv("DEFAULT_DOCS_FROM_RETRIEVER", "5")) DEFAULT_NUMBER_OF_ANSWERS = int(os.getenv("DEFAULT_NUMBER_OF_ANSWERS", "1")) # Whether the file upload should be enabled or not DISABLE_FILE_UPLOAD = bool(os.getenv("DISABLE_FILE_UPLOAD", "True")) LANG_MAP = {"English": "English", "Ukrainian": "Ukrainian", "russian": "russian"} pipelines = get_pipelines() def set_state_if_absent(key, value): if key not in st.session_state: st.session_state[key] = value def main(): st.set_page_config(page_title="AI advisor") # Persistent state set_state_if_absent("question", DEFAULT_QUESTION_AT_STARTUP) set_state_if_absent("answer", DEFAULT_ANSWER_AT_STARTUP) set_state_if_absent("results", None) set_state_if_absent("raw_json", None) set_state_if_absent("random_question_requested", False) # Small callback to reset the interface in case the text of the question changes def reset_results(*args): st.session_state.answer = None st.session_state.results = None st.session_state.raw_json = None # Title st.write("# AI Immigration advisor") # Sidebar st.sidebar.header("Options") language = st.sidebar.selectbox( "Select language: ", ("English", "Ukrainian", "Spanish", "French", "Italian", "Arabic", "Hindi", "Portuguese", "Mandarin Chinese", "Japanese", "russian")) debug = False debug = False # debug = st.sidebar.checkbox("Show debug info") if debug: top_k_reader = st.sidebar.slider( "Max. number of answers", min_value=1, max_value=100, value=DEFAULT_NUMBER_OF_ANSWERS, step=1, on_change=reset_results, ) top_k_retriever = st.sidebar.slider( "Max. number of documents from retriever", min_value=1, max_value=100, value=DEFAULT_DOCS_FROM_RETRIEVER, step=1, on_change=reset_results, ) else: top_k_reader = DEFAULT_NUMBER_OF_ANSWERS top_k_retriever = DEFAULT_DOCS_FROM_RETRIEVER # File upload block if not DISABLE_FILE_UPLOAD: st.sidebar.write("## File Upload:") data_files = st.sidebar.file_uploader( "", type=["pdf", "txt", "docx"], accept_multiple_files=True) for data_file in data_files: # Upload file if data_file: raw_json = upload_doc(data_file) st.sidebar.write(str(data_file.name) + " ✅ ") if debug: st.subheader("REST API JSON response") st.sidebar.write(raw_json) # st.sidebar.markdown( # f""" # #
# """, # unsafe_allow_html=True, # ) # Search bar question = st.text_input( "", value=st.session_state.question, max_chars=100, on_change=reset_results) col1, col2 = st.columns(2) col1.markdown( "", unsafe_allow_html=True) col2.markdown( "", unsafe_allow_html=True) # Run button run_pressed = col1.button("Run") run_query = ( run_pressed or question != st.session_state.question ) and not st.session_state.random_question_requested # Get results for query if run_query and question: reset_results() st.session_state.question = question with st.spinner("🧠 Performing neural search on documents... \n "): try: st.session_state.results, st.session_state.raw_json = query( pipelines, question, top_k_reader=top_k_reader, top_k_retriever=top_k_retriever, language=language ) except JSONDecodeError as je: st.error( "👓 An error occurred reading the results. Is the document store working?") return except Exception as e: logging.exception(e) if "The server is busy processing requests" in str(e) or "503" in str(e): st.error( "🧑🌾 All our workers are busy! Try again later.") else: st.error( "🐞 An error occurred during the request.") return if st.session_state.results: st.write("## Results:") for count, result in enumerate(st.session_state.results): if result["answer"]: answer, context = result["answer"], result["context"] start_idx = context.find(answer) end_idx = start_idx + len(answer) # Hack due to this bug: https://github.com/streamlit/streamlit/issues/3190 st.write( markdown(f"**Answer:** {answer}"), unsafe_allow_html=True) # st.write( # markdown(context[:start_idx] + str(annotation(answer, "ANSWER", "#8ef")) + context[end_idx:]), # unsafe_allow_html=True, # ) source = "" url, title = get_backlink(result) if url and title: source = f"[{result['document']['meta']['title']}]({result['document']['meta']['url']})" else: source = f"{result['source']}" st.markdown(f"**Source:** {source}") else: st.info( "🤔 Unsure whether any of the documents contain an answer to your question. Try to reformulate it!" ) st.write("___") if debug: st.subheader("REST API JSON response") st.write(st.session_state.raw_json) main()