MoinulwithAI commited on
Commit
a51d4e2
·
verified ·
1 Parent(s): a631cc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -131
app.py CHANGED
@@ -1,157 +1,124 @@
1
- import gradio as gr
2
- import torch
3
  import os
4
- from torch.utils.data import Dataset, DataLoader
5
- from torchvision import transforms
6
- from PIL import Image
7
- import xml.etree.ElementTree as ET
8
- import torch.optim as optim
9
  import zipfile
 
 
 
 
 
 
10
 
11
- # Device config
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- # Custom Dataset
15
  class FaceMaskDataset(Dataset):
16
- def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
17
- self.images_dir = images_dir
18
- self.annotations_dir = annotations_dir
19
  self.transform = transform
20
- self.resize = resize
21
- self.image_files = os.listdir(images_dir)
 
 
 
 
22
 
23
  def __len__(self):
24
- return len(self.image_files)
25
 
26
  def __getitem__(self, idx):
27
- image_path = os.path.join(self.images_dir, self.image_files[idx])
28
- image = Image.open(image_path).convert("RGB")
29
- image = image.resize(self.resize)
30
-
31
- annotation_path = os.path.join(
32
- self.annotations_dir,
33
- self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")
34
- )
35
- if not os.path.exists(annotation_path):
36
- print(f"Warning: Annotation file {annotation_path} not found.")
37
- return None, None
38
-
39
- boxes, labels = self.load_annotations(annotation_path)
40
- if boxes is None or labels is None:
41
- return None, None
42
-
43
- target = {'boxes': boxes, 'labels': labels}
44
  if self.transform:
45
  image = self.transform(image)
 
46
 
47
- return image, target
48
-
49
- def load_annotations(self, annotation_path):
50
- tree = ET.parse(annotation_path)
51
- root = tree.getroot()
52
-
53
- boxes = []
54
- labels = []
55
- for obj in root.iter('object'):
56
- label = obj.find('name').text
57
- bndbox = obj.find('bndbox')
58
- xmin = float(bndbox.find('xmin').text)
59
- ymin = float(bndbox.find('ymin').text)
60
- xmax = float(bndbox.find('xmax').text)
61
- ymax = float(bndbox.find('ymax').text)
62
- boxes.append([xmin, ymin, xmax, ymax])
63
- labels.append(1 if label == "mask" else 0)
64
-
65
- if not boxes or not labels:
66
- return None, None
67
-
68
- return torch.as_tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)
69
-
70
- # Placeholder collate function
71
- def collate_fn(batch):
72
- batch = list(filter(lambda x: x[0] is not None, batch))
73
- images, targets = zip(*batch)
74
- return images, targets
75
-
76
- # Dummy get_model function (replace with real model)
77
- def get_model(num_classes):
78
- import torchvision
79
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
80
- in_features = model.roi_heads.box_predictor.cls_score.in_features
81
- model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
82
- return model
83
-
84
- # Validation Function
85
- def evaluate_model(model, val_loader):
86
- model.eval()
87
- running_loss = 0.0
88
- with torch.no_grad():
89
- for images, targets in val_loader:
90
- if images is None or targets is None:
91
- continue
92
- images = [img.to(device) for img in images]
93
- targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
94
- loss_dict = model(images, targets)
95
- total_loss = sum(loss for loss in loss_dict.values())
96
- running_loss += total_loss.item()
97
- return running_loss / len(val_loader)
98
-
99
- # Training Function
100
- def train_model(model, train_loader, val_loader, optimizer, num_epochs=10):
101
- for epoch in range(num_epochs):
102
- running_loss = 0.0
103
- model.train()
104
- for images, targets in train_loader:
105
- if images is None or targets is None:
106
- continue
107
- images = [img.to(device) for img in images]
108
- targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
109
- optimizer.zero_grad()
110
- loss_dict = model(images, targets)
111
- total_loss = sum(loss for loss in loss_dict.values())
112
- total_loss.backward()
113
- optimizer.step()
114
- running_loss += total_loss.item()
115
-
116
- print(f"[Epoch {epoch+1}] Train Loss: {running_loss / len(train_loader):.4f}")
117
- val_loss = evaluate_model(model, val_loader)
118
- print(f"[Epoch {epoch+1}] Validation Loss: {val_loss:.4f}")
119
-
120
- torch.save(model.state_dict(), "facemask_detector.pth")
121
 
122
- # Main Training Trigger
123
- def train_from_files_tab():
124
- train_zip_path = "train.zip"
125
- val_zip_path = "val.zip"
126
 
127
- if not os.path.exists(train_zip_path) or not os.path.exists(val_zip_path):
128
- return "❌ 'train.zip' or 'val.zip' not found in the Files section."
129
 
130
- # Extract
131
- for zip_path, folder in [(train_zip_path, "train"), (val_zip_path, "val")]:
132
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
133
- zip_ref.extractall(folder)
134
 
135
- transform = transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
136
- train_dataset = FaceMaskDataset("train/images", "train/annotations", transform=transform)
137
- val_dataset = FaceMaskDataset("val/images", "val/annotations", transform=transform)
138
 
139
- train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
140
- val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
141
 
142
- model = get_model(num_classes=2)
143
- model.to(device)
144
- optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- train_model(model, train_loader, val_loader, optimizer, num_epochs=5)
147
- return "✅ Training complete. Model saved as 'facemask_detector.pth'."
148
 
149
- # Gradio UI
150
  iface = gr.Interface(
151
- fn=train_from_files_tab,
152
- inputs=[],
153
- outputs=gr.Textbox(label="Training Output"),
154
- title="Face Mask Detector Trainer (from Files Tab)"
 
 
155
  )
156
 
157
  iface.launch()
 
 
 
1
  import os
 
 
 
 
 
2
  import zipfile
3
+ from PIL import Image
4
+ import torch
5
+ import torch.nn as nn
6
+ from torchvision import transforms, models
7
+ from torch.utils.data import Dataset, DataLoader
8
+ import gradio as gr
9
 
10
+ # ----------- SETUP -----------
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ print("Using device:", device)
13
+
14
+ # ----------- UNZIP DATA -----------
15
+
16
+ def unzip_file(zip_path, extract_to):
17
+ if not os.path.exists(extract_to):
18
+ os.makedirs(extract_to)
19
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
20
+ zip_ref.extractall(extract_to)
21
+ print(f"Extracted {zip_path} to {extract_to}")
22
+
23
+ unzip_file("train.zip", "./data/train")
24
+ unzip_file("val.zip", "./data/val")
25
+
26
+ # ----------- DATASET -----------
27
 
 
28
  class FaceMaskDataset(Dataset):
29
+ def __init__(self, root_dir, transform=None):
30
+ self.image_paths = []
31
+ self.labels = []
32
  self.transform = transform
33
+ for label_name in ['mask', 'no_mask']:
34
+ class_path = os.path.join(root_dir, label_name)
35
+ for img_name in os.listdir(class_path):
36
+ if img_name.endswith(".jpg") or img_name.endswith(".png"):
37
+ self.image_paths.append(os.path.join(class_path, img_name))
38
+ self.labels.append(0 if label_name == 'mask' else 1)
39
 
40
  def __len__(self):
41
+ return len(self.image_paths)
42
 
43
  def __getitem__(self, idx):
44
+ image = Image.open(self.image_paths[idx]).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  if self.transform:
46
  image = self.transform(image)
47
+ return image, self.labels[idx]
48
 
49
+ transform = transforms.Compose([
50
+ transforms.Resize((224, 224)),
51
+ transforms.ToTensor(),
52
+ ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ train_dataset = FaceMaskDataset("./data/train", transform)
55
+ val_dataset = FaceMaskDataset("./data/val", transform)
56
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
57
+ val_loader = DataLoader(val_dataset, batch_size=16)
58
 
59
+ # ----------- MODEL -----------
 
60
 
61
+ model = models.mobilenet_v2(pretrained=True)
62
+ model.classifier[1] = nn.Linear(model.last_channel, 2)
63
+ model = model.to(device)
 
64
 
65
+ criterion = nn.CrossEntropyLoss()
66
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
 
67
 
68
+ # ----------- TRAINING -----------
 
69
 
70
+ def train_model(model, epochs=2): # keep epochs small for HF Spaces
71
+ for epoch in range(epochs):
72
+ model.train()
73
+ running_loss = 0.0
74
+ for imgs, labels in train_loader:
75
+ imgs, labels = imgs.to(device), labels.to(device)
76
+ optimizer.zero_grad()
77
+ outputs = model(imgs)
78
+ loss = criterion(outputs, labels)
79
+ loss.backward()
80
+ optimizer.step()
81
+ running_loss += loss.item()
82
+
83
+ print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")
84
+
85
+ # Validation Accuracy
86
+ correct = 0
87
+ total = 0
88
+ model.eval()
89
+ with torch.no_grad():
90
+ for imgs, labels in val_loader:
91
+ imgs, labels = imgs.to(device), labels.to(device)
92
+ outputs = model(imgs)
93
+ _, predicted = torch.max(outputs.data, 1)
94
+ total += labels.size(0)
95
+ correct += (predicted == labels).sum().item()
96
+ acc = 100 * correct / total
97
+ print(f"Validation Accuracy: {acc:.2f}%")
98
+
99
+ train_model(model)
100
+ torch.save(model.state_dict(), "face_mask_model.pth")
101
+
102
+ # ----------- INFERENCE -----------
103
+
104
+ def predict(image):
105
+ model.eval()
106
+ img = image.convert("RGB")
107
+ img = transform(img).unsqueeze(0).to(device)
108
+ with torch.no_grad():
109
+ outputs = model(img)
110
+ _, predicted = torch.max(outputs.data, 1)
111
+ return "Mask" if predicted.item() == 0 else "No Mask"
112
 
113
+ # ----------- GRADIO APP -----------
 
114
 
 
115
  iface = gr.Interface(
116
+ fn=predict,
117
+ inputs=gr.Image(source="webcam", tool="editor", type="pil", label="Upload or Webcam"),
118
+ outputs=gr.Label(label="Prediction"),
119
+ live=True,
120
+ title="Face Mask Detection",
121
+ description="Upload or use webcam to detect if a person is wearing a face mask."
122
  )
123
 
124
  iface.launch()