File size: 4,735 Bytes
c4bd279
227e832
a631cc3
a51d4e2
 
c4bd279
 
 
 
 
ef4706e
a631cc3
c4bd279
a631cc3
c4bd279
ef4706e
c4bd279
 
 
ef4706e
c4bd279
 
ef4706e
 
c4bd279
ef4706e
 
c4bd279
 
 
ef4706e
c4bd279
 
 
ef4706e
c4bd279
 
 
227e832
c4bd279
227e832
c4bd279
 
a631cc3
c4bd279
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a51d4e2
c4bd279
 
 
 
 
a51d4e2
c4bd279
 
a51d4e2
 
227e832
c4bd279
 
 
 
 
 
 
 
227e832
c4bd279
 
 
 
 
 
227e832
ef4706e
c4bd279
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import os
import zipfile
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import xml.etree.ElementTree as ET
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Dataset class
class FaceMaskDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.resize = resize
        self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        image = image.resize(self.resize)

        annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
        if not os.path.exists(annotation_path):
            return None, None

        boxes, labels = self.load_annotations(annotation_path)
        if boxes is None or labels is None:
            return None, None

        target = {'boxes': boxes, 'labels': labels}

        if self.transform:
            image = self.transform(image)

        return image, target

    def load_annotations(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        boxes = []
        labels = []
        for obj in root.iter('object'):
            label = obj.find('name').text
            bndbox = obj.find('bndbox')
            xmin = float(bndbox.find('xmin').text)
            ymin = float(bndbox.find('ymin').text)
            xmax = float(bndbox.find('xmax').text)
            ymax = float(bndbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(1 if label == "mask" else 0)

        if not boxes or not labels:
            return None, None

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels, dtype=torch.int64)

        return boxes, labels

def collate_fn(batch):
    batch = [b for b in batch if b[0] is not None and b[1] is not None]
    images, targets = zip(*batch)
    return list(images), list(targets)

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

def extract_zip(zip_file, extract_to):
    with zipfile.ZipFile(zip_file, 'r') as zip_ref:
        zip_ref.extractall(extract_to)

def train_model(train_zip, val_zip):
    extract_zip(train_zip, './data/train')
    extract_zip(val_zip, './data/val')

    transform = transforms.Compose([transforms.ToTensor()])

    train_dataset = FaceMaskDataset(
        images_dir='./data/train/train/images',
        annotations_dir='./data/train/train/annotations',
        transform=transform
    )
    val_dataset = FaceMaskDataset(
        images_dir='./data/val/val/images',
        annotations_dir='./data/val/val/annotations',
        transform=transform
    )

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    model = get_model(num_classes=2).to(device)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

    for epoch in range(3):  # Reduce for demo
        model.train()
        total_loss = 0
        for images, targets in train_loader:
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

            optimizer.zero_grad()
            loss_dict = model(images, targets)
            loss = sum(loss for loss in loss_dict.values())
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")

    torch.save(model.state_dict(), "model.pth")
    return "Training completed. Model saved as model.pth"

# Gradio upload interface
iface = gr.Interface(
    fn=train_model,
    inputs=[
        gr.File(label="Upload Train ZIP"),
        gr.File(label="Upload Val ZIP")
    ],
    outputs="text"
)

if __name__ == "__main__":
    iface.launch()