File size: 6,940 Bytes
45ee012
b0e4f45
 
 
 
512f2de
46ad3c2
 
4758881
 
b0e4f45
 
 
 
 
0e62360
27ba167
 
 
0e62360
04376ef
 
 
 
 
 
 
 
a4a1c61
 
 
 
 
 
 
 
 
 
 
 
 
b9d05c0
203bb0f
 
 
 
 
b9d05c0
 
203bb0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e62360
b9d05c0
203bb0f
0e62360
 
203bb0f
 
0e62360
 
 
b9d05c0
0e62360
 
b0e4f45
 
414bc96
 
 
 
 
512f2de
66f9f66
 
 
c1ca766
 
 
 
 
 
 
515be2e
c1ca766
 
45ee012
 
c1ca766
45ee012
 
072f6c1
515be2e
66f9f66
45ee012
c1ca766
66f9f66
45ee012
 
 
 
 
 
 
 
 
 
c1ca766
515be2e
b0e4f45
512f2de
b0e4f45
 
512f2de
b0e4f45
45ee012
b0e4f45
 
 
 
5e71278
b0e4f45
04376ef
 
 
 
 
 
b9d05c0
04376ef
 
 
 
 
 
b0e4f45
 
 
 
 
0e62360
04376ef
 
 
 
 
 
 
 
d0de0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from copy import deepcopy
import streamlit as st
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions
from weaviate import Client
from weaviate.util import generate_uuid5

# Initialize TAPAS model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")

# Initialize Weaviate client for the embedded instance
client = weaviate.Client(
  embedded_options=EmbeddedOptions()
)

# Function to check if a class already exists in Weaviate
def class_exists(class_name):
    try:
        client.schema.get_class(class_name)
        return True
    except:
        return False

def map_dtype_to_weaviate(dtype):
    """
    Map pandas data types to Weaviate data types.
    """
    if "int" in str(dtype):
        return "int"
    elif "float" in str(dtype):
        return "number"
    elif "bool" in str(dtype):
        return "boolean"
    else:
        return "string"

def ingest_data_to_weaviate(dataframe, class_name, class_description):
    # Create class schema
    class_schema = {
        "class": class_name,
        "description": class_description,
        "properties": []  # Start with an empty properties list
    }
    
    # Try to create the class without properties first
    try:
        client.schema.create({"classes": [class_schema]})
    except weaviate.exceptions.SchemaValidationException:
        # Class might already exist, so we can continue
        pass

    # Now, let's add properties to the class
    for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
        property_schema = {
            "name": column_name,
            "description": f"Property for {column_name}",
            "dataType": [map_dtype_to_weaviate(data_type)]
        }
        try:
            client.schema.property.create(class_name, property_schema)
        except weaviate.exceptions.SchemaValidationException:
            # Property might already exist, so we can continue
            pass

    # Ingest data
    for index, row in dataframe.iterrows():
        obj = {
            "class": class_name,
            "id": str(index),
            "properties": row.to_dict()
        }
        client.data_object.create(obj)


def query_weaviate(question):
    # This is a basic example; adapt the query based on the question
    results = client.query.get(class_name).with_near_text(question).do()
    return results

def ask_llm_chunk(chunk, questions):
    chunk = chunk.astype(str)
    try:
        inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
    except Exception as e:
        st.write(f"An error occurred: {e}")
        return ["Error occurred while tokenizing"] * len(questions)

    if inputs["input_ids"].shape[1] > 512:
        st.warning("Token limit exceeded for chunk")
        return ["Token limit exceeded for chunk"] * len(questions)

    outputs = model(**inputs)
    predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
        inputs,
        outputs.logits.detach(),
        outputs.logits_aggregation.detach()
    )

    answers = []
    for coordinates in predicted_answer_coordinates:
        if len(coordinates) == 1:
            row, col = coordinates[0]
            try:
                st.write(f"DataFrame shape: {chunk.shape}")  # Debugging line
                st.write(f"DataFrame columns: {chunk.columns}")  # Debugging line
                st.write(f"Trying to access row {row}, col {col}")  # Debugging line
                value = chunk.iloc[row, col]
                st.write(f"Value accessed: {value}")  # Debugging line
                answers.append(value)
            except Exception as e:
                st.write(f"An error occurred: {e}")
        else:
            cell_values = []
            for coordinate in coordinates:
                row, col = coordinate
                try:
                    value = chunk.iloc[row, col]
                    cell_values.append(value)
                except Exception as e:
                    st.write(f"An error occurred: {e}")
            answers.append(", ".join(map(str, cell_values)))

    return answers

MAX_ROWS_PER_CHUNK = 200

def summarize_map_reduce(data, questions):
    dataframe = pd.read_csv(StringIO(data))
    num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
    dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
    all_answers = []
    for chunk in dataframe_chunks:
        chunk_answers = ask_llm_chunk(chunk, questions)
        all_answers.extend(chunk_answers)
    return all_answers

st.title("TAPAS Table Question Answering with Weaviate")

# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options)

if selected_class == "New Class":
    class_name = st.text_input("Enter the new class name:")
    class_description = st.text_input("Enter a description for the class:")
else:
    class_name = selected_class
    class_description = ""  # We can fetch the description from Weaviate if needed

# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
if csv_file is not None:
    data = csv_file.read().decode("utf-8")
    dataframe = pd.read_csv(StringIO(data))
    
    # Display the schema if an existing class is selected
    if selected_class != "New Class":
        st.write(f"Schema for {selected_class}:")
        class_schema = client.schema.get_class(selected_class)
        st.write(class_schema)
    
    # Ingest data into Weaviate
    ingest_data_to_weaviate(dataframe, class_name, class_description)

    # Input for questions
    questions = st.text_area("Enter your questions (one per line)")
    questions = questions.split("\n")  # split questions by line
    questions = [q for q in questions if q]  # remove empty strings

    if st.button("Submit"):
        if data and questions:
            answers = summarize_map_reduce(data, questions)
            st.write("Answers:")
            for q, a in zip(questions, answers):
                st.write(f"Question: {q}")
                st.write(f"Answer: {a}")

# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
    <script>
    document.addEventListener("DOMContentLoaded", function(event) {
        document.addEventListener("keydown", function(event) {
            if (event.ctrlKey && event.key === "Enter") {
                document.querySelector(".stButton button").click();
            }
        });
    });
    </script>
    """, unsafe_allow_html=True)