import gradio as gr import os import zipfile import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms from PIL import Image import xml.etree.ElementTree as ET import torchvision.models.detection from torchvision.models.detection.faster_rcnn import FastRCNNPredictor device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Dataset class class FaceMaskDataset(Dataset): def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): self.images_dir = images_dir self.annotations_dir = annotations_dir self.transform = transform self.resize = resize self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))] def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_path = os.path.join(self.images_dir, self.image_files[idx]) image = Image.open(image_path).convert("RGB") image = image.resize(self.resize) annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) if not os.path.exists(annotation_path): return None, None boxes, labels = self.load_annotations(annotation_path) if boxes is None or labels is None: return None, None target = {'boxes': boxes, 'labels': labels} if self.transform: image = self.transform(image) return image, target def load_annotations(self, annotation_path): tree = ET.parse(annotation_path) root = tree.getroot() boxes = [] labels = [] for obj in root.iter('object'): label = obj.find('name').text bndbox = obj.find('bndbox') xmin = float(bndbox.find('xmin').text) ymin = float(bndbox.find('ymin').text) xmax = float(bndbox.find('xmax').text) ymax = float(bndbox.find('ymax').text) boxes.append([xmin, ymin, xmax, ymax]) labels.append(1 if label == "mask" else 0) if not boxes or not labels: return None, None boxes = torch.tensor(boxes, dtype=torch.float32) labels = torch.tensor(labels, dtype=torch.int64) return boxes, labels def collate_fn(batch): batch = [b for b in batch if b[0] is not None and b[1] is not None] images, targets = zip(*batch) return list(images), list(targets) def get_model(num_classes): model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model def extract_zip(zip_file, extract_to): with zipfile.ZipFile(zip_file, 'r') as zip_ref: zip_ref.extractall(extract_to) def train_model(train_zip, val_zip): extract_zip(train_zip, './data/train') extract_zip(val_zip, './data/val') transform = transforms.Compose([transforms.ToTensor()]) train_dataset = FaceMaskDataset( images_dir='./data/train/train/images', annotations_dir='./data/train/train/annotations', transform=transform ) val_dataset = FaceMaskDataset( images_dir='./data/val/val/images', annotations_dir='./data/val/val/annotations', transform=transform ) train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) model = get_model(num_classes=2).to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) for epoch in range(3): # Reduce for demo model.train() total_loss = 0 for images, targets in train_loader: images = [img.to(device) for img in images] targets = [{k: v.to(device) for k, v in t.items()} for t in targets] optimizer.zero_grad() loss_dict = model(images, targets) loss = sum(loss for loss in loss_dict.values()) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") torch.save(model.state_dict(), "model.pth") return "Training completed. Model saved as model.pth" # Gradio upload interface iface = gr.Interface( fn=train_model, inputs=[ gr.File(label="Upload Train ZIP"), gr.File(label="Upload Val ZIP") ], outputs="text" ) if __name__ == "__main__": iface.launch()