piyushmadhukar commited on
Commit
eda5790
Β·
verified Β·
1 Parent(s): d55ff4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -88
app.py CHANGED
@@ -1,136 +1,136 @@
1
- # from sentence_transformers import SentenceTransformer
2
- # from transformers import pipeline
3
- # from pydantic import BaseModel
4
- # import faiss
5
- # import numpy as np
6
- # import streamlit as st
7
- # from typing import List
8
- # import os
9
- # from dotenv import load_dotenv
10
- # import google.generativeai as genai
11
- # import torch
12
- # import asyncio
13
 
14
 
15
- # try:
16
- # asyncio.get_running_loop()
17
- # except RuntimeError:
18
- # asyncio.set_event_loop(asyncio.new_event_loop())
19
 
20
 
21
- # device = torch.device("cpu")
22
- # print("Device set to use CPU")
23
 
24
 
25
- # embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
26
 
27
 
28
- # summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=-1) # -1 forces CPU usage
29
 
30
 
31
- # load_dotenv()
32
- # api_key = os.getenv("API_KEY")
33
 
34
- # genai.configure(api_key=api_key)
35
 
36
 
37
- # gemini_model = genai.GenerativeModel(model_name="gemini-2.0-flash")
38
 
39
 
40
- # class UserQuery(BaseModel):
41
- # query: str
42
 
43
- # class RetrievedSection(BaseModel):
44
- # text: str
45
 
46
- # class SummarizedResponse(BaseModel):
47
- # summary: str
48
 
49
- # class FinalLLMResponse(BaseModel):
50
- # response: str
51
 
52
- # # Query Agent
53
- # def query_legal_documents(query: UserQuery, top_k=3) -> List[RetrievedSection]:
54
- # if not os.path.exists("faiss_index.idx") or not os.path.exists("doc_texts.npy"):
55
- # st.error("FAISS index or document data not found.")
56
- # return []
57
 
58
 
59
- # index = faiss.read_index("faiss_index.idx")
60
- # doc_texts = np.load("doc_texts.npy", allow_pickle=True)
61
 
62
 
63
- # query_embedding = embedding_model.encode([query.query], convert_to_numpy=True)
64
 
65
 
66
- # distances, indices = index.search(query_embedding, top_k)
67
 
68
 
69
- # retrieved_sections = [
70
- # RetrievedSection(text=doc_texts[i]) for i in indices[0] if i < len(doc_texts)
71
- # ]
72
 
73
- # return retrieved_sections
74
 
75
- # # Summarization Agent
76
- # def summarize_text(text_sections: List[RetrievedSection]) -> List[SummarizedResponse]:
77
- # summarized_results = [
78
- # SummarizedResponse(
79
- # summary=summarizer(section.text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
80
- # )
81
- # for section in text_sections
82
- # ]
83
- # return summarized_results
84
 
85
- # # LLM Agent to refine response
86
- # def generate_llm_response(summary_text: str) -> FinalLLMResponse:
87
- # response = gemini_model.generate_content(f"Provide a **brief** response. Do not use any special formatting like **. Here is the input:\n\n{summary_text}")
88
- # return FinalLLMResponse(response=response.text)
89
 
90
 
91
- # def main():
92
- # st.set_page_config(page_title="Legal Chatbot", layout="wide")
93
 
94
 
95
- # st.sidebar.title("Legal Chatbot Settings")
96
- # st.sidebar.write("This chatbot helps with legal queries by retrieving relevant legal documents, summarizing them, and generating AI-enhanced responses.")
97
 
98
 
99
- # st.title("πŸ§‘β€βš–οΈ Legal Chatbot")
100
- # st.markdown("### Ask your legal question below:")
101
 
102
- # user_query = st.text_input("Enter your legal query:")
103
 
104
- # if st.button("Submit", use_container_width=True):
105
- # if user_query:
106
- # st.info("Processing your request...")
107
 
108
- # query_obj = UserQuery(query=user_query)
109
- # retrieved_sections = query_legal_documents(query_obj)
110
 
111
- # if not retrieved_sections:
112
- # st.warning("No relevant legal documents found. Try refining your query.")
113
- # return
114
 
115
- # summarized_sections = summarize_text(retrieved_sections)
116
 
117
- # # Combine summaries for LLM
118
- # combined_summary = "\n".join([res.summary for res in summarized_sections])
119
- # llm_response = generate_llm_response(combined_summary)
120
 
121
- # # Display results
122
- # st.markdown("### πŸ“– Retrieved Data from Knowledge Base")
123
- # for section in retrieved_sections:
124
- # st.markdown(f"πŸ”Ή {section.text}")
125
 
126
- # st.markdown("### ✨ Summarized Response")
127
- # for res in summarized_sections:
128
- # st.markdown(f"βœ… {res.summary}")
129
 
130
- # st.markdown("### πŸ€– AI-Enhanced Response")
131
- # st.text_area("Final Answer:", llm_response.response, height=150)
132
 
133
- # if __name__ == "__main__":
134
- # main()
135
 
136
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ from transformers import pipeline
3
+ from pydantic import BaseModel
4
+ import faiss
5
+ import numpy as np
6
+ import streamlit as st
7
+ from typing import List
8
+ import os
9
+ from dotenv import load_dotenv
10
+ import google.generativeai as genai
11
+ import torch
12
+ import asyncio
13
 
14
 
15
+ try:
16
+ asyncio.get_running_loop()
17
+ except RuntimeError:
18
+ asyncio.set_event_loop(asyncio.new_event_loop())
19
 
20
 
21
+ device = torch.device("cpu")
22
+ print("Device set to use CPU")
23
 
24
 
25
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
26
 
27
 
28
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn", device=-1) # -1 forces CPU usage
29
 
30
 
31
+ load_dotenv()
32
+ api_key = os.getenv("API_KEY")
33
 
34
+ genai.configure(api_key=api_key)
35
 
36
 
37
+ gemini_model = genai.GenerativeModel(model_name="gemini-2.0-flash")
38
 
39
 
40
+ class UserQuery(BaseModel):
41
+ query: str
42
 
43
+ class RetrievedSection(BaseModel):
44
+ text: str
45
 
46
+ class SummarizedResponse(BaseModel):
47
+ summary: str
48
 
49
+ class FinalLLMResponse(BaseModel):
50
+ response: str
51
 
52
+ # Query Agent
53
+ def query_legal_documents(query: UserQuery, top_k=3) -> List[RetrievedSection]:
54
+ if not os.path.exists("faiss_index.idx") or not os.path.exists("doc_texts.npy"):
55
+ st.error("FAISS index or document data not found.")
56
+ return []
57
 
58
 
59
+ index = faiss.read_index("faiss_index.idx")
60
+ doc_texts = np.load("doc_texts.npy", allow_pickle=True)
61
 
62
 
63
+ query_embedding = embedding_model.encode([query.query], convert_to_numpy=True)
64
 
65
 
66
+ distances, indices = index.search(query_embedding, top_k)
67
 
68
 
69
+ retrieved_sections = [
70
+ RetrievedSection(text=doc_texts[i]) for i in indices[0] if i < len(doc_texts)
71
+ ]
72
 
73
+ return retrieved_sections
74
 
75
+ # Summarization Agent
76
+ def summarize_text(text_sections: List[RetrievedSection]) -> List[SummarizedResponse]:
77
+ summarized_results = [
78
+ SummarizedResponse(
79
+ summary=summarizer(section.text, max_length=100, min_length=30, do_sample=False)[0]["summary_text"]
80
+ )
81
+ for section in text_sections
82
+ ]
83
+ return summarized_results
84
 
85
+ # LLM Agent to refine response
86
+ def generate_llm_response(summary_text: str) -> FinalLLMResponse:
87
+ response = gemini_model.generate_content(f"Provide a **brief** response. Do not use any special formatting like **. Here is the input:\n\n{summary_text}")
88
+ return FinalLLMResponse(response=response.text)
89
 
90
 
91
+ def main():
92
+ st.set_page_config(page_title="Legal Chatbot", layout="wide")
93
 
94
 
95
+ st.sidebar.title("Legal Chatbot Settings")
96
+ st.sidebar.write("This chatbot helps with legal queries by retrieving relevant legal documents, summarizing them, and generating AI-enhanced responses.")
97
 
98
 
99
+ st.title("πŸ§‘β€βš–οΈ Legal Chatbot")
100
+ st.markdown("### Ask your legal question below:")
101
 
102
+ user_query = st.text_input("Enter your legal query:")
103
 
104
+ if st.button("Submit", use_container_width=True):
105
+ if user_query:
106
+ st.info("Processing your request...")
107
 
108
+ query_obj = UserQuery(query=user_query)
109
+ retrieved_sections = query_legal_documents(query_obj)
110
 
111
+ if not retrieved_sections:
112
+ st.warning("No relevant legal documents found. Try refining your query.")
113
+ return
114
 
115
+ summarized_sections = summarize_text(retrieved_sections)
116
 
117
+ # Combine summaries for LLM
118
+ combined_summary = "\n".join([res.summary for res in summarized_sections])
119
+ llm_response = generate_llm_response(combined_summary)
120
 
121
+ # Display results
122
+ st.markdown("### πŸ“– Retrieved Data from Knowledge Base")
123
+ for section in retrieved_sections:
124
+ st.markdown(f"πŸ”Ή {section.text}")
125
 
126
+ st.markdown("### ✨ Summarized Response")
127
+ for res in summarized_sections:
128
+ st.markdown(f"βœ… {res.summary}")
129
 
130
+ st.markdown("### πŸ€– AI-Enhanced Response")
131
+ st.text_area("Final Answer:", llm_response.response, height=150)
132
 
133
+ if __name__ == "__main__":
134
+ main()
135
 
136