File size: 2,433 Bytes
7bc1e69
 
0698f2b
2b25983
0698f2b
 
7bc1e69
 
 
0698f2b
7bc1e69
2b25983
0698f2b
7bc1e69
 
 
 
 
 
 
 
 
0698f2b
 
 
 
7bc1e69
 
 
 
 
 
 
 
 
 
 
 
 
0698f2b
7bc1e69
 
2b25983
7bc1e69
 
 
0698f2b
 
 
 
7bc1e69
0698f2b
 
 
 
 
7bc1e69
 
 
 
 
 
 
 
 
 
0698f2b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import cv2
import numpy as np
import torch
from PIL import Image
import gradio as gr
import re
from ultralytics import YOLO
import easyocr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"

# YOLOv5 for digit detection (Pre-trained model)
yolo_model = YOLO("yolov5s.pt")  

# OCR Models
ocr_reader = easyocr.Reader(["en"])  # EasyOCR
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1").to(device)

# Image Preprocessing (Sharpen & Threshold)
def enhance_image(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
    sharpened = cv2.filter2D(image, -1, kernel)
    _, thresholded = cv2.threshold(sharpened, 150, 255, cv2.THRESH_BINARY)
    return thresholded

# Detect Digits Using YOLOv5
def detect_digits(image):
    results = yolo_model(image)
    detected_numbers = [det.xyxy.tolist()[0] for det in results.pred[0] if det.conf > 0.5]
    return detected_numbers

# Extract Text Using EasyOCR
def extract_text_easyocr(image):
    text = " ".join(ocr_reader.readtext(image, detail=0))
    return text

# Extract Text Using TrOCR
def extract_text_trocr(image):
    image = Image.fromarray(image)
    pixel_values = trocr_processor(images=image, return_tensors="pt").pixel_values.to(device)
    generated_ids = trocr_model.generate(pixel_values)
    text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return text

# Extract Weight Using Regex
def extract_weight(text):
    matches = re.findall(r'\d+\.\d+|\d+', text)  # Extract numeric weight
    return matches[0] if matches else "Weight not detected"

# Full Processing Pipeline
def process_image(image):
    enhanced = enhance_image(image)
    detected_digits = detect_digits(image)
    text_easyocr = extract_text_easyocr(enhanced)
    text_trocr = extract_text_trocr(enhanced)
    
    # Prioritize numeric values from OCR
    weight_easyocr = extract_weight(text_easyocr)
    weight_trocr = extract_weight(text_trocr)
    
    final_weight = weight_easyocr if weight_easyocr != "Weight not detected" else weight_trocr
    return final_weight or "Weight not detected"

# Gradio Interface
iface = gr.Interface(fn=process_image, inputs="image", outputs="text")
iface.launch()