Spaces:
Sleeping
Sleeping
File size: 4,006 Bytes
227e832 a631cc3 a51d4e2 ef4706e a51d4e2 a631cc3 a51d4e2 a631cc3 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 ef4706e a51d4e2 227e832 a51d4e2 227e832 a51d4e2 a631cc3 a51d4e2 227e832 a51d4e2 227e832 a51d4e2 227e832 ef4706e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import os
import zipfile
from PIL import Image
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import gradio as gr
# ----------- SETUP -----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
# ----------- UNZIP DATA -----------
def unzip_file(zip_path, extract_to):
if not os.path.exists(extract_to):
os.makedirs(extract_to)
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
print(f"Extracted {zip_path} to {extract_to}")
unzip_file("train.zip", "./data/train")
unzip_file("val.zip", "./data/val")
# ----------- DATASET -----------
class FaceMaskDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.image_paths = []
self.labels = []
self.transform = transform
for label_name in ['mask', 'no_mask']:
class_path = os.path.join(root_dir, label_name)
for img_name in os.listdir(class_path):
if img_name.endswith(".jpg") or img_name.endswith(".png"):
self.image_paths.append(os.path.join(class_path, img_name))
self.labels.append(0 if label_name == 'mask' else 1)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = Image.open(self.image_paths[idx]).convert("RGB")
if self.transform:
image = self.transform(image)
return image, self.labels[idx]
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
train_dataset = FaceMaskDataset("./data/train", transform)
val_dataset = FaceMaskDataset("./data/val", transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)
# ----------- MODEL -----------
model = models.mobilenet_v2(pretrained=True)
model.classifier[1] = nn.Linear(model.last_channel, 2)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# ----------- TRAINING -----------
def train_model(model, epochs=2): # keep epochs small for HF Spaces
for epoch in range(epochs):
model.train()
running_loss = 0.0
for imgs, labels in train_loader:
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
# Validation Accuracy
correct = 0
total = 0
model.eval()
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
print(f"Validation Accuracy: {acc:.2f}%")
train_model(model)
torch.save(model.state_dict(), "face_mask_model.pth")
# ----------- INFERENCE -----------
def predict(image):
model.eval()
img = image.convert("RGB")
img = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img)
_, predicted = torch.max(outputs.data, 1)
return "Mask" if predicted.item() == 0 else "No Mask"
# ----------- GRADIO APP -----------
iface = gr.Interface(
fn=predict,
inputs=gr.Image(source="webcam", tool="editor", type="pil", label="Upload or Webcam"),
outputs=gr.Label(label="Prediction"),
live=True,
title="Face Mask Detection",
description="Upload or use webcam to detect if a person is wearing a face mask."
)
iface.launch()
|