Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
from PIL import Image | |
import google.generativeai as genai | |
from utils.document_processing import process_pdf | |
from utils.models import load_models | |
from utils.rag import query_pipeline | |
# Configure the app | |
st.set_page_config( | |
page_title="PDF RAG Pipeline", | |
page_icon="📄", | |
layout="wide" | |
) | |
# Initialize session state | |
if 'models_loaded' not in st.session_state: | |
st.session_state.models_loaded = False | |
if 'processed_docs' not in st.session_state: | |
st.session_state.processed_docs = None | |
# Sidebar for configuration | |
with st.sidebar: | |
st.title("Configuration") | |
# API keys | |
groq_api_key = st.text_input("Groq API Key", type="password") | |
google_api_key = st.text_input("Google API Key", type="password") | |
# Model selection | |
embedding_model = st.selectbox( | |
"Embedding Model", | |
["ibm-granite/granite-embedding-30m-english"], | |
index=0 | |
) | |
llm_model = st.selectbox( | |
"LLM Model", | |
["llama3-70b-8192"], | |
index=0 | |
) | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload a PDF file", | |
type=["pdf"], | |
accept_multiple_files=False | |
) | |
if st.button("Initialize Models"): | |
with st.spinner("Loading models..."): | |
try: | |
# Load models | |
embeddings_model, embeddings_tokenizer, vision_model, llm_model = load_models( | |
embedding_model=embedding_model, | |
llm_model=llm_model, | |
google_api_key=google_api_key, | |
groq_api_key=groq_api_key | |
) | |
st.session_state.embeddings_model = embeddings_model | |
st.session_state.embeddings_tokenizer = embeddings_tokenizer | |
st.session_state.vision_model = vision_model | |
st.session_state.llm_model = llm_model | |
st.session_state.models_loaded = True | |
st.success("Models loaded successfully!") | |
except Exception as e: | |
st.error(f"Error loading models: {str(e)}") | |
# Main app interface | |
st.title("PDF RAG Pipeline") | |
st.write("Upload a PDF and ask questions about its content") | |
if uploaded_file and st.session_state.models_loaded: | |
with st.spinner("Processing PDF..."): | |
try: | |
# Save uploaded file temporarily | |
file_path = f"./temp_{uploaded_file.name}" | |
with open(file_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
# Process the PDF | |
texts, tables, pictures = process_pdf( | |
file_path, | |
st.session_state.embeddings_tokenizer, | |
st.session_state.vision_model | |
) | |
st.session_state.processed_docs = { | |
"texts": texts, | |
"tables": tables, | |
"pictures": pictures | |
} | |
st.success("PDF processed successfully!") | |
# Display document stats | |
col1, col2, col3 = st.columns(3) | |
col1.metric("Text Chunks", len(texts)) | |
col2.metric("Tables", len(tables)) | |
col3.metric("Images", len(pictures)) | |
# Remove temp file | |
os.remove(file_path) | |
except Exception as e: | |
st.error(f"Error processing PDF: {str(e)}") | |
# Question answering section | |
if st.session_state.processed_docs: | |
st.divider() | |
st.subheader("Ask a Question") | |
question = st.text_input("Enter your question about the document:") | |
if question and st.button("Get Answer"): | |
with st.spinner("Generating answer..."): | |
try: | |
answer = query_pipeline( | |
question=question, | |
texts=st.session_state.processed_docs["texts"], | |
tables=st.session_state.processed_docs["tables"], | |
pictures=st.session_state.processed_docs["pictures"], | |
embeddings_model=st.session_state.embeddings_model, | |
llm_model=st.session_state.llm_model | |
) | |
st.subheader("Answer") | |
st.write(answer) | |
except Exception as e: | |
st.error(f"Error generating answer: {str(e)}") |