Ayesha352 commited on
Commit
3578ff9
·
verified ·
1 Parent(s): c7e53db

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
+ from torchvision.models import convnext_tiny
5
+ from ultralytics import YOLO
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import cv2
10
+ import gradio as gr
11
+
12
+ # ---------- 1. Class labels ----------
13
+ class_names = [
14
+ 'beige', 'black', 'blue', 'brown', 'gold',
15
+ 'green', 'grey', 'orange', 'pink', 'purple',
16
+ 'red', 'silver', 'tan', 'white', 'yellow'
17
+ ]
18
+
19
+ # ---------- 2. Load ConvNeXt-Tiny Model ----------
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+ model = convnext_tiny(pretrained=False)
22
+ model.classifier[2] = nn.Linear(768, len(class_names))
23
+ model.load_state_dict(torch.load("convnext_best_model.pth", map_location=device))
24
+ model = model.to(device)
25
+ model.eval()
26
+
27
+ # ---------- 3. Image Transform ----------
28
+ transform = transforms.Compose([
29
+ transforms.Resize((512, 512)),
30
+ transforms.ToTensor(),
31
+ transforms.Normalize([0.485, 0.456, 0.406],
32
+ [0.229, 0.224, 0.225])
33
+ ])
34
+
35
+ # ---------- 4. Load YOLOv8 Model ----------
36
+ yolo_model = YOLO("yolo11x.pt")
37
+
38
+ # ---------- Gradio Inference Function ----------
39
+ def detect_vehicle_color(input_img):
40
+ img_original = input_img.convert("RGB")
41
+ img_cv2 = cv2.cvtColor(np.array(img_original), cv2.COLOR_RGB2BGR)
42
+
43
+ results = yolo_model(img_cv2)
44
+ boxes = results[0].boxes
45
+
46
+ # Vehicle class IDs: car, motorcycle, bus, truck
47
+ vehicle_class_ids = {2, 3, 5, 7}
48
+ vehicle_boxes = [box for box in boxes if int(box.cls.item()) in vehicle_class_ids]
49
+
50
+ if len(vehicle_boxes) == 0:
51
+ return "No vehicle detected", img_original, img_original
52
+
53
+ def box_area(box):
54
+ x1, y1, x2, y2 = box.xyxy[0].tolist()
55
+ return (x2 - x1) * (y2 - y1)
56
+
57
+ largest_vehicle = max(vehicle_boxes, key=box_area)
58
+ x1, y1, x2, y2 = map(int, largest_vehicle.xyxy[0].tolist())
59
+
60
+ cropped = img_original.crop((x1, y1, x2, y2))
61
+
62
+ input_tensor = transform(cropped).unsqueeze(0).to(device)
63
+ with torch.no_grad():
64
+ output = model(input_tensor)
65
+ probs = torch.softmax(output, dim=1)[0]
66
+ pred_idx = torch.argmax(probs).item()
67
+ pred_class = class_names[pred_idx]
68
+ confidence = probs[pred_idx].item()
69
+
70
+ # Draw bounding box on original image
71
+ img_with_box = np.array(img_original).copy()
72
+ cv2.rectangle(img_with_box, (x1, y1), (x2, y2), (255, 0, 0), 3)
73
+ img_with_box_pil = Image.fromarray(img_with_box)
74
+
75
+ return f"{pred_class} ({confidence*100:.1f}%)", img_with_box_pil, cropped
76
+
77
+ # ---------- Gradio UI ----------
78
+ demo = gr.Interface(
79
+ fn=detect_vehicle_color,
80
+ inputs=gr.Image(type="pil"),
81
+ outputs=[
82
+ gr.Text(label="Predicted Vehicle Color"),
83
+ gr.Image(label="Detected Vehicle in Original"),
84
+ gr.Image(label="Cropped Vehicle Region")
85
+ ],
86
+ title="🚗 Vehicle Color Detection",
87
+ description="Upload an image to detect the most prominent vehicle and its predicted color."
88
+ )
89
+
90
+ # ---------- Launch ----------
91
+ if __name__ == "__main__":
92
+ demo.launch()