import torch import torch.nn as nn from torchvision import transforms from torchvision.models import convnext_tiny from ultralytics import YOLO import numpy as np import cv2 import gradio as gr from PIL import Image, ImageDraw from fast_alpr import ALPR # ------------------ Constants and Models ------------------ class_names = [ 'beige', 'black', 'blue', 'brown', 'gold', 'green', 'grey', 'orange', 'pink', 'purple', 'red', 'silver', 'tan', 'white', 'yellow' ] DETECTOR_MODEL = "yolo-v9-s-608-license-plate-end2end" OCR_MODEL = "global-plates-mobile-vit-v2-model" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = convnext_tiny(pretrained=False) model.classifier[2] = nn.Linear(768, len(class_names)) model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device)) model = model.to(device) model.eval() transform = transforms.Compose([ transforms.Resize((512, 512)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) yolo_model = YOLO("yolo11x.pt") # ------------------ Unified Inference Function ------------------ def alpr_color_inference(image): if image is None: return None, None, None, "Please upload an image to continue." img = image.convert("RGB") img_array = np.array(img) alpr = ALPR(detector_model=DETECTOR_MODEL, ocr_model=OCR_MODEL) results = alpr.predict(img_array) annotated_img = Image.fromarray(img_array.copy()) draw = ImageDraw.Draw(annotated_img) plate_texts = [] for result in results: detection = getattr(result, 'detection', None) ocr = getattr(result, 'ocr', None) if detection is not None: bbox_obj = getattr(detection, 'bounding_box', None) if bbox_obj is not None: bbox = [int(bbox_obj.x1), int(bbox_obj.y1), int(bbox_obj.x2), int(bbox_obj.y2)] draw.rectangle(bbox, outline="red", width=3) if ocr is not None: text = getattr(ocr, 'text', '') plate_texts.append(text) draw.text((bbox[0], max(bbox[1] - 10, 0)), text, fill="red") # Color Detection img_cv2 = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) yolo_results = yolo_model(img_cv2) boxes = yolo_results[0].boxes vehicle_class_ids = {2, 3, 5, 7} # car, motorcycle, bus, truck vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids] if not vehicle_boxes: color_text = "No vehicle detected" cropped_img = img else: largest_vehicle = max(vehicle_boxes, key=lambda box: (box.xyxy[0][2] - box.xyxy[0][0]) * (box.xyxy[0][3] - box.xyxy[0][1])) x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist()) cropped_img = img.crop((x1, y1, x2, y2)) input_tensor = transform(cropped_img).unsqueeze(0).to(device) with torch.no_grad(): output = model(input_tensor) probs = torch.softmax(output, dim=1)[0] pred_idx = torch.argmax(probs).item() pred_class = class_names[pred_idx] confidence = probs[pred_idx].item() draw.rectangle((x1, y1, x2, y2), outline="blue", width=3) draw.text((x1, max(y1 - 10, 0)), f"{pred_class} ({confidence*100:.1f}%)", fill="blue") color_text = f"{pred_class} ({confidence*100:.1f}%)" detection_results = (f"Detected {len(results)} license plate(s): {', '.join(plate_texts)}" if results else "No license plate detected 😔.") return annotated_img, cropped_img, f"{detection_results}\nVehicle Color: {color_text}" # ------------------ Gradio UI ------------------ with gr.Blocks() as demo: gr.Markdown("# License Plate + Vehicle Color Detection") with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil", label="Upload an image") submit_btn = gr.Button("Run Detection") gr.Examples( examples=[ "examples/car5.jpg", "examples/car2.jpg", "examples/car3.jpg", "examples/car4.jpg", "examples/car6.jpg", "examples/car7.jpg", ], inputs=[image_input], label="Example Images" ) with gr.Column(): plate_output = gr.Image(label="Combined Detection Output") cropped_output = gr.Image(label="(Optional) Cropped Vehicle Region") result_text = gr.Markdown(label="Results") submit_btn.click( alpr_color_inference, inputs=[image_input], outputs=[plate_output, cropped_output, result_text] ) if __name__ == "__main__": demo.launch(share=True)