priyanshu23456 commited on
Commit
c971d0d
Β·
verified Β·
1 Parent(s): 0b5a679

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +226 -0
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import io
4
+ import tempfile
5
+ import fitz # PyMuPDF
6
+ import pytesseract
7
+ from pdf2image import convert_from_bytes, convert_from_path
8
+ import numpy as np
9
+ import faiss
10
+ import torch
11
+ from flask import Flask, request, jsonify
12
+ from flask_cors import CORS
13
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
14
+ from sentence_transformers import SentenceTransformer
15
+
16
+ app = Flask(__name__)
17
+ CORS(app) # Enable CORS for cross-origin requests
18
+
19
+ # Device setup
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ print(f"Using device: {device}")
22
+
23
+ # Load models at startup (only once)
24
+ try:
25
+ print("Loading models...")
26
+ # Embedding model for semantic search
27
+ embedding_model = SentenceTransformer("all-MiniLM-L6-v2").to(device)
28
+
29
+ # QA pipeline for direct question answering
30
+ qa_pipeline = pipeline(
31
+ "question-answering",
32
+ model="distilbert-base-cased-distilled-squad",
33
+ tokenizer="distilbert-base-cased",
34
+ device=0 if device == "cuda" else -1
35
+ )
36
+
37
+ # Generation model for more detailed responses
38
+ gen_model_name = "distilgpt2" # Lightweight model suitable for free tier
39
+ gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
40
+ gen_model = AutoModelForCausalLM.from_pretrained(gen_model_name).to(device)
41
+
42
+ # Ensure pad token is set for the tokenizer
43
+ if gen_tokenizer.pad_token is None:
44
+ gen_tokenizer.pad_token = gen_tokenizer.eos_token
45
+ gen_model.config.pad_token_id = gen_model.config.eos_token_id
46
+
47
+ print("βœ… Models loaded successfully")
48
+ except Exception as e:
49
+ print(f"❌ Error loading models: {e}")
50
+ raise
51
+
52
+ # OCR fallback for scanned PDFs
53
+ def ocr_pdf(pdf_bytes):
54
+ try:
55
+ images = convert_from_bytes(pdf_bytes)
56
+ text = ""
57
+ for img in images:
58
+ text += pytesseract.image_to_string(img)
59
+ return text
60
+ except Exception as e:
61
+ print(f"❌ OCR error: {e}")
62
+ return ""
63
+
64
+ # Text extraction from standard PDFs
65
+ def extract_text(pdf_bytes):
66
+ try:
67
+ doc = fitz.open(stream=io.BytesIO(pdf_bytes), filetype="pdf")
68
+ text = ""
69
+ for page in doc:
70
+ text += page.get_text()
71
+ if len(text.strip()) < 50:
72
+ print("⚠️ Not enough text, using OCR fallback...")
73
+ text = ocr_pdf(pdf_bytes)
74
+ print("βœ… Text extraction complete")
75
+ return text
76
+ except Exception as e:
77
+ print(f"❌ Text extraction error: {e}")
78
+ return ""
79
+
80
+ # Split into chunks with overlap for better context
81
+ def split_into_chunks(text, max_tokens=300, overlap=50):
82
+ sentences = text.split('.')
83
+ chunks, current = [], ''
84
+ for sentence in sentences:
85
+ sentence = sentence.strip() + '.'
86
+ if len(current) + len(sentence) < max_tokens:
87
+ current += sentence
88
+ else:
89
+ chunks.append(current.strip())
90
+ # Keep some overlap for context continuity
91
+ words = current.split()
92
+ if len(words) > overlap:
93
+ current = ' '.join(words[-overlap:]) + ' ' + sentence
94
+ else:
95
+ current = sentence
96
+ if current:
97
+ chunks.append(current.strip())
98
+ return chunks
99
+
100
+ # Setup FAISS index for semantic search
101
+ def setup_faiss(chunks):
102
+ try:
103
+ embeddings = embedding_model.encode(chunks)
104
+ dimension = embeddings.shape[1]
105
+
106
+ index = faiss.IndexFlatL2(dimension)
107
+ index.add(embeddings)
108
+ return index, embeddings, chunks
109
+ except Exception as e:
110
+ print(f"❌ FAISS setup error: {e}")
111
+ raise
112
+
113
+ # Get answer using QA pipeline
114
+ def answer_with_qa_pipeline(chunks, question):
115
+ try:
116
+ # Join relevant chunks for context
117
+ context = " ".join(chunks[:5]) # Using first 5 chunks for simplicity
118
+
119
+ result = qa_pipeline(question=question, context=context)
120
+ return result['answer']
121
+ except Exception as e:
122
+ print(f"❌ QA pipeline error: {e}")
123
+ return "Could not generate answer with QA pipeline."
124
+
125
+ # Get answer using generation model
126
+ def answer_with_generation(index, embeddings, chunks, question):
127
+ try:
128
+ # Get embeddings for question
129
+ q_embedding = embedding_model.encode([question])
130
+
131
+ # Search in FAISS index for most relevant chunks
132
+ _, top_k_indices = index.search(q_embedding, k=3)
133
+ relevant_chunks = [chunks[i] for i in top_k_indices[0]]
134
+ context = " ".join(relevant_chunks)
135
+
136
+ # Create a clear prompt
137
+ prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
138
+
139
+ # Generate response
140
+ inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
141
+ output = gen_model.generate(
142
+ **inputs,
143
+ max_new_tokens=300,
144
+ temperature=0.7,
145
+ top_p=0.9,
146
+ do_sample=True,
147
+ num_beams=3,
148
+ no_repeat_ngram_size=2
149
+ )
150
+ answer = gen_tokenizer.decode(output[0], skip_special_tokens=True)
151
+
152
+ # Extract the answer part
153
+ if "Detailed answer:" in answer:
154
+ return answer.split("Detailed answer:")[-1].strip()
155
+ return answer
156
+ except Exception as e:
157
+ print(f"❌ Generation error: {e}")
158
+ return "Could not generate answer."
159
+
160
+ # Process PDF and answer question
161
+ def process_pdf_and_answer(pdf_bytes, question):
162
+ try:
163
+ # Extract text from PDF
164
+ text = extract_text(pdf_bytes)
165
+ if not text:
166
+ return "Could not extract text from the PDF."
167
+
168
+ # Split into chunks
169
+ chunks = split_into_chunks(text)
170
+ if not chunks:
171
+ return "Could not process the PDF content."
172
+
173
+ # Try QA pipeline first
174
+ print("Attempting to answer with QA pipeline...")
175
+ qa_answer = answer_with_qa_pipeline(chunks, question)
176
+
177
+ # If QA answer is too short or empty, try generation approach
178
+ if len(qa_answer) < 20:
179
+ print("QA answer too short, trying generation approach...")
180
+ index, embeddings, chunks = setup_faiss(chunks)
181
+ gen_answer = answer_with_generation(index, embeddings, chunks, question)
182
+ return gen_answer
183
+
184
+ return qa_answer
185
+ except Exception as e:
186
+ print(f"❌ Processing error: {e}")
187
+ return f"An error occurred: {str(e)}"
188
+
189
+ # API Endpoints
190
+
191
+ @app.route("/health", methods=["GET"])
192
+ def health_check():
193
+ """Health check endpoint"""
194
+ return jsonify({"status": "healthy"})
195
+
196
+ @app.route("/api/ask", methods=["POST"])
197
+ def ask_question():
198
+ """Endpoint for asking questions about a PDF"""
199
+ try:
200
+ if 'file' not in request.files:
201
+ return jsonify({"error": "No file provided"}), 400
202
+
203
+ file = request.files['file']
204
+ if not file or file.filename == '':
205
+ return jsonify({"error": "Invalid file"}), 400
206
+
207
+ if 'question' not in request.form:
208
+ return jsonify({"error": "No question provided"}), 400
209
+
210
+ question = request.form['question']
211
+ pdf_bytes = file.read()
212
+
213
+ answer = process_pdf_and_answer(pdf_bytes, question)
214
+
215
+ return jsonify({
216
+ "answer": answer,
217
+ "success": True
218
+ })
219
+ except Exception as e:
220
+ print(f"❌ API error: {e}")
221
+ return jsonify({"error": str(e), "success": False}), 500
222
+
223
+ if __name__ == "__main__":
224
+ # Get port from environment variable or use 7860 (Hugging Face Spaces default)
225
+ port = int(os.environ.get("PORT", 7860))
226
+ app.run(host="0.0.0.0", port=port)