Spaces:
Sleeping
Sleeping
# app.py | |
import os | |
import io | |
import tempfile | |
import fitz # PyMuPDF | |
import pytesseract | |
from pdf2image import convert_from_bytes, convert_from_path | |
import numpy as np | |
import faiss | |
import torch | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
from sentence_transformers import SentenceTransformer | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for cross-origin requests | |
# Device setup | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load models at startup (only once) | |
try: | |
print("Loading models...") | |
# Embedding model for semantic search | |
embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to(device) | |
# QA pipeline for direct question answering | |
qa_pipeline = pipeline( | |
"question-answering", | |
model="distilbert-base-cased-distilled-squad", | |
tokenizer="distilbert-base-cased", | |
device=0 if device == "cuda" else -1 | |
) | |
# Generation model for more detailed responses | |
gen_model_name = "distilgpt2" # Lightweight model suitable for free tier | |
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name) | |
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name).to(device) | |
# Ensure pad token is set for the tokenizer | |
if gen_tokenizer.pad_token is None: | |
gen_tokenizer.pad_token = gen_tokenizer.eos_token | |
gen_model.config.pad_token_id = gen_model.config.eos_token_id | |
print("β Models loaded successfully") | |
except Exception as e: | |
print(f"β Error loading models: {e}") | |
raise | |
# OCR fallback for scanned PDFs | |
def ocr_pdf(pdf_bytes): | |
try: | |
images = convert_from_bytes(pdf_bytes) | |
text = "" | |
for img in images: | |
text += pytesseract.image_to_string(img) | |
return text | |
except Exception as e: | |
print(f"β OCR error: {e}") | |
return "" | |
# Text extraction from standard PDFs | |
def extract_text(pdf_bytes): | |
try: | |
doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf") | |
text = "" | |
for page in doc: | |
text += page.get_text() | |
if len(text.strip()) < 50: | |
print("β οΈ Not enough text, using OCR fallback...") | |
text = ocr_pdf(pdf_bytes) | |
print("β Text extraction complete") | |
return text | |
except Exception as e: | |
print(f"β Text extraction error: {e}") | |
return "" | |
# Split into chunks with overlap for better context | |
def split_into_chunks(text, max_tokens=300, overlap=50): | |
sentences = text.split('.') | |
chunks, current = [], '' | |
for sentence in sentences: | |
sentence = sentence.strip() + '.' | |
if len(current) + len(sentence) < max_tokens: | |
current += sentence | |
else: | |
chunks.append(current.strip()) | |
# Keep some overlap for context continuity | |
words = current.split() | |
if len(words) > overlap: | |
current = ' '.join(words[-overlap:]) + ' ' + sentence | |
else: | |
current = sentence | |
if current: | |
chunks.append(current.strip()) | |
return chunks | |
# Setup FAISS index for semantic search | |
def setup_faiss(chunks): | |
try: | |
embeddings = embedding_model.encode(chunks) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index, embeddings, chunks | |
except Exception as e: | |
print(f"β FAISS setup error: {e}") | |
raise | |
# Get answer using QA pipeline | |
def answer_with_qa_pipeline(chunks, question): | |
try: | |
# Join relevant chunks for context | |
context = " ".join(chunks[:5]) # Using first 5 chunks for simplicity | |
result = qa_pipeline(question=question, context=context) | |
return result['answer'] | |
except Exception as e: | |
print(f"β QA pipeline error: {e}") | |
return "Could not generate answer with QA pipeline." | |
# Get answer using generation model | |
def answer_with_generation(index, embeddings, chunks, question): | |
try: | |
# Get embeddings for question | |
q_embedding = embedding_model.encode([question]) | |
# Search in FAISS index for most relevant chunks | |
_, top_k_indices = index.search(q_embedding, k=3) | |
relevant_chunks = [chunks[i] for i in top_k_indices[0]] | |
context = " ".join(relevant_chunks) | |
# Create a clear prompt | |
prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:" | |
# Generate response | |
inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device) | |
output = gen_model.generate( | |
**inputs, | |
max_new_tokens=300, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True, | |
num_beams=3, | |
no_repeat_ngram_size=2 | |
) | |
answer = gen_tokenizer.decode(output[0], skip_special_tokens=True) | |
# Extract the answer part | |
if "Detailed answer:" in answer: | |
return answer.split("Detailed answer:")[-1].strip() | |
return answer | |
except Exception as e: | |
print(f"β Generation error: {e}") | |
return "Could not generate answer." | |
# Process PDF and answer question | |
def process_pdf_and_answer(pdf_bytes, question): | |
try: | |
# Extract text from PDF | |
text = extract_text(pdf_bytes) | |
if not text: | |
return "Could not extract text from the PDF." | |
# Split into chunks | |
chunks = split_into_chunks(text) | |
if not chunks: | |
return "Could not process the PDF content." | |
# Try QA pipeline first | |
print("Attempting to answer with QA pipeline...") | |
qa_answer = answer_with_qa_pipeline(chunks, question) | |
# If QA answer is too short or empty, try generation approach | |
if len(qa_answer) < 20: | |
print("QA answer too short, trying generation approach...") | |
index, embeddings, chunks = setup_faiss(chunks) | |
gen_answer = answer_with_generation(index, embeddings, chunks, question) | |
return gen_answer | |
return qa_answer | |
except Exception as e: | |
print(f"β Processing error: {e}") | |
return f"An error occurred: {str(e)}" | |
# API Endpoints | |
def health_check(): | |
"""Health check endpoint""" | |
return jsonify({"status": "healthy"}) | |
def ask_question(): | |
"""Endpoint for asking questions about a PDF""" | |
try: | |
if 'file' not in request.files: | |
return jsonify({"error": "No file provided"}), 400 | |
file = request.files['file'] | |
if not file or file.filename == '': | |
return jsonify({"error": "Invalid file"}), 400 | |
if 'question' not in request.form: | |
return jsonify({"error": "No question provided"}), 400 | |
question = request.form['question'] | |
pdf_bytes = file.read() | |
answer = process_pdf_and_answer(pdf_bytes, question) | |
return jsonify({ | |
"answer": answer, | |
"success": True | |
}) | |
except Exception as e: | |
print(f"β API error: {e}") | |
return jsonify({"error": str(e), "success": False}), 500 | |
if __name__ == "__main__": | |
# Get port from environment variable or use 7860 (Hugging Face Spaces default) | |
port = int(os.environ.get("PORT", 7860)) | |
app.run(host="0.0.0.0", port=port) | |