Spaces:
Build error
Build error
import torch | |
import torch.nn as nn | |
from torchvision import transforms | |
from torchvision.models import convnext_tiny | |
from ultralytics import YOLO | |
import numpy as np | |
import cv2 | |
import gradio as gr | |
from PIL import Image, ImageDraw | |
from fast_alpr import ALPR | |
# ------------------ Constants and Models ------------------ | |
class_names = [ | |
'beige', 'black', 'blue', 'brown', 'gold', | |
'green', 'grey', 'orange', 'pink', 'purple', | |
'red', 'silver', 'tan', 'white', 'yellow' | |
] | |
DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end" | |
OCR_MODEL = "global-plates-mobile-vit-v2-model" | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = convnext_tiny(pretrained=False) | |
model.classifier[2] = nn.Linear(768, len(class_names)) | |
model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device)) | |
model = model.to(device) | |
model.eval() | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225]) | |
]) | |
yolo_model = YOLO("yolo11x.pt") | |
# ------------------ Unified Inference Function ------------------ | |
def alpr_color_inference(image): | |
if image is None: | |
return None, None, None, "Please upload an image to continue." | |
img = image.convert("RGB") | |
img_array = np.array(img) | |
alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL) | |
results = alpr.predict(img_array) | |
annotated_img = Image.fromarray(img_array.copy()) | |
draw = ImageDraw.Draw(annotated_img) | |
plate_texts = [] | |
for result in results: | |
detection = getattr(result, 'detection', None) | |
ocr = getattr(result, 'ocr', None) | |
if detection is not None: | |
bbox_obj = getattr(detection, 'bounding_box', None) | |
if bbox_obj is not None: | |
bbox = [int(bbox_obj.x1), int(bbox_obj.y1), int(bbox_obj.x2), int(bbox_obj.y2)] | |
draw.rectangle(bbox, outline="red", width=3) | |
if ocr is not None: | |
text = getattr(ocr, 'text', '') | |
plate_texts.append(text) | |
draw.text((bbox[0], max(bbox[1] - 10, 0)), text, fill="red") | |
# Color Detection | |
img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
yolo_results = yolo_model(img_cv2) | |
boxes = yolo_results[0].boxes | |
vehicle_class_ids = {2, 3, 5, 7} # car, motorcycle, bus, truck | |
vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids] | |
if not vehicle_boxes: | |
color_text = "No vehicle detected" | |
cropped_img = img | |
else: | |
largest_vehicle = max(vehicle_boxes, key=lambda box: (box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1])) | |
x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist()) | |
cropped_img = img.crop((x1, y1, x2, y2)) | |
input_tensor = transform(cropped_img).unsqueeze(0).to(device) | |
with torch.no_grad(): | |
output = model(input_tensor) | |
probs = torch.softmax(output, dim=1)[0] | |
pred_idx = torch.argmax(probs).item() | |
pred_class = class_names[pred_idx] | |
confidence = probs[pred_idx].item() | |
draw.rectangle((x1, y1, x2, y2), outline="blue", width=3) | |
draw.text((x1, max(y1 - 10, 0)), f"{pred_class} ({confidence*100:.1f}%)", fill="blue") | |
color_text = f"{pred_class} ({confidence*100:.1f}%)" | |
detection_results = (f"Detected {len(results)} license plate(s): {', '.join(plate_texts)}" | |
if results else "No license plate detected 😔.") | |
return annotated_img, cropped_img, f"{detection_results}\nVehicle Color: {color_text}" | |
# ------------------ Gradio UI ------------------ | |
with gr.Blocks() as demo: | |
gr.Markdown("# License Plate + Vehicle Color Detection") | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image(type="pil", label="Upload an image") | |
submit_btn = gr.Button("Run Detection") | |
gr.Examples( | |
examples=[ | |
"examples/car5.jpg", | |
"examples/car2.jpg", | |
"examples/car3.jpg", | |
"examples/car4.jpg", | |
"examples/car6.jpg", | |
"examples/car7.jpg", | |
], | |
inputs=[image_input], | |
label="Example Images" | |
) | |
with gr.Column(): | |
plate_output = gr.Image(label="Combined Detection Output") | |
cropped_output = gr.Image(label="(Optional) Cropped Vehicle Region") | |
result_text = gr.Markdown(label="Results") | |
submit_btn.click( | |
alpr_color_inference, | |
inputs=[image_input], | |
outputs=[plate_output, cropped_output, result_text] | |
) | |
if __name__ == "__main__": | |
demo.launch(share=True) | |