Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import zipfile | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
import torchvision.models.detection | |
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Dataset class | |
class FaceMaskDataset(Dataset): | |
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
self.images_dir = images_dir | |
self.annotations_dir = annotations_dir | |
self.transform = transform | |
self.resize = resize | |
self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))] | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
image = Image.open(image_path).convert("RGB") | |
image = image.resize(self.resize) | |
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
if not os.path.exists(annotation_path): | |
return None, None | |
boxes, labels = self.load_annotations(annotation_path) | |
if boxes is None or labels is None: | |
return None, None | |
target = {'boxes': boxes, 'labels': labels} | |
if self.transform: | |
image = self.transform(image) | |
return image, target | |
def load_annotations(self, annotation_path): | |
tree = ET.parse(annotation_path) | |
root = tree.getroot() | |
boxes = [] | |
labels = [] | |
for obj in root.iter('object'): | |
label = obj.find('name').text | |
bndbox = obj.find('bndbox') | |
xmin = float(bndbox.find('xmin').text) | |
ymin = float(bndbox.find('ymin').text) | |
xmax = float(bndbox.find('xmax').text) | |
ymax = float(bndbox.find('ymax').text) | |
boxes.append([xmin, ymin, xmax, ymax]) | |
labels.append(1 if label == "mask" else 0) | |
if not boxes or not labels: | |
return None, None | |
boxes = torch.tensor(boxes, dtype=torch.float32) | |
labels = torch.tensor(labels, dtype=torch.int64) | |
return boxes, labels | |
def collate_fn(batch): | |
batch = [b for b in batch if b[0] is not None and b[1] is not None] | |
images, targets = zip(*batch) | |
return list(images), list(targets) | |
def get_model(num_classes): | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | |
return model | |
def extract_zip(zip_file, extract_to): | |
with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
zip_ref.extractall(extract_to) | |
def train_model(train_zip, val_zip): | |
extract_zip(train_zip, './data/train') | |
extract_zip(val_zip, './data/val') | |
transform = transforms.Compose([transforms.ToTensor()]) | |
train_dataset = FaceMaskDataset( | |
images_dir='./data/train/train/images', | |
annotations_dir='./data/train/train/annotations', | |
transform=transform | |
) | |
val_dataset = FaceMaskDataset( | |
images_dir='./data/val/val/images', | |
annotations_dir='./data/val/val/annotations', | |
transform=transform | |
) | |
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn) | |
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn) | |
model = get_model(num_classes=2).to(device) | |
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) | |
for epoch in range(3): # Reduce for demo | |
model.train() | |
total_loss = 0 | |
for images, targets in train_loader: | |
images = [img.to(device) for img in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
loss = sum(loss for loss in loss_dict.values()) | |
loss.backward() | |
optimizer.step() | |
total_loss += loss.item() | |
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}") | |
torch.save(model.state_dict(), "model.pth") | |
return "Training completed. Model saved as model.pth" | |
# Gradio upload interface | |
iface = gr.Interface( | |
fn=train_model, | |
inputs=[ | |
gr.File(label="Upload Train ZIP"), | |
gr.File(label="Upload Val ZIP") | |
], | |
outputs="text" | |
) | |
if __name__ == "__main__": | |
iface.launch() | |