jskinner215's picture
Update app.py
d0de0d8
raw
history blame
6.94 kB
from copy import deepcopy
import streamlit as st
import pandas as pd
from io import StringIO
from transformers import AutoTokenizer, AutoModelForTableQuestionAnswering
import numpy as np
import weaviate
from weaviate.embedded import EmbeddedOptions
from weaviate import Client
from weaviate.util import generate_uuid5
# Initialize TAPAS model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/tapas-large-finetuned-wtq")
model = AutoModelForTableQuestionAnswering.from_pretrained("google/tapas-large-finetuned-wtq")
# Initialize Weaviate client for the embedded instance
client = weaviate.Client(
embedded_options=EmbeddedOptions()
)
# Function to check if a class already exists in Weaviate
def class_exists(class_name):
try:
client.schema.get_class(class_name)
return True
except:
return False
def map_dtype_to_weaviate(dtype):
"""
Map pandas data types to Weaviate data types.
"""
if "int" in str(dtype):
return "int"
elif "float" in str(dtype):
return "number"
elif "bool" in str(dtype):
return "boolean"
else:
return "string"
def ingest_data_to_weaviate(dataframe, class_name, class_description):
# Create class schema
class_schema = {
"class": class_name,
"description": class_description,
"properties": [] # Start with an empty properties list
}
# Try to create the class without properties first
try:
client.schema.create({"classes": [class_schema]})
except weaviate.exceptions.SchemaValidationException:
# Class might already exist, so we can continue
pass
# Now, let's add properties to the class
for column_name, data_type in zip(dataframe.columns, dataframe.dtypes):
property_schema = {
"name": column_name,
"description": f"Property for {column_name}",
"dataType": [map_dtype_to_weaviate(data_type)]
}
try:
client.schema.property.create(class_name, property_schema)
except weaviate.exceptions.SchemaValidationException:
# Property might already exist, so we can continue
pass
# Ingest data
for index, row in dataframe.iterrows():
obj = {
"class": class_name,
"id": str(index),
"properties": row.to_dict()
}
client.data_object.create(obj)
def query_weaviate(question):
# This is a basic example; adapt the query based on the question
results = client.query.get(class_name).with_near_text(question).do()
return results
def ask_llm_chunk(chunk, questions):
chunk = chunk.astype(str)
try:
inputs = tokenizer(table=chunk, queries=questions, padding="max_length", truncation=True, return_tensors="pt")
except Exception as e:
st.write(f"An error occurred: {e}")
return ["Error occurred while tokenizing"] * len(questions)
if inputs["input_ids"].shape[1] > 512:
st.warning("Token limit exceeded for chunk")
return ["Token limit exceeded for chunk"] * len(questions)
outputs = model(**inputs)
predicted_answer_coordinates, predicted_aggregation_indices = tokenizer.convert_logits_to_predictions(
inputs,
outputs.logits.detach(),
outputs.logits_aggregation.detach()
)
answers = []
for coordinates in predicted_answer_coordinates:
if len(coordinates) == 1:
row, col = coordinates[0]
try:
st.write(f"DataFrame shape: {chunk.shape}") # Debugging line
st.write(f"DataFrame columns: {chunk.columns}") # Debugging line
st.write(f"Trying to access row {row}, col {col}") # Debugging line
value = chunk.iloc[row, col]
st.write(f"Value accessed: {value}") # Debugging line
answers.append(value)
except Exception as e:
st.write(f"An error occurred: {e}")
else:
cell_values = []
for coordinate in coordinates:
row, col = coordinate
try:
value = chunk.iloc[row, col]
cell_values.append(value)
except Exception as e:
st.write(f"An error occurred: {e}")
answers.append(", ".join(map(str, cell_values)))
return answers
MAX_ROWS_PER_CHUNK = 200
def summarize_map_reduce(data, questions):
dataframe = pd.read_csv(StringIO(data))
num_chunks = len(dataframe) // MAX_ROWS_PER_CHUNK + 1
dataframe_chunks = [deepcopy(chunk) for chunk in np.array_split(dataframe, num_chunks)]
all_answers = []
for chunk in dataframe_chunks:
chunk_answers = ask_llm_chunk(chunk, questions)
all_answers.extend(chunk_answers)
return all_answers
st.title("TAPAS Table Question Answering with Weaviate")
# Get existing classes from Weaviate
existing_classes = [cls["class"] for cls in client.schema.get()["classes"]]
class_options = existing_classes + ["New Class"]
selected_class = st.selectbox("Select a class or create a new one:", class_options)
if selected_class == "New Class":
class_name = st.text_input("Enter the new class name:")
class_description = st.text_input("Enter a description for the class:")
else:
class_name = selected_class
class_description = "" # We can fetch the description from Weaviate if needed
# Upload CSV data
csv_file = st.file_uploader("Upload a CSV file", type=["csv"])
if csv_file is not None:
data = csv_file.read().decode("utf-8")
dataframe = pd.read_csv(StringIO(data))
# Display the schema if an existing class is selected
if selected_class != "New Class":
st.write(f"Schema for {selected_class}:")
class_schema = client.schema.get_class(selected_class)
st.write(class_schema)
# Ingest data into Weaviate
ingest_data_to_weaviate(dataframe, class_name, class_description)
# Input for questions
questions = st.text_area("Enter your questions (one per line)")
questions = questions.split("\n") # split questions by line
questions = [q for q in questions if q] # remove empty strings
if st.button("Submit"):
if data and questions:
answers = summarize_map_reduce(data, questions)
st.write("Answers:")
for q, a in zip(questions, answers):
st.write(f"Question: {q}")
st.write(f"Answer: {a}")
# Add Ctrl+Enter functionality for submitting the questions
st.markdown("""
<script>
document.addEventListener("DOMContentLoaded", function(event) {
document.addEventListener("keydown", function(event) {
if (event.ctrlKey && event.key === "Enter") {
document.querySelector(".stButton button").click();
}
});
});
</script>
""", unsafe_allow_html=True)