masadonline commited on
Commit
d186b8d
·
verified ·
1 Parent(s): f4e7b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -18
app.py CHANGED
@@ -3,11 +3,10 @@ from PyPDF2 import PdfReader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import FAISS
6
- from langchain.llms import HuggingFaceHub
7
- from langchain.chains import RetrievalQAWithSourcesChain
8
  import pandas as pd
9
  import os
10
  import io
 
11
 
12
  # --- 1. Data Loading and Preprocessing ---
13
 
@@ -48,44 +47,85 @@ def create_vectorstore(chunks):
48
  vectorstore = FAISS.from_texts(chunks, embeddings)
49
  return vectorstore
50
 
51
- # --- 2. Question Answering with RAG ---
52
 
53
- @st.cache_resource()
54
- def setup_llm():
55
- """Sets up the Hugging Face Hub LLM."""
56
- llm = HuggingFaceHub(repo_id="google/flan-t5-xxl", model_kwargs={"temperature": 0.5, "max_length": 512})
57
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- def perform_rag(vectorstore, llm, query):
60
- """Performs retrieval-augmented generation."""
61
- qa_chain = RetrievalQAWithSourcesChain.from_llm(llm, retriever=vectorstore.as_retriever())
62
- result = qa_chain({"question": query})
63
- return result
 
 
64
 
65
  # --- 3. Streamlit UI ---
66
 
67
  def main():
68
- st.title("PDF Q&A with Local Docs")
69
  st.info("Make sure you have a 'docs' folder in the same directory as this script containing your PDF files.")
70
 
 
 
 
 
 
 
71
  with st.spinner("Loading and processing PDF(s)..."):
72
  all_text, all_tables = load_and_process_pdfs_from_folder()
73
 
74
  if all_text:
75
  with st.spinner("Creating knowledge base..."):
76
  chunks = split_text_into_chunks(all_text)
77
- vectorstore = create_vectorstore(chunks)
78
- llm = setup_llm()
 
 
79
 
80
  query = st.text_input("Ask a question about the documents:")
81
  if query:
82
  with st.spinner("Searching for answer..."):
83
- result = perform_rag(vectorstore, llm, query)
84
  st.subheader("Answer:")
85
  st.write(result["answer"])
86
  if "sources" in result:
87
  st.subheader("Source:")
88
- st.write(result["sources"])
89
 
90
  if all_tables:
91
  st.subheader("Extracted Tables:")
 
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain.embeddings import HuggingFaceEmbeddings
5
  from langchain.vectorstores import FAISS
 
 
6
  import pandas as pd
7
  import os
8
  import io
9
+ import requests
10
 
11
  # --- 1. Data Loading and Preprocessing ---
12
 
 
47
  vectorstore = FAISS.from_texts(chunks, embeddings)
48
  return vectorstore
49
 
50
+ # --- 2. Question Answering with Groq ---
51
 
52
+ def generate_answer_with_groq(question, context):
53
+ """Generates an answer using the Groq API."""
54
+ url = "https://api.groq.com/openai/v1/chat/completions"
55
+ api_key = os.environ.get("GROQ_API_KEY")
56
+ headers = {
57
+ "Authorization": f"Bearer {api_key}",
58
+ "Content-Type": "application/json",
59
+ }
60
+ prompt = (
61
+ f"Customer asked: '{question}'\n\n"
62
+ f"Here is the relevant product or policy info to help:\n{context}\n\n"
63
+ f"Respond in a friendly and helpful tone as a toy shop support agent."
64
+ )
65
+ payload = {
66
+ "model": "llama3-8b-8192",
67
+ "messages": [
68
+ {
69
+ "role": "system",
70
+ "content": (
71
+ "You are ToyBot, a friendly and helpful WhatsApp assistant for an online toy shop. "
72
+ "Your goal is to politely answer customer questions, help them choose the right toys, "
73
+ "provide order or delivery information, explain return policies, and guide them through purchases."
74
+ )
75
+ },
76
+ {"role": "user", "content": prompt},
77
+ ],
78
+ "temperature": 0.5,
79
+ "max_tokens": 300,
80
+ }
81
+ try:
82
+ response = requests.post(url, headers=headers, json=payload)
83
+ response.raise_for_status() # Raise an exception for bad status codes
84
+ return response.json()['choices'][0]['message']['content'].strip()
85
+ except requests.exceptions.RequestException as e:
86
+ st.error(f"Error communicating with Groq API: {e}")
87
+ return "An error occurred while trying to get the answer."
88
 
89
+ def perform_rag_groq(vectorstore, query):
90
+ """Performs retrieval and generates an answer using Groq."""
91
+ retriever = vectorstore.as_retriever()
92
+ relevant_docs = retriever.get_relevant_documents(query)
93
+ context = "\n\n".join([doc.page_content for doc in relevant_docs])
94
+ answer = generate_answer_with_groq(query, context)
95
+ return {"answer": answer, "sources": [doc.metadata['source'] for doc in relevant_docs]} # You might need to adjust how sources are stored
96
 
97
  # --- 3. Streamlit UI ---
98
 
99
  def main():
100
+ st.title("PDF Q&A with Local Docs (Powered by Groq)")
101
  st.info("Make sure you have a 'docs' folder in the same directory as this script containing your PDF files.")
102
 
103
+ groq_api_key = st.text_input("Enter your Groq API Key:", type="password")
104
+ if not groq_api_key:
105
+ st.warning("Please enter your Groq API key to ask questions.")
106
+ return
107
+ os.environ["GROQ_API_KEY"] = groq_api_key
108
+
109
  with st.spinner("Loading and processing PDF(s)..."):
110
  all_text, all_tables = load_and_process_pdfs_from_folder()
111
 
112
  if all_text:
113
  with st.spinner("Creating knowledge base..."):
114
  chunks = split_text_into_chunks(all_text)
115
+ # We need to add metadata (source) to the chunks for accurate source tracking
116
+ metadatas = [{"source": f"doc_{i+1}"} for i in range(len(chunks))] # Basic source tracking
117
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
118
+ vectorstore = FAISS.from_texts(chunks, embeddings, metadatas=metadatas)
119
 
120
  query = st.text_input("Ask a question about the documents:")
121
  if query:
122
  with st.spinner("Searching for answer..."):
123
+ result = perform_rag_groq(vectorstore, query)
124
  st.subheader("Answer:")
125
  st.write(result["answer"])
126
  if "sources" in result:
127
  st.subheader("Source:")
128
+ st.write(", ".join(result["sources"])) # Display sources
129
 
130
  if all_tables:
131
  st.subheader("Extracted Tables:")