Goodnight7 commited on
Commit
d87ca70
Β·
verified Β·
1 Parent(s): ae7660b

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +128 -0
main.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain import memory as lc_memory
3
+ from langsmith import Client
4
+ from streamlit_feedback import streamlit_feedback
5
+ from utils1 import get_expression_chain, get_retriever
6
+ from langchain_core.tracers.context import collect_runs
7
+ from dotenv import load_dotenv
8
+ import os
9
+
10
+ load_dotenv()
11
+
12
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY')
13
+ HF_API_KEY = os.getenv("HF_API_KEY")
14
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
15
+
16
+
17
+
18
+
19
+ LANGSMITH_TRACING="true"
20
+ LANGSMITH_ENDPOINT="https://api.smith.langchain.com"
21
+ LANGSMITH_API_KEY=os.getenv("LANGSMITH_API_KEY")
22
+ LANGSMITH_PROJECT="pr-smug-rancher-51"
23
+
24
+
25
+ client = Client()
26
+ st.set_page_config(page_title = "MEDICAL CHATBOT")
27
+ st.subheader(f"Hello! How can I assist you today!")
28
+
29
+ memory = lc_memory.ConversationBufferMemory(
30
+ chat_memory=lc_memory.StreamlitChatMessageHistory(key="langchain_messages"),
31
+ return_messages=True,
32
+ memory_key="chat_history",
33
+ )
34
+
35
+ st.sidebar.markdown("## Feedback Scale")
36
+ feedback_option = (
37
+ "thumbs" if st.sidebar.toggle(label="`Faces` ⇄ `Thumbs`", value=False) else "faces"
38
+ )
39
+
40
+ with st.sidebar:
41
+ model_name = st.selectbox("**Model**", options=["llama-3.1-70b-versatile","gemma2-9b-it","gemma-7b-it","llama-3.2-3b-preview", "llama3-70b-8192", "mixtral-8x7b-32768"])
42
+ temp = st.slider("**Temperature**", min_value=0.0, max_value=1.0, step=0.001)
43
+ n_docs = st.number_input("**Number of retrieved documents**", min_value=0, max_value=10, value=5, step=1)
44
+
45
+ if st.sidebar.button("Clear message history"):
46
+ print("Clearing message history")
47
+ memory.clear()
48
+
49
+ retriever = get_retriever(n_docs=n_docs)
50
+ chain = get_expression_chain(retriever, model_name, temp)
51
+
52
+ for msg in st.session_state.langchain_messages:
53
+ avatar = "🦜" if msg.type == "ai" else None
54
+ with st.chat_message(msg.type, avatar=avatar):
55
+ st.markdown(msg.content)
56
+
57
+ prompt = st.chat_input(placeholder="Describe your symptoms or medical questions ?")
58
+
59
+ if prompt:
60
+ with st.chat_message("user"):
61
+ st.write(prompt)
62
+
63
+ with st.chat_message("assistant", avatar="πŸ’"):
64
+ message_placeholder = st.empty()
65
+ full_response = ""
66
+ input_dict = {"input": prompt.lower()}
67
+ used_docs = retriever.get_relevant_documents(prompt.lower())
68
+
69
+ with collect_runs() as cb:
70
+ for chunk in chain.stream(input_dict, config={"tags": ["MEDICAL CHATBOT"]}):
71
+ full_response += chunk.content
72
+ message_placeholder.markdown(full_response + "β–Œ")
73
+ memory.save_context(input_dict, {"output": full_response})
74
+
75
+ st.session_state.run_id = cb.traced_runs[0].id
76
+ message_placeholder.markdown(full_response)
77
+
78
+ if used_docs:
79
+ docs_content = "\n\n".join(
80
+ [
81
+ f"Doc {i+1}:\n"
82
+ f"Source: {doc.metadata['source']}\n"
83
+ f"Title: {doc.metadata['title']}\n"
84
+ f"Content: {doc.page_content}\n"
85
+ for i, doc in enumerate(used_docs)
86
+ ]
87
+ )
88
+ with st.sidebar:
89
+ st.download_button(
90
+ label="Consulted Documents",
91
+ data=docs_content,
92
+ file_name="Consulted_documents.txt",
93
+ mime="text/plain",
94
+ )
95
+
96
+ if st.session_state.get("run_id"):
97
+ run_id = st.session_state.run_id
98
+ feedback = streamlit_feedback(
99
+ feedback_type=feedback_option,
100
+ optional_text_label="[Optional] Please provide an explanation",
101
+ key=f"feedback_{run_id}",
102
+ )
103
+
104
+ score_mappings = {
105
+ "thumbs": {"πŸ‘": 1, "πŸ‘Ž": 0},
106
+ "faces": {"πŸ˜€": 1, "πŸ™‚": 0.75, "😐": 0.5, "πŸ™": 0.25, "😞": 0},
107
+ }
108
+
109
+ scores = score_mappings[feedback_option]
110
+
111
+ if feedback:
112
+ score = scores.get(feedback["score"])
113
+
114
+ if score is not None:
115
+ feedback_type_str = f"{feedback_option} {feedback['score']}"
116
+
117
+ feedback_record = client.create_feedback(
118
+ run_id,
119
+ feedback_type_str,
120
+ score=score,
121
+ comment=feedback.get("text"),
122
+ )
123
+ st.session_state.feedback = {
124
+ "feedback_id": str(feedback_record.id),
125
+ "score": score,
126
+ }
127
+ else:
128
+ st.warning("Invalid feedback score.")