manik-hossain commited on
Commit
c1a63a0
·
1 Parent(s): 2b7f356
Files changed (1) hide show
  1. app.py +74 -0
app.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, tempfile, streamlit as st
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.chains.combine_documents import create_stuff_documents_chain
5
+ from langchain.chains import create_retrieval_chain
6
+ from langchain_chroma import Chroma
7
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
8
+ from langchain_community.document_loaders import PyPDFLoader
9
+
10
+ # Streamlit app config
11
+ st.subheader("Generative Q&A with LangChain, Gemini and Chroma")
12
+ with st.sidebar:
13
+ google_api_key = st.text_input("Google API key", type="password")
14
+ source_doc = st.file_uploader("Source document", type="pdf")
15
+ col1, col2 = st.columns([4,1])
16
+ query = col1.text_input("Query", label_visibility="collapsed")
17
+ os.environ['GOOGLE_API_KEY'] = google_api_key
18
+
19
+ # Session state initialization for documents and retrievers
20
+ if 'retriever' not in st.session_state or 'loaded_doc' not in st.session_state:
21
+ st.session_state.retriever = None
22
+ st.session_state.loaded_doc = None
23
+
24
+ submit = col2.button("Submit")
25
+
26
+ if submit:
27
+ # Validate inputs
28
+ if not google_api_key or not query:
29
+ st.warning("Please provide the missing fields.")
30
+ elif not source_doc:
31
+ st.warning("Please upload the source document.")
32
+ else:
33
+ with st.spinner("Please wait..."):
34
+ # Check if it's the same document; if not or if retriever isn't set, reload and recompute
35
+ if st.session_state.loaded_doc != source_doc:
36
+ try:
37
+ # Save uploaded file temporarily to disk, load and split the file into pages, delete temp file
38
+ with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
39
+ tmp_file.write(source_doc.read())
40
+ loader = PyPDFLoader(tmp_file.name)
41
+ pages = loader.load_and_split()
42
+ os.remove(tmp_file.name)
43
+
44
+ # Generate embeddings for the pages, and store in Chroma vector database
45
+ embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
46
+ vectorstore = Chroma.from_documents(pages, embeddings)
47
+
48
+ #Configure Chroma as a retriever with top_k=5
49
+ st.session_state.retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
50
+
51
+ # Store the uploaded file in session state to prevent reloading
52
+ st.session_state.loaded_doc = source_doc
53
+ except Exception as e:
54
+ st.error(f"An error occurred: {e}")
55
+
56
+ try:
57
+ # Initialize the ChatGoogleGenerativeAI module, create and invoke the retrieval chain
58
+ llm = ChatGoogleGenerativeAI(model="gemini-pro")
59
+
60
+ template = """
61
+ You are a helpful AI assistant. Answer based on the context provided.
62
+ context: {context}
63
+ input: {input}
64
+ answer:
65
+ """
66
+ prompt = PromptTemplate.from_template(template)
67
+
68
+ combine_docs_chain = create_stuff_documents_chain(llm, prompt)
69
+ retrieval_chain = create_retrieval_chain(st.session_state.retriever, combine_docs_chain)
70
+ response = retrieval_chain.invoke({"input": query})
71
+
72
+ st.success(response['answer'])
73
+ except Exception as e:
74
+ st.error(f"An error occurred: {e}")