MoinulwithAI commited on
Commit
c6c2d72
·
verified ·
1 Parent(s): 925ec04

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision.models.detection import fasterrcnn_resnet50_fpn
3
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
4
+ from torchvision.transforms import functional as F
5
+ from PIL import Image, ImageDraw
6
+ import gradio as gr
7
+
8
+ # Label names
9
+ COCO_CLASSES = {
10
+ 0: "Background",
11
+ 1: "Without Mask",
12
+ 2: "With Mask",
13
+ 3: "Incorrect Mask"
14
+ }
15
+
16
+ # Load model
17
+ def get_model(num_classes=4):
18
+ model = fasterrcnn_resnet50_fpn(weights=None)
19
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
20
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
21
+ return model
22
+
23
+ # Setup
24
+ device = torch.device("cpu")
25
+ model = get_model()
26
+ model.load_state_dict(torch.load("fasterrcnn_resnet50_epoch_4.pth", map_location=device))
27
+ model.to(device)
28
+ model.eval()
29
+
30
+ # Inference function
31
+ def predict(image):
32
+ image_tensor = F.to_tensor(image).unsqueeze(0).to(device)
33
+
34
+ with torch.no_grad():
35
+ prediction = model(image_tensor)
36
+
37
+ boxes = prediction[0]["boxes"]
38
+ labels = prediction[0]["labels"]
39
+ scores = prediction[0]["scores"]
40
+
41
+ draw = ImageDraw.Draw(image)
42
+ threshold = 0.5
43
+
44
+ for box, label, score in zip(boxes, labels, scores):
45
+ if score > threshold:
46
+ x1, y1, x2, y2 = box
47
+ class_name = COCO_CLASSES.get(label.item(), "Unknown")
48
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=3)
49
+ draw.text((x1, y1), f"{class_name} ({score:.2f})", fill="red")
50
+
51
+ return image
52
+
53
+ # Gradio Interface
54
+ gr.Interface(
55
+ fn=predict,
56
+ inputs=gr.Image(type="pil", label="Upload a Face Image"),
57
+ outputs=gr.Image(type="pil", label="Detection Result"),
58
+ title="Face Mask Detection - Faster R-CNN",
59
+ description="Detects faces with mask, without mask, or incorrectly worn mask."
60
+ ).launch()