jskinner215's picture
Update app.py
aa88d2a
raw
history blame
4.69 kB
import sys
sys.path.append('.')
from copy import deepcopy
from langchain.callbacks import StreamlitCallbackHandler
import streamlit as st
import logging
import ui_utils
import weaviate_utils
import tapas_utils
from weaviate_utils import *
from tapas_utils import *
from ui_utils import *
# Initialize Weaviate client
client = initialize_weaviate_client()
# Initialize TAPAS
tokenizer, model = initialize_tapas()
# Global list to store debugging information
DEBUG_LOGS = []
class StreamlitCallbackHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
st.write(log_entry)
def log_debug_info(message):
if st.session_state.debug:
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
# Check if StreamlitCallbackHandler is already added to avoid duplicate logs
if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers):
handler = StreamlitCallbackHandler()
logger.addHandler(handler)
logger.debug(message)
# UI components
ui_utils.display_initial_buttons()
selected_class = ui_utils.display_class_dropdown(client)
ui_utils.handle_new_class_selection(client, selected_class)
ui_utils.csv_upload_and_ingestion(client, selected_class)
ui_utils.display_query_input()
# Initialize session state attributes
if "debug" not in st.session_state:
st.session_state.debug = False
st.title("TAPAS Table Question Answering with Weaviate")
# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options, key="class_selector")
if selected_class == "New Class":
class_name = st.text_input("Enter the new class name:")
class_description = st.text_input("Enter a description for the class:")
else:
class_name = selected_class
class_description = "" # We can fetch the description from Weaviate if needed
# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"], key="csv_uploader")
# Display the schema if an existing class is selected
class_schema = None # Initialize class_schema to None
if selected_class != "New Class":
st.write(f"Schema for {selected_class}:")
class_schema = get_class_schema(client, selected_class)
if class_schema:
properties = class_schema["properties"]
schema_df = pd.DataFrame(properties)
st.table(schema_df[["name", "dataType"]]) # Display only the name and dataType columns
# Before ingesting data into Weaviate, check if CSV columns match the class schema
if csv_file is not None:
data = csv_file.read().decode("utf-8")
dataframe = pd.read_csv(StringIO(data))
# Log CSV upload information
log_debug_info(f"CSV uploaded with shape: {dataframe.shape}")
# Display the uploaded CSV data
st.write("Uploaded CSV Data:")
st.write(dataframe)
# Check if columns match
if class_schema: # Ensure class_schema is not None
schema_columns = [prop["name"] for prop in class_schema["properties"]]
if set(dataframe.columns) != set(schema_columns):
st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
else:
# Ingest data into Weaviate
ingest_data_to_weaviate(client, dataframe, class_name, class_description)
# Input for questions
questions = st.text_area("Enter your questions (one per line)")
questions = questions.split("\n") # split questions by line
questions = [q for q in questions if q] # remove empty strings
if st.button("Submit"):
if data and questions:
answers = summarize_map_reduce(tokenizer, model, data, questions)
st.write("Answers:")
for q, a in zip(questions, answers):
st.write(f"Question: {q}")
st.write(f"Answer: {a}")
# Display debugging information
if st.checkbox("Show Debugging Information"):
st.write("Debugging Logs:")
for log in DEBUG_LOGS:
st.write(log)
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
<script>
document.addEventListener("DOMContentLoaded", function(event) {
document.addEventListener("keydown", function(event) {
if (event.ctrlKey && event.key === "Enter") {
document.querySelector(".stButton button").click();
}
});
});
</script>
""", unsafe_allow_html=True)