Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,165 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from dataclasses import dataclass
         | 
| 2 | 
            +
            from typing import Literal
         | 
| 3 | 
            +
            import streamlit as st
         | 
| 4 | 
            +
            from langchain import OpenAI
         | 
| 5 | 
            +
            from langchain.callbacks import get_openai_callback
         | 
| 6 | 
            +
            from langchain.chains import ConversationChain
         | 
| 7 | 
            +
            from langchain.chains.conversation.memory import ConversationSummaryMemory
         | 
| 8 | 
            +
            import streamlit.components.v1 as components
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            @dataclass
         | 
| 11 | 
            +
            class Message:
         | 
| 12 | 
            +
                """Class for keeping track of a chat message."""
         | 
| 13 | 
            +
                origin: Literal["human", "ai"]
         | 
| 14 | 
            +
                message: str
         | 
| 15 | 
            +
             | 
| 16 | 
            +
            def load_css():
         | 
| 17 | 
            +
                with open("static/styles.css", "r") as f:
         | 
| 18 | 
            +
                    css = f"<style>{f.read()}</style>"
         | 
| 19 | 
            +
                    st.markdown(css, unsafe_allow_html=True)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            def initialize_session_state():
         | 
| 22 | 
            +
                if "history" not in st.session_state:
         | 
| 23 | 
            +
                    st.session_state.history = []
         | 
| 24 | 
            +
                if "token_count" not in st.session_state:
         | 
| 25 | 
            +
                    st.session_state.token_count = 0
         | 
| 26 | 
            +
                if 'chain' not in st.session_state:
         | 
| 27 | 
            +
                    llm = ChatOpenAI(model_name="gpt-3.5-turbo-0613", temperature=0)
         | 
| 28 | 
            +
                    loader = PyPDFLoader("/home/user/app/docs.pdf")
         | 
| 29 | 
            +
                    documents = loader.load()
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    text_splitter = CharacterTextSplitter(chunk_size=800, chunk_overlap=0)
         | 
| 32 | 
            +
                    texts = text_splitter.split_documents(documents)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                    embeddings = OpenAIEmbeddings()
         | 
| 35 | 
            +
                    vector_store = Chroma.from_documents(texts, embeddings)
         | 
| 36 | 
            +
                    retriever = vector_store.as_retriever(search_kwargs={"k": 3})
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    from langchain.prompts.chat import (
         | 
| 39 | 
            +
                        ChatPromptTemplate,
         | 
| 40 | 
            +
                        SystemMessagePromptTemplate,
         | 
| 41 | 
            +
                        HumanMessagePromptTemplate,
         | 
| 42 | 
            +
                    )
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    system_template="""You act like a successful pharmacist. Talk to students about the career path of a pharmacist. 
         | 
| 45 | 
            +
                    Never reveal that it is an AI. Use the following pieces of context to answer the users question shortly.
         | 
| 46 | 
            +
                    Given the following summaries of a long document and a question, create a final answer with references.
         | 
| 47 | 
            +
                    If you don't know the answer, just say that "I don't know", don't try to make up an answer.
         | 
| 48 | 
            +
                    ----------------
         | 
| 49 | 
            +
                    {summaries}
         | 
| 50 | 
            +
                    You MUST answer in Korean and in Markdown format"""
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    messages = [
         | 
| 53 | 
            +
                        SystemMessagePromptTemplate.from_template(system_template),
         | 
| 54 | 
            +
                        HumanMessagePromptTemplate.from_template("{question}")
         | 
| 55 | 
            +
                    ]
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    prompt = ChatPromptTemplate.from_messages(messages)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    chain_type_kwargs = {"prompt": prompt}
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    st.session_state['chain'] = RetrievalQAWithSourcesChain.from_chain_type(
         | 
| 62 | 
            +
                        llm=llm,
         | 
| 63 | 
            +
                        chain_type="stuff",
         | 
| 64 | 
            +
                        retriever=retriever,
         | 
| 65 | 
            +
                        return_source_documents=True,
         | 
| 66 | 
            +
                        chain_type_kwargs=chain_type_kwargs,
         | 
| 67 | 
            +
                        reduce_k_below_max_tokens=True,
         | 
| 68 | 
            +
                        verbose=True,
         | 
| 69 | 
            +
                    )
         | 
| 70 | 
            +
             | 
| 71 | 
            +
            def generate_response(user_input):  
         | 
| 72 | 
            +
                result = st.session_state['chain'](user_input)
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                bot_message = result['answer']
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                for i, doc in enumerate(result['source_documents']):
         | 
| 77 | 
            +
                    bot_message += '[' + str(i+1) + '] ' + doc.metadata['source'] + '(' + str(doc.metadata['page']) + ') '
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                return bot_message
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            def on_click_callback():
         | 
| 82 | 
            +
                with get_openai_callback() as cb:
         | 
| 83 | 
            +
                    human_prompt = st.session_state.human_prompt
         | 
| 84 | 
            +
                    llm_response = generate_response(human_prompt)
         | 
| 85 | 
            +
                    st.session_state.history.append(
         | 
| 86 | 
            +
                        Message("human", human_prompt)
         | 
| 87 | 
            +
                    )
         | 
| 88 | 
            +
                    st.session_state.history.append(
         | 
| 89 | 
            +
                        Message("ai", llm_response)
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    st.session_state.token_count += cb.total_tokens
         | 
| 92 | 
            +
             | 
| 93 | 
            +
            load_css()
         | 
| 94 | 
            +
            initialize_session_state()
         | 
| 95 | 
            +
             | 
| 96 | 
            +
            st.title("Hello Custom CSS Chatbot 🤖")
         | 
| 97 | 
            +
             | 
| 98 | 
            +
            chat_placeholder = st.container()
         | 
| 99 | 
            +
            prompt_placeholder = st.form("chat-form")
         | 
| 100 | 
            +
            credit_card_placeholder = st.empty()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            with chat_placeholder:
         | 
| 103 | 
            +
                for chat in st.session_state.history:
         | 
| 104 | 
            +
                    div = f"""
         | 
| 105 | 
            +
            <div class="chat-row 
         | 
| 106 | 
            +
                {'' if chat.origin == 'ai' else 'row-reverse'}">
         | 
| 107 | 
            +
                <img class="chat-icon" src="app/static/{
         | 
| 108 | 
            +
                    'ai_icon.png' if chat.origin == 'ai' 
         | 
| 109 | 
            +
                                  else 'user_icon.png'}"
         | 
| 110 | 
            +
                     width=32 height=32>
         | 
| 111 | 
            +
                <div class="chat-bubble
         | 
| 112 | 
            +
                {'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
         | 
| 113 | 
            +
                    ​{chat.message}
         | 
| 114 | 
            +
                </div>
         | 
| 115 | 
            +
            </div>
         | 
| 116 | 
            +
                    """
         | 
| 117 | 
            +
                    st.markdown(div, unsafe_allow_html=True)
         | 
| 118 | 
            +
                
         | 
| 119 | 
            +
                for _ in range(3):
         | 
| 120 | 
            +
                    st.markdown("")
         | 
| 121 | 
            +
             | 
| 122 | 
            +
            with prompt_placeholder:
         | 
| 123 | 
            +
                st.markdown("**Chat**")
         | 
| 124 | 
            +
                cols = st.columns((6, 1))
         | 
| 125 | 
            +
                cols[0].text_input(
         | 
| 126 | 
            +
                    "Chat",
         | 
| 127 | 
            +
                    value="Hello bot",
         | 
| 128 | 
            +
                    label_visibility="collapsed",
         | 
| 129 | 
            +
                    key="human_prompt",
         | 
| 130 | 
            +
                )
         | 
| 131 | 
            +
                cols[1].form_submit_button(
         | 
| 132 | 
            +
                    "Submit", 
         | 
| 133 | 
            +
                    type="primary", 
         | 
| 134 | 
            +
                    on_click=on_click_callback, 
         | 
| 135 | 
            +
                )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
            credit_card_placeholder.caption(f"""
         | 
| 138 | 
            +
            Used {st.session_state.token_count} tokens \n
         | 
| 139 | 
            +
            Debug Langchain conversation: 
         | 
| 140 | 
            +
            {st.session_state.conversation.memory.buffer}
         | 
| 141 | 
            +
            """)
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            components.html("""
         | 
| 144 | 
            +
            <script>
         | 
| 145 | 
            +
            const streamlitDoc = window.parent.document;
         | 
| 146 | 
            +
             | 
| 147 | 
            +
            const buttons = Array.from(
         | 
| 148 | 
            +
                streamlitDoc.querySelectorAll('.stButton > button')
         | 
| 149 | 
            +
            );
         | 
| 150 | 
            +
            const submitButton = buttons.find(
         | 
| 151 | 
            +
                el => el.innerText === 'Submit'
         | 
| 152 | 
            +
            );
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            streamlitDoc.addEventListener('keydown', function(e) {
         | 
| 155 | 
            +
                switch (e.key) {
         | 
| 156 | 
            +
                    case 'Enter':
         | 
| 157 | 
            +
                        submitButton.click();
         | 
| 158 | 
            +
                        break;
         | 
| 159 | 
            +
                }
         | 
| 160 | 
            +
            });
         | 
| 161 | 
            +
            </script>
         | 
| 162 | 
            +
            """, 
         | 
| 163 | 
            +
                height=0,
         | 
| 164 | 
            +
                width=0,
         | 
| 165 | 
            +
            )
         |