piyushmadhukar commited on
Commit
7e9e399
Β·
verified Β·
1 Parent(s): 7c22b31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +220 -134
app.py CHANGED
@@ -1,134 +1,220 @@
1
- from sentence_transformers import SentenceTransformer
2
- from transformers import pipeline
3
- from pydantic import BaseModel
4
- import faiss
5
- import numpy as np
6
- import streamlit as st
7
- from typing import List
8
- import os
9
- from dotenv import load_dotenv
10
- import google.generativeai as genai
11
- import torch
12
- import asyncio
13
-
14
-
15
- try:
16
- asyncio.get_running_loop()
17
- except RuntimeError:
18
- asyncio.set_event_loop(asyncio.new_event_loop())
19
-
20
-
21
- device = torch.device("cpu")
22
- print("Device set to use CPU")
23
-
24
-
25
- embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
26
-
27
-
28
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=-1) # -1 forces CPU usage
29
-
30
-
31
- load_dotenv()
32
- api_key = os.getenv("API_KEY")
33
-
34
- genai.configure(api_key=api_key)
35
-
36
-
37
- gemini_model = genai.GenerativeModel(model_name="gemini-2.0-flash")
38
-
39
-
40
- class UserQuery(BaseModel):
41
- query: str
42
-
43
- class RetrievedSection(BaseModel):
44
- text: str
45
-
46
- class SummarizedResponse(BaseModel):
47
- summary: str
48
-
49
- class FinalLLMResponse(BaseModel):
50
- response: str
51
-
52
- # Query Agent
53
- def query_legal_documents(query: UserQuery, top_k=3) -> List[RetrievedSection]:
54
- if not os.path.exists("faiss_index.idx") or not os.path.exists("doc_texts.npy"):
55
- st.error("FAISS index or document data not found.")
56
- return []
57
-
58
-
59
- index = faiss.read_index("faiss_index.idx")
60
- doc_texts = np.load("doc_texts.npy", allow_pickle=True)
61
-
62
-
63
- query_embedding = embedding_model.encode([query.query], convert_to_numpy=True)
64
-
65
-
66
- distances, indices = index.search(query_embedding, top_k)
67
-
68
-
69
- retrieved_sections = [
70
- RetrievedSection(text=doc_texts[i]) for i in indices[0] if i < len(doc_texts)
71
- ]
72
-
73
- return retrieved_sections
74
-
75
- # Summarization Agent
76
- def summarize_text(text_sections: List[RetrievedSection]) -> List[SummarizedResponse]:
77
- summarized_results = [
78
- SummarizedResponse(
79
- summary=summarizer(section.text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
80
- )
81
- for section in text_sections
82
- ]
83
- return summarized_results
84
-
85
- # LLM Agent to refine response
86
- def generate_llm_response(summary_text: str) -> FinalLLMResponse:
87
- response = gemini_model.generate_content(f"Provide a **brief** response. Do not use any special formatting like **. Here is the input:\n\n{summary_text}")
88
- return FinalLLMResponse(response=response.text)
89
-
90
-
91
- def main():
92
- st.set_page_config(page_title="Legal Chatbot", layout="wide")
93
-
94
-
95
- st.sidebar.title("Legal Chatbot Settings")
96
- st.sidebar.write("This chatbot helps with legal queries by retrieving relevant legal documents, summarizing them, and generating AI-enhanced responses.")
97
-
98
-
99
- st.title("πŸ§‘β€βš–οΈ Legal Chatbot")
100
- st.markdown("### Ask your legal question below:")
101
-
102
- user_query = st.text_input("Enter your legal query:")
103
-
104
- if st.button("Submit", use_container_width=True):
105
- if user_query:
106
- st.info("Processing your request...")
107
-
108
- query_obj = UserQuery(query=user_query)
109
- retrieved_sections = query_legal_documents(query_obj)
110
-
111
- if not retrieved_sections:
112
- st.warning("No relevant legal documents found. Try refining your query.")
113
- return
114
-
115
- summarized_sections = summarize_text(retrieved_sections)
116
-
117
- # Combine summaries for LLM
118
- combined_summary = "\n".join([res.summary for res in summarized_sections])
119
- llm_response = generate_llm_response(combined_summary)
120
-
121
- # Display results
122
- st.markdown("### πŸ“– Retrieved Data from Knowledge Base")
123
- for section in retrieved_sections:
124
- st.markdown(f"πŸ”Ή {section.text}")
125
-
126
- st.markdown("### ✨ Summarized Response")
127
- for res in summarized_sections:
128
- st.markdown(f"βœ… {res.summary}")
129
-
130
- st.markdown("### πŸ€– AI-Enhanced Response")
131
- st.text_area("Final Answer:", llm_response.response, height=150)
132
-
133
- if __name__ == "__main__":
134
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from sentence_transformers import SentenceTransformer
2
+ # from transformers import pipeline
3
+ # from pydantic import BaseModel
4
+ # import faiss
5
+ # import numpy as np
6
+ # import streamlit as st
7
+ # from typing import List
8
+ # import os
9
+ # from dotenv import load_dotenv
10
+ # import google.generativeai as genai
11
+ # import torch
12
+ # import asyncio
13
+
14
+
15
+ # try:
16
+ # asyncio.get_running_loop()
17
+ # except RuntimeError:
18
+ # asyncio.set_event_loop(asyncio.new_event_loop())
19
+
20
+
21
+ # device = torch.device("cpu")
22
+ # print("Device set to use CPU")
23
+
24
+
25
+ # embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
26
+
27
+
28
+ # summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=-1) # -1 forces CPU usage
29
+
30
+
31
+ # load_dotenv()
32
+ # api_key = os.getenv("API_KEY")
33
+
34
+ # genai.configure(api_key=api_key)
35
+
36
+
37
+ # gemini_model = genai.GenerativeModel(model_name="gemini-2.0-flash")
38
+
39
+
40
+ # class UserQuery(BaseModel):
41
+ # query: str
42
+
43
+ # class RetrievedSection(BaseModel):
44
+ # text: str
45
+
46
+ # class SummarizedResponse(BaseModel):
47
+ # summary: str
48
+
49
+ # class FinalLLMResponse(BaseModel):
50
+ # response: str
51
+
52
+ # # Query Agent
53
+ # def query_legal_documents(query: UserQuery, top_k=3) -> List[RetrievedSection]:
54
+ # if not os.path.exists("faiss_index.idx") or not os.path.exists("doc_texts.npy"):
55
+ # st.error("FAISS index or document data not found.")
56
+ # return []
57
+
58
+
59
+ # index = faiss.read_index("faiss_index.idx")
60
+ # doc_texts = np.load("doc_texts.npy", allow_pickle=True)
61
+
62
+
63
+ # query_embedding = embedding_model.encode([query.query], convert_to_numpy=True)
64
+
65
+
66
+ # distances, indices = index.search(query_embedding, top_k)
67
+
68
+
69
+ # retrieved_sections = [
70
+ # RetrievedSection(text=doc_texts[i]) for i in indices[0] if i < len(doc_texts)
71
+ # ]
72
+
73
+ # return retrieved_sections
74
+
75
+ # # Summarization Agent
76
+ # def summarize_text(text_sections: List[RetrievedSection]) -> List[SummarizedResponse]:
77
+ # summarized_results = [
78
+ # SummarizedResponse(
79
+ # summary=summarizer(section.text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
80
+ # )
81
+ # for section in text_sections
82
+ # ]
83
+ # return summarized_results
84
+
85
+ # # LLM Agent to refine response
86
+ # def generate_llm_response(summary_text: str) -> FinalLLMResponse:
87
+ # response = gemini_model.generate_content(f"Provide a **brief** response. Do not use any special formatting like **. Here is the input:\n\n{summary_text}")
88
+ # return FinalLLMResponse(response=response.text)
89
+
90
+
91
+ # def main():
92
+ # st.set_page_config(page_title="Legal Chatbot", layout="wide")
93
+
94
+
95
+ # st.sidebar.title("Legal Chatbot Settings")
96
+ # st.sidebar.write("This chatbot helps with legal queries by retrieving relevant legal documents, summarizing them, and generating AI-enhanced responses.")
97
+
98
+
99
+ # st.title("πŸ§‘β€βš–οΈ Legal Chatbot")
100
+ # st.markdown("### Ask your legal question below:")
101
+
102
+ # user_query = st.text_input("Enter your legal query:")
103
+
104
+ # if st.button("Submit", use_container_width=True):
105
+ # if user_query:
106
+ # st.info("Processing your request...")
107
+
108
+ # query_obj = UserQuery(query=user_query)
109
+ # retrieved_sections = query_legal_documents(query_obj)
110
+
111
+ # if not retrieved_sections:
112
+ # st.warning("No relevant legal documents found. Try refining your query.")
113
+ # return
114
+
115
+ # summarized_sections = summarize_text(retrieved_sections)
116
+
117
+ # # Combine summaries for LLM
118
+ # combined_summary = "\n".join([res.summary for res in summarized_sections])
119
+ # llm_response = generate_llm_response(combined_summary)
120
+
121
+ # # Display results
122
+ # st.markdown("### πŸ“– Retrieved Data from Knowledge Base")
123
+ # for section in retrieved_sections:
124
+ # st.markdown(f"πŸ”Ή {section.text}")
125
+
126
+ # st.markdown("### ✨ Summarized Response")
127
+ # for res in summarized_sections:
128
+ # st.markdown(f"βœ… {res.summary}")
129
+
130
+ # st.markdown("### πŸ€– AI-Enhanced Response")
131
+ # st.text_area("Final Answer:", llm_response.response, height=150)
132
+
133
+ # if __name__ == "__main__":
134
+ # main()
135
+
136
+ from sentence_transformers import SentenceTransformer
137
+ from transformers import pipeline
138
+ import faiss
139
+ import numpy as np
140
+ import streamlit as st
141
+ import os
142
+ from dotenv import load_dotenv
143
+ import google.generativeai as genai
144
+ import torch
145
+
146
+ # Set device to CPU
147
+ device = "cpu"
148
+
149
+ # Load models once
150
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2-int8", device=device, normalize_embeddings=True)
151
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=-1)
152
+
153
+ # Load API Key
154
+ load_dotenv()
155
+ api_key = os.getenv("API_KEY")
156
+ genai.configure(api_key=api_key)
157
+ gemini_model = genai.GenerativeModel(model_name="gemini-2.0-flash")
158
+
159
+
160
+
161
+ # Query Legal Documents
162
+ def query_legal_documents(query: str, top_k=3):
163
+ if faiss_index is None or doc_texts is None:
164
+ st.error("FAISS index or document data not found.")
165
+ return []
166
+
167
+ query_embedding = embedding_model.encode([query])
168
+ distances, indices = faiss_index.search(query_embedding, top_k)
169
+
170
+ return [doc_texts[i] for i in indices[0] if i < len(doc_texts)]
171
+
172
+ # Summarization Agent (Batch Processing)
173
+ def summarize_text(text_sections):
174
+ texts = [section for section in text_sections]
175
+ summaries = summarizer(texts, max_length=100, min_length=30, do_sample=False)
176
+ return [summary["summary_text"] for summary in summaries]
177
+
178
+ # LLM Agent (Skip if Summaries are Sufficient)
179
+ def generate_llm_response(summary_text):
180
+ if len(summary_text) < 200:
181
+ return summary_text # Skip LLM for short summaries
182
+ response = gemini_model.generate_content(summary_text)
183
+ return response.text
184
+
185
+ # Streamlit App
186
+ def main():
187
+ st.set_page_config(page_title="Legal Chatbot", layout="wide")
188
+ st.title("πŸ§‘β€βš–οΈ Legal Chatbot")
189
+ user_query = st.text_input("Enter your legal query:")
190
+
191
+ if st.button("Submit"):
192
+ if user_query:
193
+ st.info("Processing your request...")
194
+ retrieved_sections = query_legal_documents(user_query)
195
+
196
+ if not retrieved_sections:
197
+ st.warning("No relevant legal documents found.")
198
+ return
199
+
200
+ summarized_sections = summarize_text(retrieved_sections)
201
+ combined_summary = "\n".join(summarized_sections)
202
+ final_response = generate_llm_response(combined_summary)
203
+
204
+ st.markdown("### πŸ“– Retrieved Data")
205
+ for section in retrieved_sections:
206
+ st.markdown(f"πŸ”Ή {section}")
207
+
208
+ st.markdown("### ✨ Summarized Response")
209
+ for summary in summarized_sections:
210
+ st.markdown(f"βœ… {summary}")
211
+
212
+ st.markdown("### πŸ€– AI-Enhanced Response")
213
+ st.text_area("Final Answer:", final_response, height=150)
214
+
215
+ if __name__ == "__main__":
216
+ main()
217
+
218
+
219
+
220
+