Vlm-test / app.py
mike23415's picture
Update app.py
035a6f9 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import io
import os
# Option 1: Using TrOCR (Transformer-based OCR)
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import torch
# Option 2: Using EasyOCR (commented out - uncomment if you prefer this)
# import easyocr
# Option 3: Using Tesseract (commented out - uncomment if you prefer this)
# import pytesseract
import fitz # PyMuPDF
# Initialize Flask
app = Flask(__name__)
CORS(app)
# Load TrOCR model and processor (better for text extraction)
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(device)
model.eval()
# Alternative: Initialize EasyOCR reader (uncomment if using EasyOCR)
# reader = easyocr.Reader(['en'])
def convert_pdf_to_image(file_stream):
"""Convert PDF to image with higher DPI for better OCR"""
doc = fitz.open(stream=file_stream.read(), filetype="pdf")
page = doc.load_page(0)
# Increase DPI for better text recognition
pix = page.get_pixmap(dpi=300) # Higher DPI
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return img
def preprocess_image(image):
"""Preprocess image for better OCR results"""
# Convert to grayscale if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize if image is too small
width, height = image.size
if width < 1000 or height < 1000:
scale_factor = max(1000/width, 1000/height)
new_width = int(width * scale_factor)
new_height = int(height * scale_factor)
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
return image
def extract_text_trocr(image):
"""Extract text using TrOCR"""
try:
# Split image into chunks if it's large (TrOCR works better on smaller sections)
width, height = image.size
chunk_height = 400 # Process in chunks
extracted_texts = []
for y in range(0, height, chunk_height):
chunk = image.crop((0, y, width, min(y + chunk_height, height)))
# Process with TrOCR
pixel_values = processor(chunk, return_tensors="pt").pixel_values.to(device)
with torch.no_grad():
generated_ids = model.generate(pixel_values, max_length=512)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if generated_text.strip():
extracted_texts.append(generated_text.strip())
return "\n".join(extracted_texts)
except Exception as e:
print(f"TrOCR error: {e}")
return ""
def extract_text_easyocr(image):
"""Extract text using EasyOCR (uncomment the import and initialization above)"""
try:
results = reader.readtext(image)
extracted_text = []
for (bbox, text, confidence) in results:
if confidence > 0.5: # Filter low confidence detections
extracted_text.append(text)
return "\n".join(extracted_text)
except Exception as e:
print(f"EasyOCR error: {e}")
return ""
def extract_text_tesseract(image):
"""Extract text using Tesseract (uncomment the import above)"""
try:
# Convert to grayscale for better OCR
gray_image = image.convert('L')
text = pytesseract.image_to_string(gray_image, config='--psm 6')
return text.strip()
except Exception as e:
print(f"Tesseract error: {e}")
return ""
@app.route("/ocr", methods=["POST"])
def ocr():
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
if not file.filename:
return jsonify({"error": "No file selected"}), 400
filename = file.filename.lower()
try:
# Convert input to PIL image
if filename.endswith(".pdf"):
image = convert_pdf_to_image(file)
else:
image = Image.open(io.BytesIO(file.read())).convert("RGB")
# Preprocess image
image = preprocess_image(image)
# Extract text using TrOCR (primary method)
extracted_text = extract_text_trocr(image)
# If TrOCR fails or returns empty, try alternative methods
if not extracted_text:
print("TrOCR failed, trying alternative methods...")
# Uncomment one of these if you have the libraries installed:
# extracted_text = extract_text_easyocr(image)
# extracted_text = extract_text_tesseract(image)
if not extracted_text:
return jsonify({
"text": "",
"message": "No text could be extracted from the image. The image might be too blurry, have low contrast, or contain handwritten text."
})
return jsonify({
"text": extracted_text,
"message": "Text extracted successfully"
})
except Exception as e:
print(f"OCR processing error: {e}")
return jsonify({"error": f"Failed to process file: {str(e)}"}), 500
@app.route("/", methods=["GET"])
def index():
return "Smart OCR Flask API (TrOCR-based)"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860, debug=True)