File size: 5,836 Bytes
227e832
ef4706e
227e832
ef4706e
 
227e832
ef4706e
 
a631cc3
ef4706e
a631cc3
 
 
 
ef4706e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a631cc3
 
 
 
ef4706e
a631cc3
 
 
ef4706e
 
a631cc3
ef4706e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a631cc3
ef4706e
a631cc3
 
ef4706e
a631cc3
ef4706e
a631cc3
 
 
 
 
ef4706e
a631cc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227e832
 
 
 
 
 
a631cc3
 
227e832
 
 
 
 
 
 
ef4706e
a631cc3
227e832
a631cc3
227e832
a631cc3
ef4706e
a631cc3
 
 
 
ef4706e
a631cc3
 
ef4706e
a631cc3
 
 
 
227e832
a631cc3
 
 
227e832
a631cc3
 
 
 
227e832
 
 
a631cc3
 
227e832
a631cc3
227e832
a631cc3
 
 
 
227e832
ef4706e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import gradio as gr
import torch
import os
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import xml.etree.ElementTree as ET
import torch.optim as optim
import zipfile

# Device config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Custom Dataset
class FaceMaskDataset(Dataset):
    def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
        self.images_dir = images_dir
        self.annotations_dir = annotations_dir
        self.transform = transform
        self.resize = resize
        self.image_files = os.listdir(images_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.images_dir, self.image_files[idx])
        image = Image.open(image_path).convert("RGB")
        image = image.resize(self.resize)

        annotation_path = os.path.join(
            self.annotations_dir,
            self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")
        )
        if not os.path.exists(annotation_path):
            print(f"Warning: Annotation file {annotation_path} not found.")
            return None, None

        boxes, labels = self.load_annotations(annotation_path)
        if boxes is None or labels is None:
            return None, None

        target = {'boxes': boxes, 'labels': labels}
        if self.transform:
            image = self.transform(image)

        return image, target

    def load_annotations(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        boxes = []
        labels = []
        for obj in root.iter('object'):
            label = obj.find('name').text
            bndbox = obj.find('bndbox')
            xmin = float(bndbox.find('xmin').text)
            ymin = float(bndbox.find('ymin').text)
            xmax = float(bndbox.find('xmax').text)
            ymax = float(bndbox.find('ymax').text)
            boxes.append([xmin, ymin, xmax, ymax])
            labels.append(1 if label == "mask" else 0)

        if not boxes or not labels:
            return None, None

        return torch.as_tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

# Placeholder collate function
def collate_fn(batch):
    batch = list(filter(lambda x: x[0] is not None, batch))
    images, targets = zip(*batch)
    return images, targets

# Dummy get_model function (replace with real model)
def get_model(num_classes):
    import torchvision
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
    return model

# Validation Function
def evaluate_model(model, val_loader):
    model.eval()
    running_loss = 0.0
    with torch.no_grad():
        for images, targets in val_loader:
            if images is None or targets is None:
                continue
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())
            running_loss += total_loss.item()
    return running_loss / len(val_loader)

# Training Function
def train_model(model, train_loader, val_loader, optimizer, num_epochs=10):
    for epoch in range(num_epochs):
        running_loss = 0.0
        model.train()
        for images, targets in train_loader:
            if images is None or targets is None:
                continue
            images = [img.to(device) for img in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            optimizer.zero_grad()
            loss_dict = model(images, targets)
            total_loss = sum(loss for loss in loss_dict.values())
            total_loss.backward()
            optimizer.step()
            running_loss += total_loss.item()

        print(f"[Epoch {epoch+1}] Train Loss: {running_loss / len(train_loader):.4f}")
        val_loss = evaluate_model(model, val_loader)
        print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}")

    torch.save(model.state_dict(), "facemask_detector.pth")

# Main Training Trigger
def train_from_files_tab():
    train_zip_path = "train.zip"
    val_zip_path = "val.zip"

    if not os.path.exists(train_zip_path) or not os.path.exists(val_zip_path):
        return "❌ 'train.zip' or 'val.zip' not found in the Files section."

    # Extract
    for zip_path, folder in [(train_zip_path, "train"), (val_zip_path, "val")]:
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(folder)

    transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
    train_dataset = FaceMaskDataset("train/images", "train/annotations", transform=transform)
    val_dataset = FaceMaskDataset("val/images", "val/annotations", transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

    model = get_model(num_classes=2)
    model.to(device)
    optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

    train_model(model, train_loader, val_loader, optimizer, num_epochs=5)
    return "✅ Training complete. Model saved as 'facemask_detector.pth'."

# Gradio UI
iface = gr.Interface(
    fn=train_from_files_tab,
    inputs=[],
    outputs=gr.Textbox(label="Training Output"),
    title="Face Mask Detector Trainer (from Files Tab)"
)

iface.launch()