|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from PIL import Image |
|
import io |
|
import os |
|
|
|
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
import torch |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import fitz |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") |
|
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(device) |
|
model.eval() |
|
|
|
|
|
|
|
|
|
def convert_pdf_to_image(file_stream): |
|
"""Convert PDF to image with higher DPI for better OCR""" |
|
doc = fitz.open(stream=file_stream.read(), filetype="pdf") |
|
page = doc.load_page(0) |
|
|
|
pix = page.get_pixmap(dpi=300) |
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
doc.close() |
|
return img |
|
|
|
def preprocess_image(image): |
|
"""Preprocess image for better OCR results""" |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
|
|
|
|
width, height = image.size |
|
if width < 1000 or height < 1000: |
|
scale_factor = max(1000/width, 1000/height) |
|
new_width = int(width * scale_factor) |
|
new_height = int(height * scale_factor) |
|
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
return image |
|
|
|
def extract_text_trocr(image): |
|
"""Extract text using TrOCR""" |
|
try: |
|
|
|
width, height = image.size |
|
chunk_height = 400 |
|
extracted_texts = [] |
|
|
|
for y in range(0, height, chunk_height): |
|
chunk = image.crop((0, y, width, min(y + chunk_height, height))) |
|
|
|
|
|
pixel_values = processor(chunk, return_tensors="pt").pixel_values.to(device) |
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate(pixel_values, max_length=512) |
|
|
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
if generated_text.strip(): |
|
extracted_texts.append(generated_text.strip()) |
|
|
|
return "\n".join(extracted_texts) |
|
except Exception as e: |
|
print(f"TrOCR error: {e}") |
|
return "" |
|
|
|
def extract_text_easyocr(image): |
|
"""Extract text using EasyOCR (uncomment the import and initialization above)""" |
|
try: |
|
results = reader.readtext(image) |
|
extracted_text = [] |
|
for (bbox, text, confidence) in results: |
|
if confidence > 0.5: |
|
extracted_text.append(text) |
|
return "\n".join(extracted_text) |
|
except Exception as e: |
|
print(f"EasyOCR error: {e}") |
|
return "" |
|
|
|
def extract_text_tesseract(image): |
|
"""Extract text using Tesseract (uncomment the import above)""" |
|
try: |
|
|
|
gray_image = image.convert('L') |
|
text = pytesseract.image_to_string(gray_image, config='--psm 6') |
|
return text.strip() |
|
except Exception as e: |
|
print(f"Tesseract error: {e}") |
|
return "" |
|
|
|
@app.route("/ocr", methods=["POST"]) |
|
def ocr(): |
|
if "file" not in request.files: |
|
return jsonify({"error": "No file uploaded"}), 400 |
|
|
|
file = request.files["file"] |
|
if not file.filename: |
|
return jsonify({"error": "No file selected"}), 400 |
|
|
|
filename = file.filename.lower() |
|
|
|
try: |
|
|
|
if filename.endswith(".pdf"): |
|
image = convert_pdf_to_image(file) |
|
else: |
|
image = Image.open(io.BytesIO(file.read())).convert("RGB") |
|
|
|
|
|
image = preprocess_image(image) |
|
|
|
|
|
extracted_text = extract_text_trocr(image) |
|
|
|
|
|
if not extracted_text: |
|
print("TrOCR failed, trying alternative methods...") |
|
|
|
|
|
|
|
|
|
if not extracted_text: |
|
return jsonify({ |
|
"text": "", |
|
"message": "No text could be extracted from the image. The image might be too blurry, have low contrast, or contain handwritten text." |
|
}) |
|
|
|
return jsonify({ |
|
"text": extracted_text, |
|
"message": "Text extracted successfully" |
|
}) |
|
|
|
except Exception as e: |
|
print(f"OCR processing error: {e}") |
|
return jsonify({"error": f"Failed to process file: {str(e)}"}), 500 |
|
|
|
@app.route("/", methods=["GET"]) |
|
def index(): |
|
return "Smart OCR Flask API (TrOCR-based)" |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860, debug=True) |