File size: 4,689 Bytes
58edde1
 
45ee012
3b3c852
b0e4f45
3b3c852
58edde1
 
 
862e59b
 
 
 
 
 
 
 
 
 
d304ae4
 
33d813d
3b3c852
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d304ae4
 
 
 
 
 
3b3c852
d304ae4
 
 
6b58ffb
04376ef
 
 
 
 
aa88d2a
b9d05c0
04376ef
 
 
 
 
 
b0e4f45
 
aa88d2a
6b58ffb
 
a17d0ff
6b58ffb
 
d304ae4
6b58ffb
 
 
 
 
 
f790556
b0e4f45
0e62360
3b3c852
 
 
04376ef
f790556
 
 
 
6b58ffb
f790556
 
 
 
 
 
d304ae4
d0de0d8
 
 
 
 
 
 
 
d304ae4
d0de0d8
 
 
 
 
3b3c852
 
 
 
 
 
d0de0d8
 
 
 
 
 
 
 
 
 
 
d304ae4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import sys
sys.path.append('.')
from copy import deepcopy
from langchain.callbacks import StreamlitCallbackHandler
import streamlit as st
import logging
import ui_utils
import weaviate_utils
import tapas_utils 
from weaviate_utils import *
from tapas_utils import *
from ui_utils import *

# Initialize Weaviate client
client = initialize_weaviate_client()

# Initialize TAPAS
tokenizer, model = initialize_tapas()

# Global list to store debugging information
DEBUG_LOGS = []

class StreamlitCallbackHandler(logging.Handler):
    def emit(self, record):
        log_entry = self.format(record)
        st.write(log_entry)

def log_debug_info(message):
    if st.session_state.debug:
        logger = logging.getLogger(__name__)
        logger.setLevel(logging.DEBUG)
        
        # Check if StreamlitCallbackHandler is already added to avoid duplicate logs
        if not any(isinstance(handler, StreamlitCallbackHandler) for handler in logger.handlers):
            handler = StreamlitCallbackHandler()
            logger.addHandler(handler)
        
        logger.debug(message)

# UI components
ui_utils.display_initial_buttons()
selected_class = ui_utils.display_class_dropdown(client)
ui_utils.handle_new_class_selection(client, selected_class)
ui_utils.csv_upload_and_ingestion(client, selected_class)
ui_utils.display_query_input()

# Initialize session state attributes
if "debug" not in st.session_state:
    st.session_state.debug = False

st.title("TAPAS Table Question Answering with Weaviate")

# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options, key="class_selector")

if selected_class == "New Class":
    class_name = st.text_input("Enter the new class name:")
    class_description = st.text_input("Enter a description for the class:")
else:
    class_name = selected_class
    class_description = ""  # We can fetch the description from Weaviate if needed

# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"], key="csv_uploader")

# Display the schema if an existing class is selected
class_schema = None  # Initialize class_schema to None
if selected_class != "New Class":
    st.write(f"Schema for {selected_class}:")
    class_schema = get_class_schema(client, selected_class)
    if class_schema:
        properties = class_schema["properties"]
        schema_df = pd.DataFrame(properties)
        st.table(schema_df[["name", "dataType"]])  # Display only the name and dataType columns

# Before ingesting data into Weaviate, check if CSV columns match the class schema
if csv_file is not None:
    data = csv_file.read().decode("utf-8")
    dataframe = pd.read_csv(StringIO(data))

    # Log CSV upload information
    log_debug_info(f"CSV uploaded with shape: {dataframe.shape}")
    
    # Display the uploaded CSV data
    st.write("Uploaded CSV Data:")
    st.write(dataframe)

    # Check if columns match
    if class_schema:  # Ensure class_schema is not None
        schema_columns = [prop["name"] for prop in class_schema["properties"]]
        if set(dataframe.columns) != set(schema_columns):
            st.error("The columns in the uploaded CSV do not match the schema of the selected class. Please check and upload the correct CSV or create a new class.")
        else:
            # Ingest data into Weaviate
            ingest_data_to_weaviate(client, dataframe, class_name, class_description)

    # Input for questions
    questions = st.text_area("Enter your questions (one per line)")
    questions = questions.split("\n")  # split questions by line
    questions = [q for q in questions if q]  # remove empty strings

    if st.button("Submit"):
        if data and questions:
            answers = summarize_map_reduce(tokenizer, model, data, questions)
            st.write("Answers:")
            for q, a in zip(questions, answers):
                st.write(f"Question: {q}")
                st.write(f"Answer: {a}")

# Display debugging information
if st.checkbox("Show Debugging Information"):
    st.write("Debugging Logs:")
    for log in DEBUG_LOGS:
        st.write(log)
        
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
    <script>
    document.addEventListener("DOMContentLoaded", function(event) {
        document.addEventListener("keydown", function(event) {
            if (event.ctrlKey && event.key === "Enter") {
                document.querySelector(".stButton button").click();
            }
        });
    });
    </script>
    """, unsafe_allow_html=True)