priyanshu23456 commited on
Commit
a987525
·
verified ·
1 Parent(s): f800f49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +219 -82
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from flask import Flask, request, jsonify
2
  from werkzeug.utils import secure_filename
3
- from flask_cors import CORS # ✅ Add this line
4
  import os
5
  import torch
6
  import fitz # PyMuPDF
@@ -12,6 +12,11 @@ import faiss
12
  import numpy as np
13
  import tempfile
14
  from PIL import Image
 
 
 
 
 
15
 
16
  # Fix caching issue on Hugging Face Spaces
17
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
@@ -19,16 +24,65 @@ os.environ["HF_HOME"] = "/tmp"
19
  os.environ["XDG_CACHE_HOME"] = "/tmp"
20
 
21
  app = Flask(__name__)
22
- CORS(app) # Enable CORS for all routes
23
 
24
  UPLOAD_FOLDER = "/tmp/uploads"
25
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Improved OCR function
30
  def ocr_pdf(pdf_path):
31
  try:
 
32
  # Use a higher DPI for better quality
33
  images = convert_from_path(
34
  pdf_path,
@@ -39,17 +93,20 @@ def ocr_pdf(pdf_path):
39
  )
40
 
41
  text = ""
42
- for img in images:
 
43
  # Preprocess the image for better OCR results
44
  preprocessed = preprocess_image_for_ocr(img)
45
  # Use tesseract with more options
46
- text += pytesseract.image_to_string(
47
  preprocessed,
48
  config='--psm 1 --oem 3 -l eng' # Page segmentation mode 1 (auto), OCR Engine mode 3 (default)
49
  )
 
 
50
  return text
51
  except Exception as e:
52
- print(f"OCR error: {str(e)}")
53
  return ""
54
 
55
  # Image preprocessing function for better OCR
@@ -66,27 +123,38 @@ def preprocess_image_for_ocr(img):
66
 
67
  # Improved extract_text function with better text detection
68
  def extract_text(pdf_path):
69
- doc = fitz.open(pdf_path)
70
- text = ""
71
- for page in doc:
72
- page_text = page.get_text()
73
- text += page_text
74
-
75
- # Check if the text is meaningful (more sophisticated check)
76
- words = text.split()
77
- unique_words = set(word.lower() for word in words if len(word) > 2)
78
-
79
- # If we don't have enough meaningful text, try OCR
80
- if len(unique_words) < 20 or len(text.strip()) < 100:
81
- ocr_text = ocr_pdf(pdf_path)
82
- # If OCR gave us more text, use it
83
- if len(ocr_text.strip()) > len(text.strip()):
84
- text = ocr_text
85
-
86
- return text
 
 
 
 
 
 
 
 
 
 
87
 
88
- # Split into chunks
89
  def split_into_chunks(text, max_tokens=300, overlap=50):
 
90
  sentences = text.split('.')
91
  chunks, current = [], ''
92
  for sentence in sentences:
@@ -102,60 +170,95 @@ def split_into_chunks(text, max_tokens=300, overlap=50):
102
  current = sentence
103
  if current:
104
  chunks.append(current.strip())
 
105
  return chunks
106
 
107
- # Setup FAISS
108
  def setup_faiss(chunks):
109
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
110
- embeddings = embedder.encode(chunks)
111
- dim = embeddings.shape[1]
112
- index = faiss.IndexFlatL2(dim)
113
- index.add(embeddings)
114
- return index, embeddings, chunks
115
-
116
- # QA pipeline
 
 
 
 
 
 
 
 
 
117
  def answer_with_qa_pipeline(chunks, question):
118
- qa_pipeline = pipeline(
119
- "question-answering",
120
- model="distilbert-base-cased-distilled-squad",
121
- tokenizer="distilbert-base-cased",
122
- device=0 if device == "cuda" else -1
123
- )
124
- context = " ".join(chunks[:5])
125
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  result = qa_pipeline(question=question, context=context)
 
127
  return result["answer"]
128
- except:
 
129
  return ""
130
 
131
- # Modify your answer_with_generation function like this:
132
  def answer_with_generation(index, embeddings, chunks, question):
133
- tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
134
-
135
- # Fix for meta tensor error - load model with device_map="auto"
136
- model = AutoModelForCausalLM.from_pretrained(
137
- "distilgpt2",
138
- device_map="auto", # This handles device placement automatically
139
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 # Use fp16 if possible
140
- )
141
-
142
- if tokenizer.pad_token is None:
143
- tokenizer.pad_token = tokenizer.eos_token
144
- model.config.pad_token_id = model.config.eos_token_id
145
-
146
- embedder = SentenceTransformer("all-MiniLM-L6-v2")
147
- q_embedding = embedder.encode([question])
148
- _, top_k_indices = index.search(q_embedding, k=3)
149
- relevant_chunks = [chunks[i] for i in top_k_indices[0]]
150
- context = " ".join(relevant_chunks)
151
-
152
- prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
153
-
154
- # Handle inputs without explicit device placement
155
- inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
156
- # Let the model handle device placement internally
157
-
158
  try:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  output = model.generate(
160
  **inputs,
161
  max_new_tokens=300,
@@ -165,15 +268,19 @@ def answer_with_generation(index, embeddings, chunks, question):
165
  num_beams=3,
166
  no_repeat_ngram_size=2
167
  )
 
 
168
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
169
  if "Detailed answer:" in answer:
170
- return answer.split("Detailed answer:")[-1].strip()
 
 
171
  return answer.strip()
172
  except Exception as e:
173
- print(f"Generation error: {str(e)}")
174
  return "I couldn't generate a good answer based on the PDF content."
175
 
176
- # API route
177
  @app.route('/')
178
  def home():
179
  return jsonify({"message": "PDF QA API is running!"})
@@ -182,28 +289,58 @@ def home():
182
  def ask():
183
  file = request.files.get("pdf")
184
  question = request.form.get("question", "")
 
185
 
186
  if not file or not question:
187
  return jsonify({"error": "Both PDF file and question are required"}), 400
188
 
189
- filename = secure_filename(file.filename)
190
- filepath = os.path.join(UPLOAD_FOLDER, filename)
191
- file.save(filepath)
192
-
193
  try:
194
- # 🧠 Process PDF and generate answer
 
 
 
 
 
 
195
  text = extract_text(filepath)
 
 
 
196
  chunks = split_into_chunks(text)
197
- answer = answer_with_qa_pipeline(chunks, question)
198
-
199
- if len(answer.strip()) < 20:
200
- index, embeddings, chunks = setup_faiss(chunks)
201
- answer = answer_with_generation(index, embeddings, chunks, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  return jsonify({"answer": answer})
204
 
205
  except Exception as e:
206
- return jsonify({"error": str(e)}), 500
 
 
 
 
 
207
 
208
  if __name__ == "__main__":
209
- app.run(host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
  from flask import Flask, request, jsonify
2
  from werkzeug.utils import secure_filename
3
+ from flask_cors import CORS
4
  import os
5
  import torch
6
  import fitz # PyMuPDF
 
12
  import numpy as np
13
  import tempfile
14
  from PIL import Image
15
+ import logging
16
+
17
+ # Set up logging
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
 
21
  # Fix caching issue on Hugging Face Spaces
22
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
 
24
  os.environ["XDG_CACHE_HOME"] = "/tmp"
25
 
26
  app = Flask(__name__)
27
+ CORS(app) # Enable CORS for all routes
28
 
29
  UPLOAD_FOLDER = "/tmp/uploads"
30
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
31
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ logger.info(f"Using device: {device}")
34
+
35
+ # Global model variables
36
+ embedder = None
37
+ qa_pipeline = None
38
+ tokenizer = None
39
+ model = None
40
+
41
+ # Initialize models once on startup
42
+ def initialize_models():
43
+ global embedder, qa_pipeline, tokenizer, model
44
+ try:
45
+ logger.info("Loading SentenceTransformer model...")
46
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
47
+
48
+ logger.info("Loading QA pipeline...")
49
+ qa_pipeline = pipeline(
50
+ "question-answering",
51
+ model="distilbert-base-cased-distilled-squad",
52
+ tokenizer="distilbert-base-cased",
53
+ device=0 if device == "cuda" else -1
54
+ )
55
+
56
+ logger.info("Loading language model...")
57
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
58
+ model = AutoModelForCausalLM.from_pretrained(
59
+ "distilgpt2",
60
+ device_map="auto",
61
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
62
+ )
63
+
64
+ if tokenizer.pad_token is None:
65
+ tokenizer.pad_token = tokenizer.eos_token
66
+ model.config.pad_token_id = model.config.eos_token_id
67
+
68
+ logger.info("Models initialized successfully")
69
+ except Exception as e:
70
+ logger.error(f"Error initializing models: {str(e)}")
71
+ raise
72
+
73
+ # Cleanup function for temporary files
74
+ def cleanup_temp_files(filepath):
75
+ try:
76
+ if os.path.exists(filepath):
77
+ os.remove(filepath)
78
+ logger.info(f"Removed temporary file: {filepath}")
79
+ except Exception as e:
80
+ logger.warning(f"Failed to clean up file {filepath}: {str(e)}")
81
 
82
  # Improved OCR function
83
  def ocr_pdf(pdf_path):
84
  try:
85
+ logger.info(f"Starting OCR for {pdf_path}")
86
  # Use a higher DPI for better quality
87
  images = convert_from_path(
88
  pdf_path,
 
93
  )
94
 
95
  text = ""
96
+ for i, img in enumerate(images):
97
+ logger.info(f"Processing page {i+1} of {len(images)}")
98
  # Preprocess the image for better OCR results
99
  preprocessed = preprocess_image_for_ocr(img)
100
  # Use tesseract with more options
101
+ page_text = pytesseract.image_to_string(
102
  preprocessed,
103
  config='--psm 1 --oem 3 -l eng' # Page segmentation mode 1 (auto), OCR Engine mode 3 (default)
104
  )
105
+ text += page_text
106
+ logger.info(f"OCR completed with {len(text)} characters extracted")
107
  return text
108
  except Exception as e:
109
+ logger.error(f"OCR error: {str(e)}")
110
  return ""
111
 
112
  # Image preprocessing function for better OCR
 
123
 
124
  # Improved extract_text function with better text detection
125
  def extract_text(pdf_path):
126
+ try:
127
+ logger.info(f"Extracting text from {pdf_path}")
128
+ doc = fitz.open(pdf_path)
129
+ text = ""
130
+ for page_num, page in enumerate(doc):
131
+ page_text = page.get_text()
132
+ text += page_text
133
+ logger.info(f"Extracted {len(page_text)} characters from page {page_num+1}")
134
+
135
+ # Check if the text is meaningful (more sophisticated check)
136
+ words = text.split()
137
+ unique_words = set(word.lower() for word in words if len(word) > 2)
138
+
139
+ logger.info(f"PDF text extraction: {len(text)} chars, {len(words)} words, {len(unique_words)} unique words")
140
+
141
+ # If we don't have enough meaningful text, try OCR
142
+ if len(unique_words) < 20 or len(text.strip()) < 100:
143
+ logger.info("Text extraction yielded insufficient results, trying OCR...")
144
+ ocr_text = ocr_pdf(pdf_path)
145
+ # If OCR gave us more text, use it
146
+ if len(ocr_text.strip()) > len(text.strip()):
147
+ logger.info(f"Using OCR result: {len(ocr_text)} chars (better than {len(text)} chars)")
148
+ text = ocr_text
149
+
150
+ return text
151
+ except Exception as e:
152
+ logger.error(f"Text extraction error: {str(e)}")
153
+ return ""
154
 
155
+ # Split into chunks
156
  def split_into_chunks(text, max_tokens=300, overlap=50):
157
+ logger.info(f"Splitting text into chunks (max_tokens={max_tokens}, overlap={overlap})")
158
  sentences = text.split('.')
159
  chunks, current = [], ''
160
  for sentence in sentences:
 
170
  current = sentence
171
  if current:
172
  chunks.append(current.strip())
173
+ logger.info(f"Split text into {len(chunks)} chunks")
174
  return chunks
175
 
176
+ # Setup FAISS
177
  def setup_faiss(chunks):
178
+ try:
179
+ logger.info("Setting up FAISS index")
180
+ global embedder
181
+ if embedder is None:
182
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
183
+
184
+ embeddings = embedder.encode(chunks)
185
+ dim = embeddings.shape[1]
186
+ index = faiss.IndexFlatL2(dim)
187
+ index.add(embeddings)
188
+ logger.info(f"FAISS index created with {len(chunks)} chunks and dimension {dim}")
189
+ return index, embeddings, chunks
190
+ except Exception as e:
191
+ logger.error(f"FAISS setup error: {str(e)}")
192
+ raise
193
+
194
+ # QA pipeline
195
  def answer_with_qa_pipeline(chunks, question):
 
 
 
 
 
 
 
196
  try:
197
+ logger.info(f"Answering with QA pipeline: '{question}'")
198
+ global qa_pipeline
199
+ if qa_pipeline is None:
200
+ logger.info("QA pipeline not initialized, creating now...")
201
+ qa_pipeline = pipeline(
202
+ "question-answering",
203
+ model="distilbert-base-cased-distilled-squad",
204
+ tokenizer="distilbert-base-cased",
205
+ device=0 if device == "cuda" else -1
206
+ )
207
+
208
+ # Limit context size to avoid token length issues
209
+ context = " ".join(chunks[:5])
210
+ if len(context) > 5000: # Approx token limit
211
+ context = context[:5000]
212
+
213
  result = qa_pipeline(question=question, context=context)
214
+ logger.info(f"QA pipeline answer: '{result['answer']}' (score: {result['score']})")
215
  return result["answer"]
216
+ except Exception as e:
217
+ logger.error(f"QA pipeline error: {str(e)}")
218
  return ""
219
 
220
+ # Generation-based answering
221
  def answer_with_generation(index, embeddings, chunks, question):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
  try:
223
+ logger.info(f"Answering with generation model: '{question}'")
224
+ global tokenizer, model
225
+
226
+ if tokenizer is None or model is None:
227
+ logger.info("Generation models not initialized, creating now...")
228
+ tokenizer = AutoTokenizer.from_pretrained("distilgpt2")
229
+ model = AutoModelForCausalLM.from_pretrained(
230
+ "distilgpt2",
231
+ device_map="auto",
232
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
233
+ )
234
+
235
+ if tokenizer.pad_token is None:
236
+ tokenizer.pad_token = tokenizer.eos_token
237
+ model.config.pad_token_id = model.config.eos_token_id
238
+
239
+ # Get embeddings for question
240
+ q_embedding = embedder.encode([question])
241
+
242
+ # Find relevant chunks
243
+ _, top_k_indices = index.search(q_embedding, k=3)
244
+ relevant_chunks = [chunks[i] for i in top_k_indices[0]]
245
+ context = " ".join(relevant_chunks)
246
+
247
+ # Limit context size to avoid token length issues
248
+ if len(context) > 4000:
249
+ context = context[:4000]
250
+
251
+ # Create prompt
252
+ prompt = f"Answer the following question based on this information:\n\nInformation: {context}\n\nQuestion: {question}\n\nDetailed answer:"
253
+
254
+ # Handle inputs
255
+ inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
256
+
257
+ # Move inputs to the right device if needed
258
+ if torch.cuda.is_available():
259
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
260
+
261
+ # Generate answer
262
  output = model.generate(
263
  **inputs,
264
  max_new_tokens=300,
 
268
  num_beams=3,
269
  no_repeat_ngram_size=2
270
  )
271
+
272
+ # Decode and format answer
273
  answer = tokenizer.decode(output[0], skip_special_tokens=True)
274
  if "Detailed answer:" in answer:
275
+ answer = answer.split("Detailed answer:")[-1].strip()
276
+
277
+ logger.info(f"Generation answer: '{answer[:50]}...' (length: {len(answer)})")
278
  return answer.strip()
279
  except Exception as e:
280
+ logger.error(f"Generation error: {str(e)}")
281
  return "I couldn't generate a good answer based on the PDF content."
282
 
283
+ # API route
284
  @app.route('/')
285
  def home():
286
  return jsonify({"message": "PDF QA API is running!"})
 
289
  def ask():
290
  file = request.files.get("pdf")
291
  question = request.form.get("question", "")
292
+ filepath = None
293
 
294
  if not file or not question:
295
  return jsonify({"error": "Both PDF file and question are required"}), 400
296
 
 
 
 
 
297
  try:
298
+ filename = secure_filename(file.filename)
299
+ filepath = os.path.join(UPLOAD_FOLDER, filename)
300
+ file.save(filepath)
301
+
302
+ logger.info(f"Processing file: {filename}, Question: '{question}'")
303
+
304
+ # Process PDF and generate answer
305
  text = extract_text(filepath)
306
+ if not text.strip():
307
+ return jsonify({"error": "Could not extract text from the PDF"}), 400
308
+
309
  chunks = split_into_chunks(text)
310
+ if not chunks:
311
+ return jsonify({"error": "PDF content couldn't be processed"}), 400
312
+
313
+ try:
314
+ answer = answer_with_qa_pipeline(chunks, question)
315
+ except Exception as e:
316
+ logger.warning(f"QA pipeline failed: {str(e)}")
317
+ answer = ""
318
+
319
+ # If QA pipeline didn't give a good answer, try generation
320
+ if not answer or len(answer.strip()) < 20:
321
+ try:
322
+ logger.info("QA pipeline answer insufficient, trying generation...")
323
+ index, embeddings, chunks = setup_faiss(chunks)
324
+ answer = answer_with_generation(index, embeddings, chunks, question)
325
+ except Exception as e:
326
+ logger.error(f"Generation fallback failed: {str(e)}")
327
+ return jsonify({"error": "Failed to generate answer from PDF content"}), 500
328
 
329
  return jsonify({"answer": answer})
330
 
331
  except Exception as e:
332
+ logger.error(f"Error processing request: {str(e)}")
333
+ return jsonify({"error": f"An error occurred processing your request: {str(e)}"}), 500
334
+ finally:
335
+ # Always clean up, even if errors occur
336
+ if filepath:
337
+ cleanup_temp_files(filepath)
338
 
339
  if __name__ == "__main__":
340
+ try:
341
+ # Initialize models at startup
342
+ initialize_models()
343
+ logger.info("Starting Flask application")
344
+ app.run(host="0.0.0.0", port=7860)
345
+ except Exception as e:
346
+ logger.critical(f"Failed to start application: {str(e)}")