adarsh-maurya commited on
Commit
0dae560
·
verified ·
1 Parent(s): a1f5731

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -32
app.py CHANGED
@@ -1,5 +1,6 @@
1
- import time
2
  import os
 
3
  import streamlit as st
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -7,17 +8,16 @@ from langchain.prompts import PromptTemplate
7
  from langchain.memory import ConversationBufferWindowMemory
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain_together import Together
10
-
11
  from footer import footer
12
 
13
- # Set the Streamlit page configuration and theme
14
  st.set_page_config(page_title="BharatLAW", layout="centered")
15
 
16
- # Display the logo image
17
  col1, col2, col3 = st.columns([1, 30, 1])
18
  with col2:
19
- st.image("https://github.com/Nike-one/BharatLAW/blob/master/images/banner.png?raw=true", use_column_width=True)
20
 
 
21
  def hide_hamburger_menu():
22
  st.markdown("""
23
  <style>
@@ -28,7 +28,7 @@ def hide_hamburger_menu():
28
 
29
  hide_hamburger_menu()
30
 
31
- # Initialize session state for messages and memory
32
  if "messages" not in st.session_state:
33
  st.session_state.messages = []
34
 
@@ -37,23 +37,27 @@ if "memory" not in st.session_state:
37
 
38
  @st.cache_resource
39
  def load_embeddings():
40
- """Load and cache the embeddings model."""
41
  return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
42
 
43
  embeddings = load_embeddings()
44
- db = FAISS.load_local("ipc_embed_db", embeddings, allow_dangerous_deserialization=True)
 
 
 
 
 
 
 
 
 
 
 
45
  db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
46
 
 
47
  prompt_template = """
48
  <s>[INST]
49
- As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses. Ensure your answers meet these criteria:
50
- - Respond in a bullet-point format to clearly delineate distinct aspects of the legal query.
51
- - Each point should accurately reflect the breadth of the legal provision in question, avoiding over-specificity unless directly relevant to the user's query.
52
- - Clarify the general applicability of the legal rules or sections mentioned, highlighting any common misconceptions or frequently misunderstood aspects.
53
- - Limit responses to essential information that directly addresses the user's question, providing concise yet comprehensive explanations.
54
- - Avoid assuming specific contexts or details not provided in the query, focusing on delivering universally applicable legal interpretations unless otherwise specified.
55
- - Conclude with a brief summary that captures the essence of the legal discussion and corrects any common misinterpretations related to the topic.
56
-
57
  CONTEXT: {context}
58
  CHAT HISTORY: {chat_history}
59
  QUESTION: {question}
@@ -66,21 +70,17 @@ ANSWER:
66
  </s>[INST]
67
  """
68
 
69
-
70
-
71
- prompt = PromptTemplate(template=prompt_template,
72
- input_variables=['context', 'question', 'chat_history'])
73
 
74
  api_key = os.getenv('TOGETHER_API_KEY')
75
  if not api_key:
76
  st.error("API key for Together is missing. Please set the TOGETHER_API_KEY environment variable.")
77
 
78
  llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.5, max_tokens=1024, together_api_key=api_key)
79
-
80
  qa = ConversationalRetrievalChain.from_llm(llm=llm, memory=st.session_state.memory, retriever=db_retriever, combine_docs_chain_kwargs={'prompt': prompt})
81
 
 
82
  def extract_answer(full_response):
83
- """Extracts the answer from the LLM's full response by removing the instructional text."""
84
  try:
85
  answer_start = full_response.find("Response:")
86
  if answer_start != -1:
@@ -94,29 +94,27 @@ def reset_conversation():
94
  st.session_state.messages = []
95
  st.session_state.memory.clear()
96
 
 
97
  for message in st.session_state.messages:
98
  with st.chat_message(message["role"]):
99
  st.write(message["content"])
100
 
101
-
102
  input_prompt = st.chat_input("Say something...")
103
  if input_prompt:
104
  with st.chat_message("user"):
105
  st.markdown(f"**You:** {input_prompt}")
106
-
107
  st.session_state.messages.append({"role": "user", "content": input_prompt})
 
108
  with st.chat_message("assistant"):
109
  with st.spinner("Thinking 💡..."):
110
  result = qa.invoke(input=input_prompt)
111
  message_placeholder = st.empty()
112
  answer = extract_answer(result["answer"])
113
 
114
- # Initialize the response message
115
  full_response = "⚠️ **_Gentle reminder: We generally ensure precise information, but do double-check._** \n\n\n"
116
  for chunk in answer:
117
- # Simulate typing by appending chunks of the response over time
118
  full_response += chunk
119
- time.sleep(0.02) # Adjust the sleep time to control the "typing" speed
120
  message_placeholder.markdown(full_response + " |", unsafe_allow_html=True)
121
 
122
  st.session_state.messages.append({"role": "assistant", "content": answer})
@@ -124,9 +122,5 @@ if input_prompt:
124
  if st.button('🗑️ Reset All Chat', on_click=reset_conversation):
125
  st.experimental_rerun()
126
 
127
-
128
-
129
- # Define the CSS to style the footer
130
  footer()
131
-
132
-
 
1
+
2
  import os
3
+ import time
4
  import streamlit as st
5
  from langchain_community.vectorstores import FAISS
6
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
8
  from langchain.memory import ConversationBufferWindowMemory
9
  from langchain.chains import ConversationalRetrievalChain
10
  from langchain_together import Together
 
11
  from footer import footer
12
 
 
13
  st.set_page_config(page_title="BharatLAW", layout="centered")
14
 
15
+ # Display logo
16
  col1, col2, col3 = st.columns([1, 30, 1])
17
  with col2:
18
+ st.image("https://github.com/Nike-one/BharatLAW/blob/master/images/banner.png?raw=true", use_container_width=True)
19
 
20
+ # Hide hamburger and default footer
21
  def hide_hamburger_menu():
22
  st.markdown("""
23
  <style>
 
28
 
29
  hide_hamburger_menu()
30
 
31
+ # Session state init
32
  if "messages" not in st.session_state:
33
  st.session_state.messages = []
34
 
 
37
 
38
  @st.cache_resource
39
  def load_embeddings():
 
40
  return HuggingFaceEmbeddings(model_name="law-ai/InLegalBERT")
41
 
42
  embeddings = load_embeddings()
43
+
44
+ # ✅ Add FAISS index check before loading
45
+ index_dir = "ipc_embed_db"
46
+ index_faiss = os.path.join(index_dir, "index.faiss")
47
+ index_pkl = os.path.join(index_dir, "index.pkl")
48
+
49
+ if os.path.exists(index_faiss) and os.path.exists(index_pkl):
50
+ db = FAISS.load_local(index_dir, embeddings, allow_dangerous_deserialization=True)
51
+ else:
52
+ st.error("❌ FAISS index files not found! Please run the FAISS indexing script first.")
53
+ st.stop()
54
+
55
  db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 3})
56
 
57
+ # Prompt template for QA
58
  prompt_template = """
59
  <s>[INST]
60
+ As a legal chatbot specializing in the Indian Penal Code, you are tasked with providing highly accurate and contextually appropriate responses...
 
 
 
 
 
 
 
61
  CONTEXT: {context}
62
  CHAT HISTORY: {chat_history}
63
  QUESTION: {question}
 
70
  </s>[INST]
71
  """
72
 
73
+ prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question', 'chat_history'])
 
 
 
74
 
75
  api_key = os.getenv('TOGETHER_API_KEY')
76
  if not api_key:
77
  st.error("API key for Together is missing. Please set the TOGETHER_API_KEY environment variable.")
78
 
79
  llm = Together(model="mistralai/Mixtral-8x22B-Instruct-v0.1", temperature=0.5, max_tokens=1024, together_api_key=api_key)
 
80
  qa = ConversationalRetrievalChain.from_llm(llm=llm, memory=st.session_state.memory, retriever=db_retriever, combine_docs_chain_kwargs={'prompt': prompt})
81
 
82
+ # Helper functions
83
  def extract_answer(full_response):
 
84
  try:
85
  answer_start = full_response.find("Response:")
86
  if answer_start != -1:
 
94
  st.session_state.messages = []
95
  st.session_state.memory.clear()
96
 
97
+ # Chat Interface
98
  for message in st.session_state.messages:
99
  with st.chat_message(message["role"]):
100
  st.write(message["content"])
101
 
 
102
  input_prompt = st.chat_input("Say something...")
103
  if input_prompt:
104
  with st.chat_message("user"):
105
  st.markdown(f"**You:** {input_prompt}")
 
106
  st.session_state.messages.append({"role": "user", "content": input_prompt})
107
+
108
  with st.chat_message("assistant"):
109
  with st.spinner("Thinking 💡..."):
110
  result = qa.invoke(input=input_prompt)
111
  message_placeholder = st.empty()
112
  answer = extract_answer(result["answer"])
113
 
 
114
  full_response = "⚠️ **_Gentle reminder: We generally ensure precise information, but do double-check._** \n\n\n"
115
  for chunk in answer:
 
116
  full_response += chunk
117
+ time.sleep(0.02)
118
  message_placeholder.markdown(full_response + " |", unsafe_allow_html=True)
119
 
120
  st.session_state.messages.append({"role": "assistant", "content": answer})
 
122
  if st.button('🗑️ Reset All Chat', on_click=reset_conversation):
123
  st.experimental_rerun()
124
 
125
+ # Footer styling
 
 
126
  footer()