|
import streamlit as st |
|
import os |
|
from PIL import Image |
|
import google.generativeai as genai |
|
from utils.document_processing import process_pdf |
|
from utils.models import load_models |
|
from utils.rag import query_pipeline |
|
|
|
|
|
st.set_page_config( |
|
page_title="PDF RAG Pipeline", |
|
page_icon="📄", |
|
layout="wide" |
|
) |
|
|
|
|
|
if 'models_loaded' not in st.session_state: |
|
st.session_state.models_loaded = False |
|
if 'processed_docs' not in st.session_state: |
|
st.session_state.processed_docs = None |
|
|
|
|
|
with st.sidebar: |
|
st.title("Configuration") |
|
|
|
|
|
groq_api_key = st.text_input("Groq API Key", type="password") |
|
google_api_key = st.text_input("Google API Key", type="password") |
|
|
|
|
|
embedding_model = st.selectbox( |
|
"Embedding Model", |
|
["ibm-granite/granite-embedding-30m-english"], |
|
index=0 |
|
) |
|
|
|
llm_model = st.selectbox( |
|
"LLM Model", |
|
["llama3-70b-8192"], |
|
index=0 |
|
) |
|
|
|
|
|
uploaded_file = st.file_uploader( |
|
"Upload a PDF file", |
|
type=["pdf"], |
|
accept_multiple_files=False |
|
) |
|
|
|
if st.button("Initialize Models"): |
|
with st.spinner("Loading models..."): |
|
try: |
|
|
|
embeddings_model, embeddings_tokenizer, vision_model, llm_model = load_models( |
|
embedding_model=embedding_model, |
|
llm_model=llm_model, |
|
google_api_key=google_api_key, |
|
groq_api_key=groq_api_key |
|
) |
|
|
|
st.session_state.embeddings_model = embeddings_model |
|
st.session_state.embeddings_tokenizer = embeddings_tokenizer |
|
st.session_state.vision_model = vision_model |
|
st.session_state.llm_model = llm_model |
|
st.session_state.models_loaded = True |
|
|
|
st.success("Models loaded successfully!") |
|
except Exception as e: |
|
st.error(f"Error loading models: {str(e)}") |
|
|
|
|
|
st.title("PDF RAG Pipeline") |
|
st.write("Upload a PDF and ask questions about its content") |
|
|
|
if uploaded_file and st.session_state.models_loaded: |
|
with st.spinner("Processing PDF..."): |
|
try: |
|
|
|
file_path = f"./temp_{uploaded_file.name}" |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
|
|
texts, tables, pictures = process_pdf( |
|
file_path, |
|
st.session_state.embeddings_tokenizer, |
|
st.session_state.vision_model |
|
) |
|
|
|
st.session_state.processed_docs = { |
|
"texts": texts, |
|
"tables": tables, |
|
"pictures": pictures |
|
} |
|
|
|
st.success("PDF processed successfully!") |
|
|
|
|
|
col1, col2, col3 = st.columns(3) |
|
col1.metric("Text Chunks", len(texts)) |
|
col2.metric("Tables", len(tables)) |
|
col3.metric("Images", len(pictures)) |
|
|
|
|
|
os.remove(file_path) |
|
|
|
except Exception as e: |
|
st.error(f"Error processing PDF: {str(e)}") |
|
|
|
|
|
if st.session_state.processed_docs: |
|
st.divider() |
|
st.subheader("Ask a Question") |
|
|
|
question = st.text_input("Enter your question about the document:") |
|
|
|
if question and st.button("Get Answer"): |
|
with st.spinner("Generating answer..."): |
|
try: |
|
answer = query_pipeline( |
|
question=question, |
|
texts=st.session_state.processed_docs["texts"], |
|
tables=st.session_state.processed_docs["tables"], |
|
pictures=st.session_state.processed_docs["pictures"], |
|
embeddings_model=st.session_state.embeddings_model, |
|
llm_model=st.session_state.llm_model |
|
) |
|
|
|
st.subheader("Answer") |
|
st.write(answer) |
|
|
|
except Exception as e: |
|
st.error(f"Error generating answer: {str(e)}") |