Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
import torch.optim as optim | |
import zipfile | |
# Device config | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Custom Dataset | |
class FaceMaskDataset(Dataset): | |
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
self.images_dir = images_dir | |
self.annotations_dir = annotations_dir | |
self.transform = transform | |
self.resize = resize | |
self.image_files = os.listdir(images_dir) | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
image = Image.open(image_path).convert("RGB") | |
image = image.resize(self.resize) | |
annotation_path = os.path.join( | |
self.annotations_dir, | |
self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml") | |
) | |
if not os.path.exists(annotation_path): | |
print(f"Warning: Annotation file {annotation_path} not found.") | |
return None, None | |
boxes, labels = self.load_annotations(annotation_path) | |
if boxes is None or labels is None: | |
return None, None | |
target = {'boxes': boxes, 'labels': labels} | |
if self.transform: | |
image = self.transform(image) | |
return image, target | |
def load_annotations(self, annotation_path): | |
tree = ET.parse(annotation_path) | |
root = tree.getroot() | |
boxes = [] | |
labels = [] | |
for obj in root.iter('object'): | |
label = obj.find('name').text | |
bndbox = obj.find('bndbox') | |
xmin = float(bndbox.find('xmin').text) | |
ymin = float(bndbox.find('ymin').text) | |
xmax = float(bndbox.find('xmax').text) | |
ymax = float(bndbox.find('ymax').text) | |
boxes.append([xmin, ymin, xmax, ymax]) | |
labels.append(1 if label == "mask" else 0) | |
if not boxes or not labels: | |
return None, None | |
return torch.as_tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64) | |
# Placeholder collate function | |
def collate_fn(batch): | |
batch = list(filter(lambda x: x[0] is not None, batch)) | |
images, targets = zip(*batch) | |
return images, targets | |
# Dummy get_model function (replace with real model) | |
def get_model(num_classes): | |
import torchvision | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
return model | |
# Validation Function | |
def evaluate_model(model, val_loader): | |
model.eval() | |
running_loss = 0.0 | |
with torch.no_grad(): | |
for images, targets in val_loader: | |
if images is None or targets is None: | |
continue | |
images = [img.to(device) for img in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
loss_dict = model(images, targets) | |
total_loss = sum(loss for loss in loss_dict.values()) | |
running_loss += total_loss.item() | |
return running_loss / len(val_loader) | |
# Training Function | |
def train_model(model, train_loader, val_loader, optimizer, num_epochs=10): | |
for epoch in range(num_epochs): | |
running_loss = 0.0 | |
model.train() | |
for images, targets in train_loader: | |
if images is None or targets is None: | |
continue | |
images = [img.to(device) for img in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
total_loss = sum(loss for loss in loss_dict.values()) | |
total_loss.backward() | |
optimizer.step() | |
running_loss += total_loss.item() | |
print(f"[Epoch {epoch+1}] Train Loss: {running_loss / len(train_loader):.4f}") | |
val_loss = evaluate_model(model, val_loader) | |
print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}") | |
torch.save(model.state_dict(), "facemask_detector.pth") | |
# Main Training Trigger | |
def train_from_files_tab(): | |
train_zip_path = "train.zip" | |
val_zip_path = "val.zip" | |
if not os.path.exists(train_zip_path) or not os.path.exists(val_zip_path): | |
return "❌ 'train.zip' or 'val.zip' not found in the Files section." | |
# Extract | |
for zip_path, folder in [(train_zip_path, "train"), (val_zip_path, "val")]: | |
with zipfile.ZipFile(zip_path, 'r') as zip_ref: | |
zip_ref.extractall(folder) | |
transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
train_dataset = FaceMaskDataset("train/images", "train/annotations", transform=transform) | |
val_dataset = FaceMaskDataset("val/images", "val/annotations", transform=transform) | |
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) | |
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) | |
model = get_model(num_classes=2) | |
model.to(device) | |
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) | |
train_model(model, train_loader, val_loader, optimizer, num_epochs=5) | |
return "✅ Training complete. Model saved as 'facemask_detector.pth'." | |
# Gradio UI | |
iface = gr.Interface( | |
fn=train_from_files_tab, | |
inputs=[], | |
outputs=gr.Textbox(label="Training Output"), | |
title="Face Mask Detector Trainer (from Files Tab)" | |
) | |
iface.launch() | |