File size: 5,920 Bytes
f4e7b4f
 
 
2ac2dd4
f4e7b4f
 
 
 
d186b8d
f4e7b4f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d186b8d
f4e7b4f
d186b8d
 
 
 
33aebae
 
 
 
d186b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4e7b4f
d186b8d
 
 
 
 
 
33aebae
f4e7b4f
 
 
 
d186b8d
f4e7b4f
 
 
 
 
 
 
 
d186b8d
 
 
 
f4e7b4f
 
 
 
d186b8d
33aebae
 
 
 
 
 
 
 
f4e7b4f
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
import pandas as pd
import os
import io
import requests

# --- 1. Data Loading and Preprocessing ---

@st.cache_data()
def load_and_process_pdfs_from_folder(docs_folder="docs"):
    """Loads and processes all PDF files from the specified folder."""
    all_text = ""
    all_tables = []
    for filename in os.listdir(docs_folder):
        if filename.endswith(".pdf"):
            filepath = os.path.join(docs_folder, filename)
            try:
                with open(filepath, 'rb') as file:
                    pdf_reader = PdfReader(file)
                    for page in pdf_reader.pages:
                        all_text += page.extract_text() + "\n"
                        try:
                            for table in page.extract_tables():
                                df = pd.DataFrame(table)
                                all_tables.append(df)
                        except Exception as e:
                            print(f"Could not extract tables from page in {filename}. Error: {e}")
            except Exception as e:
                st.error(f"Error reading PDF {filename}: {e}")
    return all_text, all_tables

@st.cache_data()
def split_text_into_chunks(text):
    """Splits the text into smaller, manageable chunks."""
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
    chunks = text_splitter.split_text(text)
    return chunks

@st.cache_data()
def create_vectorstore(chunks):
    """Creates a vectorstore from the text chunks using HuggingFace embeddings."""
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
    vectorstore = FAISS.from_texts(chunks, embeddings)
    return vectorstore

# --- 2. Question Answering with Groq ---

def generate_answer_with_groq(question, context):
    """Generates an answer using the Groq API."""
    url = "https://api.groq.com/openai/v1/chat/completions"
    api_key = os.environ.get("GROQ_API_KEY")
    if not api_key:
        st.error("GROQ_API_KEY environment variable not found. Please set it.")
        return None  # Indicate failure

    headers = {
        "Authorization": f"Bearer {api_key}",
        "Content-Type": "application/json",
    }
    prompt = (
        f"Customer asked: '{question}'\n\n"
        f"Here is the relevant product or policy info to help:\n{context}\n\n"
        f"Respond in a friendly and helpful tone as a toy shop support agent."
    )
    payload = {
        "model": "llama3-8b-8192",
        "messages": [
            {
                "role": "system",
                "content": (
                    "You are ToyBot, a friendly and helpful WhatsApp assistant for an online toy shop. "
                    "Your goal is to politely answer customer questions, help them choose the right toys, "
                    "provide order or delivery information, explain return policies, and guide them through purchases."
                )
            },
            {"role": "user", "content": prompt},
        ],
        "temperature": 0.5,
        "max_tokens": 300,
    }
    try:
        response = requests.post(url, headers=headers, json=payload)
        response.raise_for_status()  # Raise an exception for bad status codes
        return response.json()['choices'][0]['message']['content'].strip()
    except requests.exceptions.RequestException as e:
        st.error(f"Error communicating with Groq API: {e}")
        return "An error occurred while trying to get the answer."

def perform_rag_groq(vectorstore, query):
    """Performs retrieval and generates an answer using Groq."""
    retriever = vectorstore.as_retriever()
    relevant_docs = retriever.get_relevant_documents(query)
    context = "\n\n".join([doc.page_content for doc in relevant_docs])
    answer = generate_answer_with_groq(query, context)
    return {"answer": answer, "sources": [doc.metadata['source'] for doc in relevant_docs] if relevant_docs else []}

# --- 3. Streamlit UI ---

def main():
    st.title("PDF Q&A with Local Docs (Powered by Groq)")
    st.info("Make sure you have a 'docs' folder in the same directory as this script containing your PDF files.")

    with st.spinner("Loading and processing PDF(s)..."):
        all_text, all_tables = load_and_process_pdfs_from_folder()

    if all_text:
        with st.spinner("Creating knowledge base..."):
            chunks = split_text_into_chunks(all_text)
            # We need to add metadata (source) to the chunks for accurate source tracking
            metadatas = [{"source": f"doc_{i+1}"} for i in range(len(chunks))] # Basic source tracking
            embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
            vectorstore = FAISS.from_texts(chunks, embeddings, metadatas=metadatas)

        query = st.text_input("Ask a question about the documents:")
        if query:
            with st.spinner("Searching for answer..."):
                result = perform_rag_groq(vectorstore, query)
                if result and result.get("answer"):
                    st.subheader("Answer:")
                    st.write(result["answer"])
                    if "sources" in result and result["sources"]:
                        st.subheader("Source:")
                        st.write(", ".join(result["sources"]))
                else:
                    st.warning("Could not generate an answer.")

    if all_tables:
        st.subheader("Extracted Tables:")
        for i, table_df in enumerate(all_tables):
            st.write(f"Table {i+1}:")
            st.dataframe(table_df)
    elif not all_text:
        st.warning("No PDF files found in the 'docs' folder.")

if __name__ == "__main__":
    main()