Ayesha352's picture
Update app.py
31d6d91 verified
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import convnext_tiny
from ultralytics import YOLO
import numpy as np
import cv2
import gradio as gr
from PIL import Image, ImageDraw
from fast_alpr import ALPR
# ------------------ Constants and Models ------------------
class_names = [
'beige', 'black', 'blue', 'brown', 'gold',
'green', 'grey', 'orange', 'pink', 'purple',
'red', 'silver', 'tan', 'white', 'yellow'
]
DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end"
OCR_MODEL = "global-plates-mobile-vit-v2-model"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = convnext_tiny(pretrained=False)
model.classifier[2] = nn.Linear(768, len(class_names))
model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device))
model = model.to(device)
model.eval()
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
yolo_model = YOLO("yolo11x.pt")
# ------------------ Unified Inference Function ------------------
def alpr_color_inference(image):
if image is None:
return None, None, None, "Please upload an image to continue."
img = image.convert("RGB")
img_array = np.array(img)
alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL)
results = alpr.predict(img_array)
annotated_img = Image.fromarray(img_array.copy())
draw = ImageDraw.Draw(annotated_img)
plate_texts = []
for result in results:
detection = getattr(result, 'detection', None)
ocr = getattr(result, 'ocr', None)
if detection is not None:
bbox_obj = getattr(detection, 'bounding_box', None)
if bbox_obj is not None:
bbox = [int(bbox_obj.x1), int(bbox_obj.y1), int(bbox_obj.x2), int(bbox_obj.y2)]
draw.rectangle(bbox, outline="red", width=3)
if ocr is not None:
text = getattr(ocr, 'text', '')
plate_texts.append(text)
draw.text((bbox[0], max(bbox[1] - 10, 0)), text, fill="red")
# Color Detection
img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
yolo_results = yolo_model(img_cv2)
boxes = yolo_results[0].boxes
vehicle_class_ids = {2, 3, 5, 7} # car, motorcycle, bus, truck
vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids]
if not vehicle_boxes:
color_text = "No vehicle detected"
cropped_img = img
else:
largest_vehicle = max(vehicle_boxes, key=lambda box: (box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1]))
x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist())
cropped_img = img.crop((x1, y1, x2, y2))
input_tensor = transform(cropped_img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
probs = torch.softmax(output, dim=1)[0]
pred_idx = torch.argmax(probs).item()
pred_class = class_names[pred_idx]
confidence = probs[pred_idx].item()
draw.rectangle((x1, y1, x2, y2), outline="blue", width=3)
draw.text((x1, max(y1 - 10, 0)), f"{pred_class} ({confidence*100:.1f}%)", fill="blue")
color_text = f"{pred_class} ({confidence*100:.1f}%)"
detection_results = (f"Detected {len(results)} license plate(s): {', '.join(plate_texts)}"
if results else "No license plate detected 😔.")
return annotated_img, cropped_img, f"{detection_results}\nVehicle Color: {color_text}"
# ------------------ Gradio UI ------------------
with gr.Blocks() as demo:
gr.Markdown("# License Plate + Vehicle Color Detection")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="pil", label="Upload an image")
submit_btn = gr.Button("Run Detection")
gr.Examples(
examples=[
"examples/car5.jpg",
"examples/car2.jpg",
"examples/car3.jpg",
"examples/car4.jpg",
"examples/car6.jpg",
"examples/car7.jpg",
],
inputs=[image_input],
label="Example Images"
)
with gr.Column():
plate_output = gr.Image(label="Combined Detection Output")
cropped_output = gr.Image(label="(Optional) Cropped Vehicle Region")
result_text = gr.Markdown(label="Results")
submit_btn.click(
alpr_color_inference,
inputs=[image_input],
outputs=[plate_output, cropped_output, result_text]
)
if __name__ == "__main__":
demo.launch(share=True)