File size: 4,006 Bytes
227e832
a631cc3
a51d4e2
 
 
 
 
 
ef4706e
a51d4e2
a631cc3
a51d4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a631cc3
ef4706e
a51d4e2
 
 
ef4706e
a51d4e2
 
 
 
 
 
ef4706e
 
a51d4e2
ef4706e
 
a51d4e2
ef4706e
 
a51d4e2
ef4706e
a51d4e2
 
 
 
ef4706e
a51d4e2
 
 
 
ef4706e
a51d4e2
ef4706e
a51d4e2
 
 
227e832
a51d4e2
 
227e832
a51d4e2
a631cc3
a51d4e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227e832
a51d4e2
227e832
 
a51d4e2
 
 
 
 
 
227e832
ef4706e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import zipfile
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import gradio as gr

# ----------- SETUP -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------- UNZIP DATA -----------

def unzip_file(zip_path, extract_to):
    if not os.path.exists(extract_to):
        os.makedirs(extract_to)
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
        print(f"Extracted {zip_path} to {extract_to}")

unzip_file("train.zip", "./data/train")
unzip_file("val.zip", "./data/val")

# ----------- DATASET -----------

class FaceMaskDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.image_paths = []
        self.labels = []
        self.transform = transform
        for label_name in ['mask', 'no_mask']:
            class_path = os.path.join(root_dir, label_name)
            for img_name in os.listdir(class_path):
                if img_name.endswith(".jpg") or img_name.endswith(".png"):
                    self.image_paths.append(os.path.join(class_path, img_name))
                    self.labels.append(0 if label_name == 'mask' else 1)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, self.labels[idx]

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = FaceMaskDataset("./data/train", transform)
val_dataset = FaceMaskDataset("./data/val", transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

# ----------- MODEL -----------

model = models.mobilenet_v2(pretrained=True)
model.classifier[1] = nn.Linear(model.last_channel, 2)
model = model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# ----------- TRAINING -----------

def train_model(model, epochs=2):  # keep epochs small for HF Spaces
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for imgs, labels in train_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

        # Validation Accuracy
        correct = 0
        total = 0
        model.eval()
        with torch.no_grad():
            for imgs, labels in val_loader:
                imgs, labels = imgs.to(device), labels.to(device)
                outputs = model(imgs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        acc = 100 * correct / total
        print(f"Validation Accuracy: {acc:.2f}%")

train_model(model)
torch.save(model.state_dict(), "face_mask_model.pth")

# ----------- INFERENCE -----------

def predict(image):
    model.eval()
    img = image.convert("RGB")
    img = transform(img).unsqueeze(0).to(device)
    with torch.no_grad():
        outputs = model(img)
        _, predicted = torch.max(outputs.data, 1)
    return "Mask" if predicted.item() == 0 else "No Mask"

# ----------- GRADIO APP -----------

iface = gr.Interface(
    fn=predict,
    inputs=gr.Image(source="webcam", tool="editor", type="pil", label="Upload or Webcam"),
    outputs=gr.Label(label="Prediction"),
    live=True,
    title="Face Mask Detection",
    description="Upload or use webcam to detect if a person is wearing a face mask."
)

iface.launch()