pdfassistant / app.py
priyanshu23456's picture
Create app.py
c971d0d verified
raw
history blame
7.68 kB
# app.py
import os
import io
import tempfile
import fitz # PyMuPDF
import pytesseract
from pdf2image import convert_from_bytes, convert_from_path
import numpy as np
import faiss
import torch
from flask import Flask, request, jsonify
from flask_cors import CORS
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
app = Flask(__name__)
CORS(app) # Enable CORS for cross-origin requests
# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load models at startup (only once)
try:
print("Loading models...")
# Embedding model for semantic search
embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
# QA pipeline for direct question answering
qa_pipeline = pipeline(
"question-answering",
model="distilbert-base-cased-distilled-squad",
tokenizer="distilbert-base-cased",
device=0 if device == "cuda" else -1
)
# Generation model for more detailed responses
gen_model_name = "distilgpt2" # Lightweight model suitable for free tier
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name).to(device)
# Ensure pad token is set for the tokenizer
if gen_tokenizer.pad_token is None:
gen_tokenizer.pad_token = gen_tokenizer.eos_token
gen_model.config.pad_token_id = gen_model.config.eos_token_id
print("βœ… Models loaded successfully")
except Exception as e:
print(f"❌ Error loading models: {e}")
raise
# OCR fallback for scanned PDFs
def ocr_pdf(pdf_bytes):
try:
images = convert_from_bytes(pdf_bytes)
text = ""
for img in images:
text += pytesseract.image_to_string(img)
return text
except Exception as e:
print(f"❌ OCR error: {e}")
return ""
# Text extraction from standard PDFs
def extract_text(pdf_bytes):
try:
doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf")
text = ""
for page in doc:
text += page.get_text()
if len(text.strip()) < 50:
print("⚠️ Not enough text, using OCR fallback...")
text = ocr_pdf(pdf_bytes)
print("βœ… Text extraction complete")
return text
except Exception as e:
print(f"❌ Text extraction error: {e}")
return ""
# Split into chunks with overlap for better context
def split_into_chunks(text, max_tokens=300, overlap=50):
sentences = text.split('.')
chunks, current = [], ''
for sentence in sentences:
sentence = sentence.strip() + '.'
if len(current) + len(sentence) < max_tokens:
current += sentence
else:
chunks.append(current.strip())
# Keep some overlap for context continuity
words = current.split()
if len(words) > overlap:
current = ' '.join(words[-overlap:]) + ' ' + sentence
else:
current = sentence
if current:
chunks.append(current.strip())
return chunks
# Setup FAISS index for semantic search
def setup_faiss(chunks):
try:
embeddings = embedding_model.encode(chunks)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index, embeddings, chunks
except Exception as e:
print(f"❌ FAISS setup error: {e}")
raise
# Get answer using QA pipeline
def answer_with_qa_pipeline(chunks, question):
try:
# Join relevant chunks for context
context = " ".join(chunks[:5]) # Using first 5 chunks for simplicity
result = qa_pipeline(question=question, context=context)
return result['answer']
except Exception as e:
print(f"❌ QA pipeline error: {e}")
return "Could not generate answer with QA pipeline."
# Get answer using generation model
def answer_with_generation(index, embeddings, chunks, question):
try:
# Get embeddings for question
q_embedding = embedding_model.encode([question])
# Search in FAISS index for most relevant chunks
_, top_k_indices = index.search(q_embedding, k=3)
relevant_chunks = [chunks[i] for i in top_k_indices[0]]
context = " ".join(relevant_chunks)
# Create a clear prompt
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
# Generate response
inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
output = gen_model.generate(
**inputs,
max_new_tokens=300,
temperature=0.7,
top_p=0.9,
do_sample=True,
num_beams=3,
no_repeat_ngram_size=2
)
answer = gen_tokenizer.decode(output[0], skip_special_tokens=True)
# Extract the answer part
if "Detailed answer:" in answer:
return answer.split("Detailed answer:")[-1].strip()
return answer
except Exception as e:
print(f"❌ Generation error: {e}")
return "Could not generate answer."
# Process PDF and answer question
def process_pdf_and_answer(pdf_bytes, question):
try:
# Extract text from PDF
text = extract_text(pdf_bytes)
if not text:
return "Could not extract text from the PDF."
# Split into chunks
chunks = split_into_chunks(text)
if not chunks:
return "Could not process the PDF content."
# Try QA pipeline first
print("Attempting to answer with QA pipeline...")
qa_answer = answer_with_qa_pipeline(chunks, question)
# If QA answer is too short or empty, try generation approach
if len(qa_answer) < 20:
print("QA answer too short, trying generation approach...")
index, embeddings, chunks = setup_faiss(chunks)
gen_answer = answer_with_generation(index, embeddings, chunks, question)
return gen_answer
return qa_answer
except Exception as e:
print(f"❌ Processing error: {e}")
return f"An error occurred: {str(e)}"
# API Endpoints
@app.route("/health", methods=["GET"])
def health_check():
"""Health check endpoint"""
return jsonify({"status": "healthy"})
@app.route("/api/ask", methods=["POST"])
def ask_question():
"""Endpoint for asking questions about a PDF"""
try:
if 'file' not in request.files:
return jsonify({"error": "No file provided"}), 400
file = request.files['file']
if not file or file.filename == '':
return jsonify({"error": "Invalid file"}), 400
if 'question' not in request.form:
return jsonify({"error": "No question provided"}), 400
question = request.form['question']
pdf_bytes = file.read()
answer = process_pdf_and_answer(pdf_bytes, question)
return jsonify({
"answer": answer,
"success": True
})
except Exception as e:
print(f"❌ API error: {e}")
return jsonify({"error": str(e), "success": False}), 500
if __name__ == "__main__":
# Get port from environment variable or use 7860 (Hugging Face Spaces default)
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)