File size: 4,411 Bytes
791a72e
7d48d44
 
61e6b08
 
 
 
20d640d
61e6b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d48d44
61e6b08
 
 
7d48d44
61e6b08
 
 
 
 
7d48d44
 
61e6b08
 
 
 
 
7d48d44
61e6b08
 
 
 
 
 
7d48d44
61e6b08
 
 
 
 
 
 
 
 
 
7d48d44
61e6b08
 
 
 
 
7d48d44
61e6b08
 
 
20d640d
61e6b08
 
 
20d640d
61e6b08
 
 
 
 
 
 
7d48d44
61e6b08
 
 
 
 
7d48d44
 
61e6b08
 
 
 
 
7d48d44
61e6b08
 
 
 
 
 
 
7d48d44
61e6b08
 
 
 
 
20d640d
61e6b08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import streamlit as st
import os
from PIL import Image
import google.generativeai as genai
from utils.document_processing import process_pdf
from utils.models import load_models
from utils.rag import query_pipeline

# Configure the app
st.set_page_config(
    page_title="PDF RAG Pipeline",
    page_icon="📄",
    layout="wide"
)

# Initialize session state
if 'models_loaded' not in st.session_state:
    st.session_state.models_loaded = False
if 'processed_docs' not in st.session_state:
    st.session_state.processed_docs = None

# Sidebar for configuration
with st.sidebar:
    st.title("Configuration")
    
    # API keys
    groq_api_key = st.text_input("Groq API Key", type="password")
    google_api_key = st.text_input("Google API Key", type="password")
    
    # Model selection
    embedding_model = st.selectbox(
        "Embedding Model",
        ["ibm-granite/granite-embedding-30m-english"],
        index=0
    )
    
    llm_model = st.selectbox(
        "LLM Model",
        ["llama3-70b-8192"],
        index=0
    )
    
    # File upload
    uploaded_file = st.file_uploader(
        "Upload a PDF file", 
        type=["pdf"],
        accept_multiple_files=False
    )
    
    if st.button("Initialize Models"):
        with st.spinner("Loading models..."):
            try:
                # Load models
                embeddings_model, embeddings_tokenizer, vision_model, llm_model = load_models(
                    embedding_model=embedding_model,
                    llm_model=llm_model,
                    google_api_key=google_api_key,
                    groq_api_key=groq_api_key
                )
                
                st.session_state.embeddings_model = embeddings_model
                st.session_state.embeddings_tokenizer = embeddings_tokenizer
                st.session_state.vision_model = vision_model
                st.session_state.llm_model = llm_model
                st.session_state.models_loaded = True
                
                st.success("Models loaded successfully!")
            except Exception as e:
                st.error(f"Error loading models: {str(e)}")

# Main app interface
st.title("PDF RAG Pipeline")
st.write("Upload a PDF and ask questions about its content")

if uploaded_file and st.session_state.models_loaded:
    with st.spinner("Processing PDF..."):
        try:
            # Save uploaded file temporarily
            file_path = f"./temp_{uploaded_file.name}"
            with open(file_path, "wb") as f:
                f.write(uploaded_file.getbuffer())
            
            # Process the PDF
            texts, tables, pictures = process_pdf(
                file_path,
                st.session_state.embeddings_tokenizer,
                st.session_state.vision_model
            )
            
            st.session_state.processed_docs = {
                "texts": texts,
                "tables": tables,
                "pictures": pictures
            }
            
            st.success("PDF processed successfully!")
            
            # Display document stats
            col1, col2, col3 = st.columns(3)
            col1.metric("Text Chunks", len(texts))
            col2.metric("Tables", len(tables))
            col3.metric("Images", len(pictures))
            
            # Remove temp file
            os.remove(file_path)
            
        except Exception as e:
            st.error(f"Error processing PDF: {str(e)}")

# Question answering section
if st.session_state.processed_docs:
    st.divider()
    st.subheader("Ask a Question")
    
    question = st.text_input("Enter your question about the document:")
    
    if question and st.button("Get Answer"):
        with st.spinner("Generating answer..."):
            try:
                answer = query_pipeline(
                    question=question,
                    texts=st.session_state.processed_docs["texts"],
                    tables=st.session_state.processed_docs["tables"],
                    pictures=st.session_state.processed_docs["pictures"],
                    embeddings_model=st.session_state.embeddings_model,
                    llm_model=st.session_state.llm_model
                )
                
                st.subheader("Answer")
                st.write(answer)
                
            except Exception as e:
                st.error(f"Error generating answer: {str(e)}")