IZERE HIRWA Roger commited on
Commit
b7d66ac
Β·
1 Parent(s): 1cc833a
Files changed (4) hide show
  1. app.py +139 -5
  2. requirements.txt +4 -1
  3. static/index.html +8 -3
  4. static/script.js +7 -3
app.py CHANGED
@@ -17,6 +17,11 @@ from datetime import datetime, timedelta
17
  import jwt
18
  import sqlite3
19
  import tempfile
 
 
 
 
 
20
 
21
  app = Flask(__name__)
22
  app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production'
@@ -109,6 +114,29 @@ except Exception as e:
109
  clip_model = None
110
  preprocess = None
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # Helper functions
113
  def save_index():
114
  try:
@@ -151,7 +179,95 @@ def image_from_pdf(pdf_bytes):
151
  print(f"❌ PDF conversion error: {e}")
152
  return None
153
 
154
- def extract_text(image):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  try:
156
  if image.mode != 'RGB':
157
  image = image.convert('RGB')
@@ -161,6 +277,13 @@ def extract_text(image):
161
  except Exception as e:
162
  return f"❌ OCR error: {str(e)}"
163
 
 
 
 
 
 
 
 
164
  def get_clip_embedding(image):
165
  try:
166
  if clip_model is None:
@@ -297,10 +420,10 @@ def classify_document():
297
  sim = float(1 - D[0][i])
298
  matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)})
299
 
300
- # Save classified document to SQLite
301
  if similarity >= confidence_threshold:
302
  saved_filename = save_uploaded_file(file_content, file.filename)
303
- ocr_text = extract_text(image)
304
 
305
  document_id = str(uuid.uuid4())
306
  conn = sqlite3.connect(DATABASE_PATH)
@@ -320,7 +443,8 @@ def classify_document():
320
  "confidence": "high",
321
  "matches": matches,
322
  "document_saved": True,
323
- "document_id": document_id
 
324
  })
325
  else:
326
  return jsonify({
@@ -452,8 +576,18 @@ def ocr_document():
452
  if image is None:
453
  return jsonify({"error": "Failed to process image"}), 400
454
 
 
455
  text = extract_text(image)
456
- return jsonify({"text": text, "status": "success"})
 
 
 
 
 
 
 
 
 
457
  except Exception as e:
458
  return jsonify({"error": str(e)}), 500
459
 
 
17
  import jwt
18
  import sqlite3
19
  import tempfile
20
+ import base64
21
+ from io import BytesIO
22
+ from transformers import AutoModelForImageTextToText, AutoProcessor, AutoTokenizer, TextIteratorStreamer
23
+ from threading import Thread
24
+ import time
25
 
26
  app = Flask(__name__)
27
  app.config['SECRET_KEY'] = 'your-secret-key-change-this-in-production'
 
114
  clip_model = None
115
  preprocess = None
116
 
117
+ # Initialize Nanonets OCR model
118
+ ocr_model = None
119
+ ocr_processor = None
120
+ ocr_tokenizer = None
121
+
122
+ try:
123
+ model_path = "nanonets/Nanonets-OCR-s"
124
+ print("Loading Nanonets OCR model...")
125
+ ocr_model = AutoModelForImageTextToText.from_pretrained(
126
+ model_path,
127
+ torch_dtype="auto",
128
+ device_map="auto",
129
+ trust_remote_code=True
130
+ )
131
+ ocr_model.eval()
132
+
133
+ ocr_processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
134
+ ocr_tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
135
+ print("βœ… Nanonets OCR model loaded successfully!")
136
+ except Exception as e:
137
+ print(f"❌ Failed to load Nanonets OCR model: {e}")
138
+ print("πŸ“ Falling back to pytesseract for OCR")
139
+
140
  # Helper functions
141
  def save_index():
142
  try:
 
179
  print(f"❌ PDF conversion error: {e}")
180
  return None
181
 
182
+ def process_tags(content: str) -> str:
183
+ """Process special tags from Nanonets OCR output"""
184
+ content = content.replace("<img>", "&lt;img&gt;")
185
+ content = content.replace("</img>", "&lt;/img&gt;")
186
+ content = content.replace("<watermark>", "&lt;watermark&gt;")
187
+ content = content.replace("</watermark>", "&lt;/watermark&gt;")
188
+ content = content.replace("<page_number>", "&lt;page_number&gt;")
189
+ content = content.replace("</page_number>", "&lt;/page_number&gt;")
190
+ content = content.replace("<signature>", "&lt;signature&gt;")
191
+ content = content.replace("</signature>", "&lt;/signature&gt;")
192
+ return content
193
+
194
+ def encode_image(image: Image) -> str:
195
+ """Encode image to base64 for Nanonets OCR"""
196
+ buffered = BytesIO()
197
+ image.save(buffered, format="JPEG")
198
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
199
+ return img_str
200
+
201
+ def nanonets_ocr_extract(image):
202
+ """Extract text using Nanonets OCR model"""
203
+ try:
204
+ if ocr_model is None or ocr_processor is None or ocr_tokenizer is None:
205
+ # Fallback to pytesseract
206
+ return extract_text_pytesseract(image)
207
+
208
+ if image.mode != 'RGB':
209
+ image = image.convert('RGB')
210
+
211
+ # Resize image for optimal processing
212
+ image = image.resize((2048, 2048))
213
+
214
+ # Prepare prompt for OCR extraction
215
+ user_prompt = """Extract the text from the above document as if you were reading it naturally. Return the tables in html format. Return the equations in LaTeX representation. Watermarks should be wrapped in brackets. Ex: <watermark>OFFICIAL COPY</watermark>. Page numbers should be wrapped in brackets. Ex: <page_number>14</page_number> or <page_number>9/22</page_number>. Prefer using ☐ and β˜‘ for check boxes."""
216
+
217
+ # Format messages for the model
218
+ formatted_messages = [
219
+ {"role": "system", "content": "You are a helpful assistant."},
220
+ {"role": "user", "content": [
221
+ {"type": "image", "image": image},
222
+ {"type": "text", "text": user_prompt},
223
+ ]},
224
+ ]
225
+
226
+ # Apply chat template
227
+ text = ocr_processor.apply_chat_template(
228
+ formatted_messages,
229
+ tokenize=False,
230
+ add_generation_prompt=True
231
+ )
232
+
233
+ # Process inputs
234
+ inputs = ocr_processor(
235
+ text=[text],
236
+ images=[image],
237
+ padding=True,
238
+ return_tensors="pt"
239
+ )
240
+
241
+ # Move inputs to model device
242
+ inputs = {k: v.to(ocr_model.device) if hasattr(v, 'to') else v for k, v in inputs.items()}
243
+
244
+ # Generate text
245
+ with torch.no_grad():
246
+ generated_ids = ocr_model.generate(
247
+ **inputs,
248
+ max_new_tokens=4096,
249
+ do_sample=False,
250
+ pad_token_id=ocr_tokenizer.eos_token_id,
251
+ )
252
+
253
+ # Decode the generated text
254
+ generated_text = ocr_tokenizer.decode(
255
+ generated_ids[0][inputs['input_ids'].shape[1]:],
256
+ skip_special_tokens=True
257
+ )
258
+
259
+ # Process special tags
260
+ processed_text = process_tags(generated_text)
261
+
262
+ return processed_text.strip() if processed_text.strip() else "❓ No text detected"
263
+
264
+ except Exception as e:
265
+ print(f"❌ Nanonets OCR error: {e}")
266
+ # Fallback to pytesseract
267
+ return extract_text_pytesseract(image)
268
+
269
+ def extract_text_pytesseract(image):
270
+ """Fallback OCR using pytesseract"""
271
  try:
272
  if image.mode != 'RGB':
273
  image = image.convert('RGB')
 
277
  except Exception as e:
278
  return f"❌ OCR error: {str(e)}"
279
 
280
+ def extract_text(image):
281
+ """Main OCR function - tries Nanonets first, falls back to pytesseract"""
282
+ if ocr_model is not None:
283
+ return nanonets_ocr_extract(image)
284
+ else:
285
+ return extract_text_pytesseract(image)
286
+
287
  def get_clip_embedding(image):
288
  try:
289
  if clip_model is None:
 
420
  sim = float(1 - D[0][i])
421
  matches.append({"category": labels[I[0][i]], "similarity": round(sim, 3)})
422
 
423
+ # Save classified document to SQLite with enhanced OCR
424
  if similarity >= confidence_threshold:
425
  saved_filename = save_uploaded_file(file_content, file.filename)
426
+ ocr_text = extract_text(image) # Now uses Nanonets OCR
427
 
428
  document_id = str(uuid.uuid4())
429
  conn = sqlite3.connect(DATABASE_PATH)
 
443
  "confidence": "high",
444
  "matches": matches,
445
  "document_saved": True,
446
+ "document_id": document_id,
447
+ "ocr_preview": ocr_text[:200] + "..." if len(ocr_text) > 200 else ocr_text
448
  })
449
  else:
450
  return jsonify({
 
576
  if image is None:
577
  return jsonify({"error": "Failed to process image"}), 400
578
 
579
+ # Use enhanced Nanonets OCR
580
  text = extract_text(image)
581
+
582
+ # Determine OCR method used
583
+ ocr_method = "Nanonets OCR-s" if ocr_model is not None else "Pytesseract"
584
+
585
+ return jsonify({
586
+ "text": text,
587
+ "status": "success",
588
+ "ocr_method": ocr_method,
589
+ "enhanced_features": ocr_model is not None
590
+ })
591
  except Exception as e:
592
  return jsonify({"error": str(e)}), 500
593
 
requirements.txt CHANGED
@@ -8,4 +8,7 @@ torchvision
8
  Pillow
9
  PyJWT
10
  git+https://github.com/openai/CLIP.git
11
- poppler-utils
 
 
 
 
8
  Pillow
9
  PyJWT
10
  git+https://github.com/openai/CLIP.git
11
+ poppler-utils
12
+ transformers
13
+ accelerate
14
+ spaces
static/index.html CHANGED
@@ -255,10 +255,14 @@
255
  <div id="ocr" class="tab-content">
256
  <div class="card">
257
  <div class="card-header bg-warning text-dark">
258
- <h4><i class="fas fa-eye me-2"></i>OCR Text Extraction</h4>
259
  </div>
260
  <div class="card-body">
261
- <p class="text-muted">Extract text from documents using Optical Character Recognition.</p>
 
 
 
 
262
 
263
  <form id="ocrForm" class="row g-3">
264
  <div class="col-12">
@@ -266,12 +270,13 @@
266
  <div class="file-upload border rounded p-4 text-center" id="ocrUpload">
267
  <i class="fas fa-file-alt fa-3x text-warning mb-3"></i>
268
  <p class="mb-0">Click to select or drag & drop files here</p>
 
269
  <input type="file" id="ocrFile" accept="image/*,.pdf" class="d-none">
270
  </div>
271
  </div>
272
  <div class="col-12">
273
  <button type="submit" class="btn btn-warning">
274
- <i class="fas fa-search me-2"></i>Extract Text
275
  </button>
276
  </div>
277
  </form>
 
255
  <div id="ocr" class="tab-content">
256
  <div class="card">
257
  <div class="card-header bg-warning text-dark">
258
+ <h4><i class="fas fa-eye me-2"></i>Advanced OCR Text Extraction</h4>
259
  </div>
260
  <div class="card-body">
261
+ <div class="alert alert-info" role="alert">
262
+ <i class="fas fa-info-circle me-2"></i>
263
+ <strong>Enhanced OCR Features:</strong> Our advanced Nanonets OCR-s model supports table extraction (HTML), LaTeX equations, watermark detection, signature recognition, and checkbox handling.
264
+ </div>
265
+ <p class="text-muted">Extract text from documents using advanced Optical Character Recognition with AI-powered document understanding.</p>
266
 
267
  <form id="ocrForm" class="row g-3">
268
  <div class="col-12">
 
270
  <div class="file-upload border rounded p-4 text-center" id="ocrUpload">
271
  <i class="fas fa-file-alt fa-3x text-warning mb-3"></i>
272
  <p class="mb-0">Click to select or drag & drop files here</p>
273
+ <small class="text-muted">Supports: PDF, JPEG, PNG, TIFF</small>
274
  <input type="file" id="ocrFile" accept="image/*,.pdf" class="d-none">
275
  </div>
276
  </div>
277
  <div class="col-12">
278
  <button type="submit" class="btn btn-warning">
279
+ <i class="fas fa-robot me-2"></i>Extract Text with AI OCR
280
  </button>
281
  </div>
282
  </form>
static/script.js CHANGED
@@ -625,7 +625,7 @@ document.getElementById('ocrForm').addEventListener('submit', async (e) => {
625
  const formData = new FormData();
626
  formData.append('file', fileInput.files[0]);
627
 
628
- showResult(resultDiv, '<div class="loading"></div> Extracting text...', 'info');
629
 
630
  try {
631
  const response = await authenticatedFetch('/api/ocr', {
@@ -636,9 +636,13 @@ document.getElementById('ocrForm').addEventListener('submit', async (e) => {
636
  const result = await response.json();
637
 
638
  if (response.ok) {
639
- showResult(resultDiv, result.text, 'success');
 
 
 
 
640
  } else {
641
- showResult(resultDiv, result.detail, 'error');
642
  }
643
  } catch (error) {
644
  showResult(resultDiv, 'OCR failed: ' + error.message, 'error');
 
625
  const formData = new FormData();
626
  formData.append('file', fileInput.files[0]);
627
 
628
+ showResult(resultDiv, '<div class="loading"></div> Extracting text with advanced OCR...', 'info');
629
 
630
  try {
631
  const response = await authenticatedFetch('/api/ocr', {
 
636
  const result = await response.json();
637
 
638
  if (response.ok) {
639
+ const ocrInfo = result.enhanced_features ?
640
+ `πŸ€– Processed with ${result.ocr_method} (Enhanced Features: Tables, LaTeX, Watermarks)\n\n` :
641
+ `πŸ“ Processed with ${result.ocr_method}\n\n`;
642
+
643
+ showResult(resultDiv, ocrInfo + result.text, 'success');
644
  } else {
645
+ showResult(resultDiv, result.error || result.detail, 'error');
646
  }
647
  } catch (error) {
648
  showResult(resultDiv, 'OCR failed: ' + error.message, 'error');