new_streamlit / app.py
JUNGU's picture
Update app.py
09f7073
raw
history blame
6.56 kB
from dataclasses import dataclass
from typing import Literal
import streamlit as st
from langchain import OpenAI
from langchain.callbacks import get_openai_callback
from langchain.chains import ConversationChain
from langchain.chains.conversation.memory import ConversationSummaryMemory
import streamlit.components.v1 as components
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.embeddings.cohere import CohereEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores.elastic_vector_search import ElasticVectorSearch
from langchain.vectorstores import Chroma
from langchain.chat_models import ChatOpenAI
from langchain.chains import RetrievalQAWithSourcesChain
from dotenv import load_dotenv
import os
import openai
import time
load_dotenv()
os.environ["OPENAI_API_KEY"] = os.environ['my_secret']
@dataclass
class Message:
"""Class for keeping track of a chat message."""
origin: Literal["human", "ai"]
message: str
def load_css():
with open("styles.css", "r") as f:
css = f"<style>{f.read()}</style>"
st.markdown(css, unsafe_allow_html=True)
def initialize_session_state():
if "history" not in st.session_state:
st.session_state.history = [Message("ai", "์•ˆ๋…•ํ•˜์„ธ์š”, ์–ด๋–ป๊ฒŒ ๋„์™€๋“œ๋ฆด๊นŒ์š”?")]
if "token_count" not in st.session_state:
st.session_state.token_count = 0
if 'chain' not in st.session_state:
llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0)
loader = PyPDFLoader("/home/user/app/docs.pdf")
documents = loader.load()
text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
texts = text_splitter.split_documents(documents)
embeddings = OpenAIEmbeddings()
vector_store = Chroma.from_documents(texts, embeddings)
retriever = vector_store.as_retriever(search_kwargs={"k": 3})
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
system_template="""You act like a successful teacher. Talk to students about the career path of a teacher.
Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly.
Given the following summaries of a long document and a question, create a final answer.
If you don't know the answer, just say that "I don't know", don't try to make up an answer.
----------------
{summaries}
You MUST answer in Korean and in Markdown format"""
messages = [
SystemMessagePromptTemplate.from_template(system_template),
HumanMessagePromptTemplate.from_template("{question}")
]
prompt = ChatPromptTemplate.from_messages(messages)
chain_type_kwargs = {"prompt": prompt}
st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True,
chain_type_kwargs=chain_type_kwargs,
reduce_k_below_max_tokens=True,
verbose=True,
)
def generate_response(user_input):
result = st.session_state['chain'](user_input)
bot_message = result['answer']
return bot_message
def on_click_callback():
with get_openai_callback() as cb:
human_prompt = st.session_state.human_prompt
llm_response = generate_response(human_prompt)
st.session_state.history.append(
Message("human", human_prompt)
)
st.session_state.history.append(
Message("ai", llm_response)
)
st.session_state.token_count += cb.total_tokens
load_css()
initialize_session_state()
st.title("๊ต์‚ฌ์™€ ์ง„๋กœ์ƒ๋‹ด์„ ํ•ด๋ณด์„ธ์š”, \n ์‹ค์ œ ์ธํ„ฐ๋ทฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•ฉ๋‹ˆ๋‹ค. ๐Ÿค–")
chat_placeholder = st.container()
prompt_placeholder = st.form("chat-form")
credit_card_placeholder = st.empty()
with chat_placeholder:
for chat in st.session_state.history[:-1]:
div = f"""
<div class="chat-row
{'' if chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/{
'/512/3058/3058838.png' if chat.origin == 'ai'
else '512/1177/1177568.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
&#8203;{chat.message}
</div>
</div>
"""
st.markdown(div, unsafe_allow_html=True)
if st.session_state.history:
last_chat = st.session_state.history[-1]
div_start = f"""
<div class="chat-row
{'' if last_chat.origin == 'ai' else 'row-reverse'}">
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/{
'/512/3058/3058838.png' if last_chat.origin == 'ai'
else '512/1177/1177568.png'}"
width=32 height=32>
<div class="chat-bubble
{'ai-bubble' if last_chat.origin == 'ai' else 'human-bubble'}">
&#8203;"""
div_end = """
</div>
</div>
"""
new_placeholder = st.empty()
for j in range(len(last_chat.message)):
new_placeholder.markdown(div_start + last_chat.message[:j+1] + div_end, unsafe_allow_html=True)
time.sleep(0.05)
for _ in range(3):
st.markdown("")
with prompt_placeholder:
st.markdown("**Chat**")
cols = st.columns((6, 1))
cols[0].text_input(
"Chat",
value="๊ต์‚ฌ๊ฐ€ ๋˜๋ ค๋ฉด ๋ฌด์—‡์„ ํ•ด์•ผ ํ•˜๋‚˜์š”?",
label_visibility="collapsed",
key="human_prompt",
)
cols[1].form_submit_button(
"Submit",
type="primary",
on_click=on_click_callback,
)
# credit_card_placeholder.caption(f"""
# Used {st.session_state.token_count} tokens \n
# Debug Langchain conversation:
# {st.session_state.chain.memory.buffer}
# """)
components.html("""
<script>
const streamlitDoc = window.parent.document;
const buttons = Array.from(
streamlitDoc.querySelectorAll('.stButton > button')
);
const submitButton = buttons.find(
el => el.innerText === 'Submit'
);
streamlitDoc.addEventListener('keydown', function(e) {
switch (e.key) {
case 'Enter':
submitButton.click();
break;
}
});
</script>
""",
height=0,
width=0,
)