MoinulwithAI commited on
Commit
ef4706e
·
verified ·
1 Parent(s): 14283e1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from PIL import Image
5
+ import torchvision
6
+ from torchvision import transforms
7
+ import xml.etree.ElementTree as ET
8
+ import torch.optim as optim
9
+ import matplotlib.pyplot as plt
10
+ import gradio as gr
11
+
12
+ # Ensure device is set to GPU if available
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ class FaceMaskDataset(Dataset):
16
+ def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
17
+ self.images_dir = images_dir
18
+ self.annotations_dir = annotations_dir
19
+ self.transform = transform
20
+ self.resize = resize
21
+ self.image_files = os.listdir(images_dir)
22
+
23
+ def __len__(self):
24
+ return len(self.image_files)
25
+
26
+ def __getitem__(self, idx):
27
+ image_path = os.path.join(self.images_dir, self.image_files[idx])
28
+ image = Image.open(image_path).convert("RGB")
29
+
30
+ # Resize the image to a fixed size, while maintaining aspect ratio
31
+ image = image.resize(self.resize)
32
+
33
+ # Handle both .jpg and .png files
34
+ annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
35
+
36
+ if not os.path.exists(annotation_path):
37
+ print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
38
+ return None, None # Return a tuple with None to skip the image/annotation pair
39
+
40
+ boxes, labels = self.load_annotations(annotation_path)
41
+
42
+ if boxes is None or labels is None:
43
+ return None, None # Skip this item if annotations are invalid
44
+
45
+ target = {'boxes': boxes, 'labels': labels}
46
+
47
+ if self.transform:
48
+ image = self.transform(image)
49
+
50
+ return image, target
51
+
52
+ def load_annotations(self, annotation_path):
53
+ tree = ET.parse(annotation_path)
54
+ root = tree.getroot()
55
+
56
+ boxes = []
57
+ labels = []
58
+ for obj in root.iter('object'):
59
+ label = obj.find('name').text
60
+ bndbox = obj.find('bndbox')
61
+ xmin = float(bndbox.find('xmin').text)
62
+ ymin = float(bndbox.find('ymin').text)
63
+ xmax = float(bndbox.find('xmax').text)
64
+ ymax = float(bndbox.find('ymax').text)
65
+ boxes.append([xmin, ymin, xmax, ymax])
66
+ labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0
67
+
68
+ if len(boxes) == 0 or len(labels) == 0:
69
+ return None, None # If no boxes or labels are found, return None
70
+
71
+ boxes = torch.as_tensor(boxes, dtype=torch.float32)
72
+ labels = torch.tensor(labels, dtype=torch.int64)
73
+
74
+ return boxes, labels
75
+
76
+
77
+ # Define the collate function for DataLoader
78
+ def collate_fn(batch):
79
+ # Filter out None values and pack the rest into a batch
80
+ batch = [item for item in batch if item[0] is not None and item[1] is not None]
81
+ return tuple(zip(*batch))
82
+
83
+ # Load your pre-trained model (or initialize if required)
84
+ def load_model():
85
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
86
+ # Assuming 2 classes: mask and no_mask
87
+ num_classes = 2
88
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
89
+ model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
90
+ model.to(device)
91
+ return model
92
+
93
+ # Inference function
94
+ def infer(image):
95
+ model = load_model() # Load the model
96
+ model.eval()
97
+
98
+ # Apply transformations
99
+ transform = transforms.Compose([
100
+ transforms.Resize((224, 224)), # Resize all images to 224x224
101
+ transforms.ToTensor(),
102
+ ])
103
+
104
+ image = Image.fromarray(image)
105
+ image = transform(image).unsqueeze(0).to(device) # Add batch dimension
106
+
107
+ with torch.no_grad():
108
+ prediction = model(image)
109
+
110
+ # Get boxes and labels from the predictions
111
+ boxes = prediction[0]['boxes'].cpu().numpy()
112
+ labels = prediction[0]['labels'].cpu().numpy()
113
+
114
+ return boxes, labels
115
+
116
+ # Gradio interface
117
+ def gradio_interface(image):
118
+ boxes, labels = infer(image)
119
+
120
+ # Assuming labels: 0 = no mask, 1 = mask
121
+ result = {"boxes": boxes, "labels": labels}
122
+ return result
123
+
124
+ # Create Gradio interface
125
+ iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json")
126
+
127
+ # Launch Gradio interface
128
+ iface.launch()