Spaces:
Sleeping
Sleeping
File size: 4,735 Bytes
c4bd279 227e832 a631cc3 a51d4e2 c4bd279 ef4706e a631cc3 c4bd279 a631cc3 c4bd279 ef4706e c4bd279 ef4706e c4bd279 ef4706e c4bd279 ef4706e c4bd279 ef4706e c4bd279 ef4706e c4bd279 227e832 c4bd279 227e832 c4bd279 a631cc3 c4bd279 a51d4e2 c4bd279 a51d4e2 c4bd279 a51d4e2 227e832 c4bd279 227e832 c4bd279 227e832 ef4706e c4bd279 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
import gradio as gr
import os
import zipfile
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import xml.etree.ElementTree as ET
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Dataset class
class FaceMaskDataset(Dataset):
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
self.images_dir = images_dir
self.annotations_dir = annotations_dir
self.transform = transform
self.resize = resize
self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = os.path.join(self.images_dir, self.image_files[idx])
image = Image.open(image_path).convert("RGB")
image = image.resize(self.resize)
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
if not os.path.exists(annotation_path):
return None, None
boxes, labels = self.load_annotations(annotation_path)
if boxes is None or labels is None:
return None, None
target = {'boxes': boxes, 'labels': labels}
if self.transform:
image = self.transform(image)
return image, target
def load_annotations(self, annotation_path):
tree = ET.parse(annotation_path)
root = tree.getroot()
boxes = []
labels = []
for obj in root.iter('object'):
label = obj.find('name').text
bndbox = obj.find('bndbox')
xmin = float(bndbox.find('xmin').text)
ymin = float(bndbox.find('ymin').text)
xmax = float(bndbox.find('xmax').text)
ymax = float(bndbox.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(1 if label == "mask" else 0)
if not boxes or not labels:
return None, None
boxes = torch.tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
return boxes, labels
def collate_fn(batch):
batch = [b for b in batch if b[0] is not None and b[1] is not None]
images, targets = zip(*batch)
return list(images), list(targets)
def get_model(num_classes):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def extract_zip(zip_file, extract_to):
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_to)
def train_model(train_zip, val_zip):
extract_zip(train_zip, './data/train')
extract_zip(val_zip, './data/val')
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = FaceMaskDataset(
images_dir='./data/train/train/images',
annotations_dir='./data/train/train/annotations',
transform=transform
)
val_dataset = FaceMaskDataset(
images_dir='./data/val/val/images',
annotations_dir='./data/val/val/annotations',
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
model = get_model(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
for epoch in range(3): # Reduce for demo
model.train()
total_loss = 0
for images, targets in train_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
optimizer.zero_grad()
loss_dict = model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
torch.save(model.state_dict(), "model.pth")
return "Training completed. Model saved as model.pth"
# Gradio upload interface
iface = gr.Interface(
fn=train_model,
inputs=[
gr.File(label="Upload Train ZIP"),
gr.File(label="Upload Val ZIP")
],
outputs="text"
)
if __name__ == "__main__":
iface.launch()
|