new_streamlit / app.py
JUNGU's picture
Update app.py
1994aa6
raw
history blame
5.86 kB
from dataclasses import dataclass
from typing import Literal
import streamlit as st
from langchain import OpenAI
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryMemory
import streamlit.components.v1 as components
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from dotenv import load_dotenv
import os
import openai
import time
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.environ['my_secret']
@dataclass
class Message:
"""Class for keeping track of a chat message."""
origin: Literal["human", "ai"]
message: str
def load_css():
with open("styles.css", "r") as f:
css = f"<style>{f.read()}</style>"
st.markdown(css, unsafe_allow_html=True)
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = []
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if 'chain' not in st.session_state:
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0)
loader = PyPDFLoader("/home/user/app/docs.pdf")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(texts, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
system_template="""You act like a successful teacher. Talk to students about the career path of a teacher.
Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly.
Given the following summaries of a long document and a question, create a final answer with references.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}
You MUST answer in Korean and in Markdown format"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs=chain_type_kwargs,
reduce_k_below_max_tokens=True,
verbose=True,
)
def generate_response(user_input):
result = st.session_state['chain'](user_input)
bot_message = result['answer']
for i, doc in enumerate(result['source_documents']):
bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') '
return bot_message
def on_click_callback():
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt
llm_response = generate_response(human_prompt)
st.session_state.history.append(
Message("human", human_prompt)
)
st.session_state.history.append(
Message("ai", llm_response)
)
st.session_state.token_count += cb.total_tokens
load_css()
initialize_session_state()
st.title("๊ต์‚ฌ์™€ ์ง„๋กœ์ƒ๋‹ด์„ ํ•ด๋ณด์„ธ์š”, \n ์‹ค์ œ ์ธํ„ฐ๋ทฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค. ๐Ÿค–")
chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form")
credit_card_placeholder = st.empty()
with chat_placeholder:
for chat in st.session_state.history:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="new_streamlit/resolve/main/static/{
'ai_icon.png' if chat.origin == 'ai'
else 'user_icon.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
for _ in range(3):
st.markdown("")
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
value="Hello bot",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
# credit_card_placeholder.caption(f"""
# Used {st.session_state.token_count} tokens \n
# Debug Langchain conversation:
# {st.session_state.chain.memory.buffer}
# """)
components.html("""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)