|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
from command_center import CommandCenter |
|
from process_documents import process_documents |
|
from embed_documents import create_retriever |
|
import json |
|
from langchain.callbacks import get_openai_callback |
|
from langchain.chains import ConversationalRetrievalChain |
|
from langchain_openai import ChatOpenAI |
|
|
|
st.set_page_config(layout="wide") |
|
os.environ["OPENAI_API_KEY"] = "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" |
|
|
|
get_references = lambda relevant_docs: " ".join( |
|
[f"[{ref}]" for ref in sorted([ref.metadata["chunk_id"] for ref in relevant_docs])] |
|
) |
|
session_state_2_llm_chat_history = lambda session_state: [ |
|
ss[:2] for ss in session_state if not ss[0].startswith("/") |
|
] |
|
ai_message_format = lambda message, references: ( |
|
f"{message}\n\n---\n\n{references}" if references != "" else message |
|
) |
|
|
|
|
|
def process_documents_wrapper(inputs): |
|
snippets = process_documents(inputs) |
|
st.session_state.retriever = create_retriever(snippets) |
|
st.session_state.source_doc_urls = inputs |
|
st.session_state.index = [snip.metadata["header"] for snip in snippets] |
|
response = f"Uploaded and processed documents {inputs}" |
|
st.session_state.messages.append((f"/upload {inputs}", response, "")) |
|
return response |
|
|
|
|
|
def index_documents_wrapper(inputs=None): |
|
response = pd.Series(st.session_state.index, name="references").to_markdown() |
|
st.session_state.messages.append(("/index", response, "")) |
|
return response |
|
|
|
|
|
def calculate_cost_wrapper(inputs=None): |
|
try: |
|
stats_df = pd.DataFrame(st.session_state.costing) |
|
stats_df.loc["total"] = stats_df.sum() |
|
response = stats_df.to_markdown() |
|
except ValueError: |
|
response = "No costing incurred yet" |
|
st.session_state.messages.append(("/cost", response, "")) |
|
return response |
|
|
|
|
|
def download_conversation_wrapper(inputs=None): |
|
conversation_data = json.dumps( |
|
{ |
|
"document_urls": ( |
|
st.session_state.source_doc_urls |
|
if "source_doc_urls" in st.session_state |
|
else [] |
|
), |
|
"document_snippets": ( |
|
st.session_state.index.to_list() |
|
if "headers" in st.session_state |
|
else [] |
|
), |
|
"conversation": [ |
|
{"human": message[0], "ai": message[1], "references": message[2]} |
|
for message in st.session_state.messages |
|
], |
|
"costing": ( |
|
st.session_state.costing if "costing" in st.session_state else [] |
|
), |
|
"total_cost": ( |
|
{ |
|
k: sum(d[k] for d in st.session_state.costing) |
|
for k in st.session_state.costing[0] |
|
} |
|
if "costing" in st.session_state and len(st.session_state.costing) > 0 |
|
else {} |
|
), |
|
} |
|
) |
|
st.sidebar.download_button( |
|
"Download Conversation", |
|
conversation_data, |
|
file_name="conversation_data.json", |
|
mime="application/json", |
|
) |
|
st.session_state.messages.append(("/download", "Conversation data downloaded", "")) |
|
|
|
|
|
def query_llm_wrapper(inputs): |
|
retriever = st.session_state.retriever |
|
qa_chain = ConversationalRetrievalChain.from_llm( |
|
llm=ChatOpenAI(model="gpt-4-0125-preview", temperature=0), |
|
retriever=retriever, |
|
return_source_documents=True, |
|
chain_type="stuff", |
|
) |
|
relevant_docs = retriever.get_relevant_documents(inputs) |
|
with get_openai_callback() as cb: |
|
result = qa_chain( |
|
{ |
|
"question": inputs, |
|
"chat_history": session_state_2_llm_chat_history( |
|
st.session_state.messages |
|
), |
|
} |
|
) |
|
stats = cb |
|
result = result["answer"] |
|
references = get_references(relevant_docs) |
|
st.session_state.messages.append((inputs, result, references)) |
|
st.session_state.costing.append( |
|
{ |
|
"prompt tokens": stats.prompt_tokens, |
|
"completion tokens": stats.completion_tokens, |
|
"cost": stats.total_cost, |
|
} |
|
) |
|
return result, references |
|
|
|
|
|
def boot(command_center): |
|
st.title("Agent Xi - An ArXiv Chatbot") |
|
if "costing" not in st.session_state: |
|
st.session_state.costing = [] |
|
if "messages" not in st.session_state: |
|
st.session_state.messages = [] |
|
for message in st.session_state.messages: |
|
st.chat_message("human").write(message[0]) |
|
st.chat_message("ai").write(ai_message_format(message[1], message[2])) |
|
if query := st.chat_input(): |
|
st.chat_message("human").write(query) |
|
response = command_center.execute_command(query) |
|
if response is None: |
|
pass |
|
elif type(response) == tuple: |
|
result, references = response |
|
st.chat_message("ai").write(ai_message_format(result, references)) |
|
else: |
|
st.chat_message("ai").write(response) |
|
|
|
|
|
if __name__ == "__main__": |
|
all_commands = [ |
|
("/upload", list, process_documents_wrapper, "Upload and process documents"), |
|
("/index", None, index_documents_wrapper, "View index of processed documents"), |
|
("/cost", None, calculate_cost_wrapper, "Calculate cost of conversation"), |
|
( |
|
"/download", |
|
None, |
|
download_conversation_wrapper, |
|
"Download conversation data", |
|
), |
|
] |
|
st.sidebar.title("Commands Menu") |
|
st.sidebar.write( |
|
pd.DataFrame( |
|
{ |
|
"Command": [command[0] for command in all_commands], |
|
"Description": [command[3] for command in all_commands], |
|
} |
|
) |
|
) |
|
command_center = CommandCenter( |
|
default_input_type=str, |
|
default_function=query_llm_wrapper, |
|
all_commands=[command[:3] for command in all_commands], |
|
) |
|
boot(command_center) |
|
|