Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from langchain import memory as lc_memory | |
| from langsmith import Client | |
| from streamlit_feedback import streamlit_feedback | |
| from utils import get_expression_chain, get_retriever | |
| from langchain_core.tracers.context import collect_runs | |
| from dotenv import load_dotenv | |
| import os | |
| load_dotenv() | |
| GROQ_API_KEY = os.getenv('GROQ_API_KEY') | |
| HF_API_KEY = os.getenv("HF_API_KEY") | |
| COHERE_API_KEY = os.getenv("COHERE_API_KEY") | |
| LANGSMITH_TRACING="true" | |
| LANGSMITH_ENDPOINT="https://api.smith.langchain.com" | |
| LANGSMITH_API_KEY=os.getenv("LANGSMITH_API_KEY") | |
| LANGSMITH_PROJECT="pr-smug-rancher-51" | |
| client = Client() | |
| st.set_page_config(page_title = "MEDICAL CHATBOT") | |
| st.subheader(f"Hello! How can I assist you today!") | |
| memory = lc_memory.ConversationBufferMemory( | |
| chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"), | |
| return_messages=True, | |
| memory_key="chat_history", | |
| ) | |
| st.sidebar.markdown("## Feedback Scale") | |
| feedback_option = ( | |
| "thumbs" if st.sidebar.toggle(label="`Faces` β `Thumbs`", value=False) else "faces" | |
| ) | |
| with st.sidebar: | |
| model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"]) | |
| temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001) | |
| n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1) | |
| if st.sidebar.button("Clear message history"): | |
| print("Clearing message history") | |
| memory.clear() | |
| retriever = get_retriever(n_docs=n_docs) | |
| chain = get_expression_chain(retriever, model_name, temp) | |
| for msg in st.session_state.langchain_messages: | |
| avatar = "π¦" if msg.type == "ai" else None | |
| with st.chat_message(msg.type, avatar=avatar): | |
| st.markdown(msg.content) | |
| prompt = st.chat_input(placeholder="Describe your symptoms or medical questions ?") | |
| if prompt: | |
| with st.chat_message("user"): | |
| st.write(prompt) | |
| with st.chat_message("assistant", avatar="π"): | |
| message_placeholder = st.empty() | |
| full_response = "" | |
| input_dict = {"input": prompt.lower()} | |
| used_docs = retriever.get_relevant_documents(prompt.lower()) | |
| with collect_runs() as cb: | |
| for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}): | |
| full_response += chunk.content | |
| message_placeholder.markdown(full_response + "β") | |
| memory.save_context(input_dict, {"output": full_response}) | |
| st.session_state.run_id = cb.traced_runs[0].id | |
| message_placeholder.markdown(full_response) | |
| if used_docs: | |
| docs_content = "\n\n".join( | |
| [ | |
| f"Doc {i+1}:\n" | |
| f"Source: {doc.metadata['source']}\n" | |
| f"Title: {doc.metadata['title']}\n" | |
| f"Content: {doc.page_content}\n" | |
| for i, doc in enumerate(used_docs) | |
| ] | |
| ) | |
| with st.sidebar: | |
| st.download_button( | |
| label="Consulted Documents", | |
| data=docs_content, | |
| file_name="Consulted_documents.txt", | |
| mime="text/plain", | |
| ) | |
| if st.session_state.get("run_id"): | |
| run_id = st.session_state.run_id | |
| feedback = streamlit_feedback( | |
| feedback_type=feedback_option, | |
| optional_text_label="[Optional] Please provide an explanation", | |
| key=f"feedback_{run_id}", | |
| ) | |
| score_mappings = { | |
| "thumbs": {"π": 1, "π": 0}, | |
| "faces": {"π": 1, "π": 0.75, "π": 0.5, "π": 0.25, "π": 0}, | |
| } | |
| scores = score_mappings[feedback_option] | |
| if feedback: | |
| score = scores.get(feedback["score"]) | |
| if score is not None: | |
| feedback_type_str = f"{feedback_option} {feedback['score']}" | |
| feedback_record = client.create_feedback( | |
| run_id, | |
| feedback_type_str, | |
| score=score, | |
| comment=feedback.get("text"), | |
| ) | |
| st.session_state.feedback = { | |
| "feedback_id": str(feedback_record.id), | |
| "score": score, | |
| } | |
| else: | |
| st.warning("Invalid feedback score.") |