File size: 4,412 Bytes
ef4706e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision
from torchvision import transforms
import xml.etree.ElementTree as ET
import torch.optim as optim
import matplotlib.pyplot as plt
import gradio as gr

# Ensure device is set to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class FaceMaskDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.resize = resize
        self.image_files = os.listdir(images_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")

        # Resize the image to a fixed size, while maintaining aspect ratio
        image = image.resize(self.resize)

        # Handle both .jpg and .png files
        annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
        
        if not os.path.exists(annotation_path):
            print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
            return None, None  # Return a tuple with None to skip the image/annotation pair
        
        boxes, labels = self.load_annotations(annotation_path)

        if boxes is None or labels is None:
            return None, None  # Skip this item if annotations are invalid

        target = {'boxes': boxes, 'labels': labels}

        if self.transform:
            image = self.transform(image)

        return image, target

    def load_annotations(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        boxes = []
        labels = []
        for obj in root.iter('object'):
            label = obj.find('name').text
            bndbox = obj.find('bndbox')
            xmin = float(bndbox.find('xmin').text)
            ymin = float(bndbox.find('ymin').text)
            xmax = float(bndbox.find('xmax').text)
            ymax = float(bndbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(1 if label == "mask" else 0)  # Assuming "mask" = 1, "no_mask" = 0

        if len(boxes) == 0 or len(labels) == 0:
            return None, None  # If no boxes or labels are found, return None

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        return boxes, labels


# Define the collate function for DataLoader
def collate_fn(batch):
    # Filter out None values and pack the rest into a batch
    batch = [item for item in batch if item[0] is not None and item[1] is not None]
    return tuple(zip(*batch))

# Load your pre-trained model (or initialize if required)
def load_model():
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    # Assuming 2 classes: mask and no_mask
    num_classes = 2  
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    model.to(device)
    return model

# Inference function
def infer(image):
    model = load_model()  # Load the model
    model.eval()
    
    # Apply transformations
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # Resize all images to 224x224
        transforms.ToTensor(),
    ])

    image = Image.fromarray(image)
    image = transform(image).unsqueeze(0).to(device)  # Add batch dimension

    with torch.no_grad():
        prediction = model(image)

    # Get boxes and labels from the predictions
    boxes = prediction[0]['boxes'].cpu().numpy()
    labels = prediction[0]['labels'].cpu().numpy()

    return boxes, labels

# Gradio interface
def gradio_interface(image):
    boxes, labels = infer(image)
    
    # Assuming labels: 0 = no mask, 1 = mask
    result = {"boxes": boxes, "labels": labels}
    return result

# Create Gradio interface
iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json")

# Launch Gradio interface
iface.launch()