Weight / app.py
Rammohan0504's picture
Update app.py
4eacc77 verified
import cv2
import numpy as np
import torch
from PIL import Image
import gradio as gr
import re
from ultralytics import YOLO
import easyocr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from scipy.signal import medfilt
# Load models
device = "cuda" if torch.cuda.is_available() else "cpu"
# YOLOv5 for digital meter detection (Pre-trained model)
yolo_model = YOLO("yolov5s.pt")
# OCR Models
ocr_reader = easyocr.Reader(["en"]) # EasyOCR
trocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-stage1")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-stage1").to(device)
# Image Preprocessing (Adaptive Threshold & Sharpening)
def enhance_image(image):
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
# Apply sharpening
kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
sharpened = cv2.filter2D(image, -1, kernel)
# Adaptive thresholding
thresholded = cv2.adaptiveThreshold(sharpened, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2)
return thresholded
# Convert Grayscale to RGB (Fix for TrOCR)
def convert_to_rgb(image):
if len(image.shape) == 2: # Grayscale image
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
return image
# Detect Digital Meter Using YOLOv5
def detect_meter(image):
results = yolo_model(image)
detected_meters = []
for result in results:
if hasattr(result, "boxes"): # Ensure correct format
for box in result.boxes:
if box.conf > 0.25: # Lower confidence threshold for better detection
detected_meters.append(box.xyxy.tolist())
return detected_meters
# Extract Text Using EasyOCR
def extract_text_easyocr(image):
text = " ".join(ocr_reader.readtext(image, detail=0))
return text
# Extract Text Using TrOCR
def extract_text_trocr(image):
image = convert_to_rgb(image) # Convert grayscale to RGB
image = Image.fromarray(image)
pixel_values = trocr_processor(images=image, return_tensors="pt").pixel_values.to(device)
generated_ids = trocr_model.generate(pixel_values)
text = trocr_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return text
# Extract Weight Using Regex
def extract_weight(text):
matches = re.findall(r'\d+\.\d+|\d+', text) # Extract numeric weight
return matches[0] if matches else None # Returns None if no weight detected
# Apply Statistical Filtering for Stability
def filter_weight_values(weights):
if len(weights) > 1:
weights = [float(w) for w in weights]
filtered_weight = medfilt(weights, kernel_size=3)[-1] # Smooth out variations
return str(round(filtered_weight, 2))
return weights[0] if weights else None
# Full Processing Pipeline (Dynamic Feedback)
def process_image(image):
enhanced = enhance_image(image)
detected_meters = detect_meter(image)
# OCR Extraction
text_easyocr = extract_text_easyocr(enhanced)
text_trocr = extract_text_trocr(enhanced)
# Prioritize numeric values from OCR
weight_easyocr = extract_weight(text_easyocr)
weight_trocr = extract_weight(text_trocr)
final_weights = [w for w in [weight_easyocr, weight_trocr] if w]
final_weight = filter_weight_values(final_weights)
# Handle failed detection cases dynamically
if not final_weight:
return "Try adjusting image clarity or detection thresholds."
return final_weight
# Gradio Interface
iface = gr.Interface(fn=process_image, inputs="image", outputs="text")
iface.launch()