Zeta / app.py
Ritvik19's picture
Upload 5 files
7e4014b verified
raw
history blame
5.96 kB
import streamlit as st
import os
import pandas as pd
from command_center import CommandCenter
from process_documents import process_documents
from embed_documents import create_retriever
import json
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationalRetrievalChain
from langchain_openai import ChatOpenAI
st.set_page_config(layout="wide")
os.environ["OPENAI_API_KEY"] = "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"
get_references = lambda relevant_docs: " ".join(
[f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])]
)
session_state_2_llm_chat_history = lambda session_state: [
ss[:2] for ss in session_state if not ss[0].startswith("/")
]
ai_message_format = lambda message, references: (
f"{message}\n\n---\n\n{references}" if references != "" else message
)
def process_documents_wrapper(inputs):
snippets = process_documents(inputs)
st.session_state.retriever = create_retriever(snippets)
st.session_state.source_doc_urls = inputs
st.session_state.index = [snip.metadata["header"] for snip in snippets]
response = f"Uploaded and processed documents {inputs}"
st.session_state.messages.append((f"/upload {inputs}", response, ""))
return response
def index_documents_wrapper(inputs=None):
response = pd.Series(st.session_state.index, name="references").to_markdown()
st.session_state.messages.append(("/index", response, ""))
return response
def calculate_cost_wrapper(inputs=None):
try:
stats_df = pd.DataFrame(st.session_state.costing)
stats_df.loc["total"] = stats_df.sum()
response = stats_df.to_markdown()
except ValueError:
response = "No costing incurred yet"
st.session_state.messages.append(("/cost", response, ""))
return response
def download_conversation_wrapper(inputs=None):
conversation_data = json.dumps(
{
"document_urls": (
st.session_state.source_doc_urls
if "source_doc_urls" in st.session_state
else []
),
"document_snippets": (
st.session_state.index.to_list()
if "headers" in st.session_state
else []
),
"conversation": [
{"human": message[0], "ai": message[1], "references": message[2]}
for message in st.session_state.messages
],
"costing": (
st.session_state.costing if "costing" in st.session_state else []
),
"total_cost": (
{
k: sum(d[k] for d in st.session_state.costing)
for k in st.session_state.costing[0]
}
if "costing" in st.session_state and len(st.session_state.costing) > 0
else {}
),
}
)
st.sidebar.download_button(
"Download Conversation",
conversation_data,
file_name="conversation_data.json",
mime="application/json",
)
st.session_state.messages.append(("/download", "Conversation data downloaded", ""))
def query_llm_wrapper(inputs):
retriever = st.session_state.retriever
qa_chain = ConversationalRetrievalChain.from_llm(
llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0),
retriever=retriever,
return_source_documents=True,
chain_type="stuff",
)
relevant_docs = retriever.get_relevant_documents(inputs)
with get_openai_callback() as cb:
result = qa_chain(
{
"question": inputs,
"chat_history": session_state_2_llm_chat_history(
st.session_state.messages
),
}
)
stats = cb
result = result["answer"]
references = get_references(relevant_docs)
st.session_state.messages.append((inputs, result, references))
st.session_state.costing.append(
{
"prompt tokens": stats.prompt_tokens,
"completion tokens": stats.completion_tokens,
"cost": stats.total_cost,
}
)
return result, references
def boot(command_center):
st.title("Agent Xi - An ArXiv Chatbot")
if "costing" not in st.session_state:
st.session_state.costing = []
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message("human").write(message[0])
st.chat_message("ai").write(ai_message_format(message[1], message[2]))
if query := st.chat_input():
st.chat_message("human").write(query)
response = command_center.execute_command(query)
if response is None:
pass
elif type(response) == tuple:
result, references = response
st.chat_message("ai").write(ai_message_format(result, references))
else:
st.chat_message("ai").write(response)
if __name__ == "__main__":
all_commands = [
("/upload", list, process_documents_wrapper, "Upload and process documents"),
("/index", None, index_documents_wrapper, "View index of processed documents"),
("/cost", None, calculate_cost_wrapper, "Calculate cost of conversation"),
(
"/download",
None,
download_conversation_wrapper,
"Download conversation data",
),
]
st.sidebar.title("Commands Menu")
st.sidebar.write(
pd.DataFrame(
{
"Command": [command[0] for command in all_commands],
"Description": [command[3] for command in all_commands],
}
)
)
command_center = CommandCenter(
default_input_type=str,
default_function=query_llm_wrapper,
all_commands=[command[:3] for command in all_commands],
)
boot(command_center)