MoinulwithAI commited on
Commit
227e832
·
verified ·
1 Parent(s): 5d42441

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -55
app.py CHANGED
@@ -1,17 +1,15 @@
1
- import os
2
  import torch
 
3
  from torch.utils.data import Dataset, DataLoader
4
- from PIL import Image
5
- import torchvision
6
  from torchvision import transforms
 
7
  import xml.etree.ElementTree as ET
8
  import torch.optim as optim
9
- import matplotlib.pyplot as plt
10
- import gradio as gr
11
-
12
- # Ensure device is set to GPU if available
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
 
 
15
  class FaceMaskDataset(Dataset):
16
  def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
17
  self.images_dir = images_dir
@@ -26,21 +24,17 @@ class FaceMaskDataset(Dataset):
26
  def __getitem__(self, idx):
27
  image_path = os.path.join(self.images_dir, self.image_files[idx])
28
  image = Image.open(image_path).convert("RGB")
29
-
30
- # Resize the image to a fixed size, while maintaining aspect ratio
31
  image = image.resize(self.resize)
32
 
33
- # Handle both .jpg and .png files
34
  annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
35
 
36
  if not os.path.exists(annotation_path):
37
  print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
38
- return None, None # Return a tuple with None to skip the image/annotation pair
39
 
40
  boxes, labels = self.load_annotations(annotation_path)
41
-
42
  if boxes is None or labels is None:
43
- return None, None # Skip this item if annotations are invalid
44
 
45
  target = {'boxes': boxes, 'labels': labels}
46
 
@@ -63,66 +57,119 @@ class FaceMaskDataset(Dataset):
63
  xmax = float(bndbox.find('xmax').text)
64
  ymax = float(bndbox.find('ymax').text)
65
  boxes.append([xmin, ymin, xmax, ymax])
66
- labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0
67
 
68
  if len(boxes) == 0 or len(labels) == 0:
69
- return None, None # If no boxes or labels are found, return None
70
 
71
  boxes = torch.as_tensor(boxes, dtype=torch.float32)
72
  labels = torch.tensor(labels, dtype=torch.int64)
73
 
74
  return boxes, labels
75
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Define the collate function for DataLoader
78
- def collate_fn(batch):
79
- # Filter out None values and pack the rest into a batch
80
- batch = [item for item in batch if item[0] is not None and item[1] is not None]
81
- return tuple(zip(*batch))
82
 
83
- # Load your pre-trained model (or initialize if required)
84
- def load_model():
85
- model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
86
- # Assuming 2 classes: mask and no_mask
87
- num_classes = 2
88
- in_features = model.roi_heads.box_predictor.cls_score.in_features
89
- model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
90
- model.to(device)
91
- return model
92
 
93
- # Inference function
94
- def infer(image):
95
- model = load_model() # Load the model
96
- model.eval()
97
-
98
- # Apply transformations
99
- transform = transforms.Compose([
100
- transforms.Resize((224, 224)), # Resize all images to 224x224
101
- transforms.ToTensor(),
102
- ])
103
 
104
- image = Image.fromarray(image)
105
- image = transform(image).unsqueeze(0).to(device) # Add batch dimension
106
 
 
 
 
 
 
 
 
 
107
  with torch.no_grad():
108
- prediction = model(image)
 
 
109
 
110
- # Get boxes and labels from the predictions
111
- boxes = prediction[0]['boxes'].cpu().numpy()
112
- labels = prediction[0]['labels'].cpu().numpy()
113
 
114
- return boxes, labels
115
 
116
- # Gradio interface
117
- def gradio_interface(image):
118
- boxes, labels = infer(image)
 
 
 
 
 
 
 
 
119
 
120
- # Assuming labels: 0 = no mask, 1 = mask
121
- result = {"boxes": boxes, "labels": labels}
122
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
- # Create Gradio interface
125
- iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # Launch Gradio interface
128
  iface.launch()
 
1
+ import gradio as gr
2
  import torch
3
+ import os
4
  from torch.utils.data import Dataset, DataLoader
 
 
5
  from torchvision import transforms
6
+ from PIL import Image
7
  import xml.etree.ElementTree as ET
8
  import torch.optim as optim
9
+ from torch import nn
 
 
 
 
10
 
11
+ # Your model training and evaluation functions (already defined in your previous code)
12
+ # Define the custom dataset
13
  class FaceMaskDataset(Dataset):
14
  def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
15
  self.images_dir = images_dir
 
24
  def __getitem__(self, idx):
25
  image_path = os.path.join(self.images_dir, self.image_files[idx])
26
  image = Image.open(image_path).convert("RGB")
 
 
27
  image = image.resize(self.resize)
28
 
 
29
  annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
30
 
31
  if not os.path.exists(annotation_path):
32
  print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
33
+ return None, None # Return None if annotation is missing
34
 
35
  boxes, labels = self.load_annotations(annotation_path)
 
36
  if boxes is None or labels is None:
37
+ return None, None # Skip if annotations are invalid
38
 
39
  target = {'boxes': boxes, 'labels': labels}
40
 
 
57
  xmax = float(bndbox.find('xmax').text)
58
  ymax = float(bndbox.find('ymax').text)
59
  boxes.append([xmin, ymin, xmax, ymax])
60
+ labels.append(1 if label == "mask" else 0) # "mask" = 1, "no_mask" = 0
61
 
62
  if len(boxes) == 0 or len(labels) == 0:
63
+ return None, None # If no boxes or labels, return None
64
 
65
  boxes = torch.as_tensor(boxes, dtype=torch.float32)
66
  labels = torch.tensor(labels, dtype=torch.int64)
67
 
68
  return boxes, labels
69
 
70
+ # Model Training Loop (referred to from previous code)
71
+ def train_model(model, train_loader, val_loader, optimizer, num_epochs=10):
72
+ for epoch in range(num_epochs):
73
+ # Training loop
74
+ running_loss = 0.0
75
+ model.train()
76
+ for images, targets in train_loader:
77
+ if images is None or targets is None:
78
+ continue # Skip invalid images/annotations
79
 
80
+ # Move data to device
81
+ images = [image.to(device) for image in images]
82
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
 
 
83
 
84
+ optimizer.zero_grad()
85
+ loss_dict = model(images, targets)
 
 
 
 
 
 
 
86
 
87
+ # Calculate total loss
88
+ total_loss = sum(loss for loss in loss_dict.values())
89
+ total_loss.backward()
90
+ optimizer.step()
91
+
92
+ running_loss += total_loss.item()
 
 
 
 
93
 
94
+ print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")
 
95
 
96
+ # Evaluate after every epoch
97
+ val_loss = evaluate_model(model, val_loader)
98
+ print(f"Validation Loss: {val_loss}")
99
+
100
+ # Validation function
101
+ def evaluate_model(model, val_loader):
102
+ model.eval()
103
+ running_loss = 0.0
104
  with torch.no_grad():
105
+ for images, targets in val_loader:
106
+ if images is None or targets is None:
107
+ continue # Skip invalid data
108
 
109
+ # Move data to device
110
+ images = [image.to(device) for image in images]
111
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
112
 
113
+ loss_dict = model(images, targets)
114
 
115
+ # Calculate total loss
116
+ total_loss = sum(loss for loss in loss_dict.values())
117
+ running_loss += total_loss.item()
118
+
119
+ return running_loss / len(val_loader)
120
+
121
+ # Function to upload dataset and start training
122
+ def train_on_uploaded_data(train_data, val_data):
123
+ # Save the uploaded dataset (files)
124
+ train_data_path = "train_data.zip"
125
+ val_data_path = "val_data.zip"
126
 
127
+ # Unzip and prepare directories (assuming you upload zip files for simplicity)
128
+ with open(train_data.name, 'wb') as f:
129
+ f.write(train_data.read())
130
+ with open(val_data.name, 'wb') as f:
131
+ f.write(val_data.read())
132
+
133
+ # Extract zip files
134
+ os.system(f"unzip {train_data_path} -d ./train/")
135
+ os.system(f"unzip {val_data_path} -d ./val/")
136
+
137
+ # Load datasets
138
+ train_dataset = FaceMaskDataset(
139
+ images_dir="train/images",
140
+ annotations_dir="train/annotations",
141
+ transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
142
+ )
143
+ val_dataset = FaceMaskDataset(
144
+ images_dir="val/images",
145
+ annotations_dir="val/annotations",
146
+ transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()])
147
+ )
148
+
149
+ # Dataloaders
150
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
151
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn)
152
 
153
+ # Train the model
154
+ model = get_model(num_classes=2) # Assuming you have a model function
155
+ model.to(device)
156
+ optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
157
+
158
+ # Train the model and return feedback
159
+ train_model(model, train_loader, val_loader, optimizer, num_epochs=10)
160
+
161
+ return "Training completed and model saved."
162
+
163
+ # Create Gradio Interface
164
+ iface = gr.Interface(
165
+ fn=train_on_uploaded_data,
166
+ inputs=[
167
+ gr.File(label="Upload Train Dataset (ZIP)"),
168
+ gr.File(label="Upload Validation Dataset (ZIP)")
169
+ ],
170
+ outputs=gr.Textbox(label="Training Status"),
171
+ live=True
172
+ )
173
 
174
  # Launch Gradio interface
175
  iface.launch()