Update app.py
Browse files
app.py
CHANGED
@@ -1,23 +1,38 @@
|
|
1 |
import os
|
2 |
-
import logging
|
3 |
import faiss
|
4 |
import streamlit as st
|
5 |
-
from
|
6 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
7 |
from langchain.vectorstores import FAISS
|
8 |
-
from langchain_community.llms import HuggingFacePipeline
|
9 |
from langchain.chains import RetrievalQA
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
|
12 |
-
|
13 |
-
logger = logging.getLogger(__name__)
|
14 |
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
@st.cache_resource
|
19 |
def load_llm():
|
20 |
-
"""
|
|
|
|
|
|
|
21 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
22 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
23 |
pipe = pipeline(
|
@@ -29,30 +44,23 @@ def load_llm():
|
|
29 |
temperature=0.3,
|
30 |
top_p=0.95
|
31 |
)
|
32 |
-
return
|
33 |
-
|
34 |
-
def load_faiss_index():
|
35 |
-
"""Load the FAISS index for vector search."""
|
36 |
-
index_path = "faiss_index/index.faiss"
|
37 |
-
if not os.path.exists(index_path):
|
38 |
-
st.error(f"FAISS index not found at {index_path}. Please ensure the file exists.")
|
39 |
-
raise RuntimeError(f"FAISS index not found at {index_path}.")
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
except Exception as e:
|
48 |
-
st.error(f"Failed to load FAISS index: {e}")
|
49 |
-
logger.exception("Exception in load_faiss_index")
|
50 |
-
raise
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
54 |
try:
|
55 |
-
|
|
|
56 |
llm = load_llm()
|
57 |
qa = RetrievalQA.from_chain_type(
|
58 |
llm=llm,
|
@@ -60,16 +68,14 @@ def process_answer(instruction):
|
|
60 |
retriever=retriever,
|
61 |
return_source_documents=True
|
62 |
)
|
63 |
-
|
64 |
-
answer =
|
65 |
-
return answer,
|
66 |
except Exception as e:
|
67 |
st.error(f"An error occurred while processing the answer: {e}")
|
68 |
-
logger.exception("Exception in process_answer")
|
69 |
return "An error occurred while processing your request.", {}
|
70 |
|
71 |
def main():
|
72 |
-
"""Main function to run the Streamlit application."""
|
73 |
st.title("Search Your PDF ππ")
|
74 |
|
75 |
with st.expander("About the App"):
|
@@ -90,7 +96,6 @@ def main():
|
|
90 |
st.write(metadata)
|
91 |
except Exception as e:
|
92 |
st.error(f"An unexpected error occurred: {e}")
|
93 |
-
logger.exception("Unexpected error in main function")
|
94 |
|
95 |
if __name__ == '__main__':
|
96 |
main()
|
|
|
1 |
import os
|
|
|
2 |
import faiss
|
3 |
import streamlit as st
|
4 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
5 |
from langchain.vectorstores import FAISS
|
|
|
6 |
from langchain.chains import RetrievalQA
|
7 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
8 |
+
|
9 |
+
def load_faiss_index(index_path):
|
10 |
+
"""
|
11 |
+
Load a FAISS index from a specified path.
|
12 |
|
13 |
+
Parameters:
|
14 |
+
- index_path (str): Path to the FAISS index file.
|
|
|
15 |
|
16 |
+
Returns:
|
17 |
+
- faiss.Index: Loaded FAISS index object.
|
18 |
+
"""
|
19 |
+
if not os.path.exists(index_path):
|
20 |
+
st.error(f"FAISS index not found at {index_path}. Please create the index first.")
|
21 |
+
raise FileNotFoundError(f"FAISS index not found at {index_path}.")
|
22 |
+
|
23 |
+
try:
|
24 |
+
index = faiss.read_index(index_path)
|
25 |
+
st.success("FAISS index loaded successfully.")
|
26 |
+
return index
|
27 |
+
except Exception as e:
|
28 |
+
st.error(f"Failed to load FAISS index: {e}")
|
29 |
+
raise
|
30 |
|
|
|
31 |
def load_llm():
|
32 |
+
"""
|
33 |
+
Load the HuggingFace model for generating responses.
|
34 |
+
"""
|
35 |
+
checkpoint = "LaMini-T5-738M"
|
36 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
37 |
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
|
38 |
pipe = pipeline(
|
|
|
44 |
temperature=0.3,
|
45 |
top_p=0.95
|
46 |
)
|
47 |
+
return pipe
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
def process_answer(question):
|
50 |
+
"""
|
51 |
+
Process the user's question using the FAISS index and LLM.
|
52 |
+
|
53 |
+
Parameters:
|
54 |
+
- question (str): User's question to be processed.
|
|
|
|
|
|
|
|
|
55 |
|
56 |
+
Returns:
|
57 |
+
- str: The answer generated by the LLM.
|
58 |
+
"""
|
59 |
+
index_path = 'faiss_index/index.faiss'
|
60 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
61 |
try:
|
62 |
+
faiss_index = load_faiss_index(index_path)
|
63 |
+
retriever = FAISS(index=faiss_index, embeddings=embeddings)
|
64 |
llm = load_llm()
|
65 |
qa = RetrievalQA.from_chain_type(
|
66 |
llm=llm,
|
|
|
68 |
retriever=retriever,
|
69 |
return_source_documents=True
|
70 |
)
|
71 |
+
result = qa.invoke(question)
|
72 |
+
answer = result['result']
|
73 |
+
return answer, result
|
74 |
except Exception as e:
|
75 |
st.error(f"An error occurred while processing the answer: {e}")
|
|
|
76 |
return "An error occurred while processing your request.", {}
|
77 |
|
78 |
def main():
|
|
|
79 |
st.title("Search Your PDF ππ")
|
80 |
|
81 |
with st.expander("About the App"):
|
|
|
96 |
st.write(metadata)
|
97 |
except Exception as e:
|
98 |
st.error(f"An unexpected error occurred: {e}")
|
|
|
99 |
|
100 |
if __name__ == '__main__':
|
101 |
main()
|