import torch from torchvision.models.detection import fasterrcnn_resnet50_fpn from torchvision.models.detection.faster_rcnn import FastRCNNPredictor from torchvision.transforms import functional as F from PIL import Image, ImageDraw import gradio as gr # Label names COCO_CLASSES = { 0: "Background", 1: "Without Mask", 2: "With Mask", 3: "Incorrect Mask" } # Load model def get_model(num_classes=4): model = fasterrcnn_resnet50_fpn(weights=None) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model # Setup device = torch.device("cpu") model = get_model() model.load_state_dict(torch.load("fasterrcnn_resnet50_epoch_4.pth", map_location=device)) model.to(device) model.eval() # Inference function def predict(image): image_tensor = F.to_tensor(image).unsqueeze(0).to(device) with torch.no_grad(): prediction = model(image_tensor) boxes = prediction[0]["boxes"] labels = prediction[0]["labels"] scores = prediction[0]["scores"] draw = ImageDraw.Draw(image) threshold = 0.5 for box, label, score in zip(boxes, labels, scores): if score > threshold: x1, y1, x2, y2 = box class_name = COCO_CLASSES.get(label.item(), "Unknown") draw.rectangle([x1, y1, x2, y2], outline="red", width=3) draw.text((x1, y1), f"{class_name} ({score:.2f})", fill="red") return image # Gradio Interface gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload a Face Image"), outputs=gr.Image(type="pil", label="Detection Result"), title="Face Mask Detection - Faster R-CNN", description="Detects faces with mask, without mask, or incorrectly worn mask." ).launch()