import os import torch from torch.utils.data import Dataset, DataLoader from PIL import Image import torchvision from torchvision import transforms import xml.etree.ElementTree as ET import torch.optim as optim import matplotlib.pyplot as plt import gradio as gr # Ensure device is set to GPU if available device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class FaceMaskDataset(Dataset): def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): self.images_dir = images_dir self.annotations_dir = annotations_dir self.transform = transform self.resize = resize self.image_files = os.listdir(images_dir) def __len__(self): return len(self.image_files) def __getitem__(self, idx): image_path = os.path.join(self.images_dir, self.image_files[idx]) image = Image.open(image_path).convert("RGB") # Resize the image to a fixed size, while maintaining aspect ratio image = image.resize(self.resize) # Handle both .jpg and .png files annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) if not os.path.exists(annotation_path): print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.") return None, None # Return a tuple with None to skip the image/annotation pair boxes, labels = self.load_annotations(annotation_path) if boxes is None or labels is None: return None, None # Skip this item if annotations are invalid target = {'boxes': boxes, 'labels': labels} if self.transform: image = self.transform(image) return image, target def load_annotations(self, annotation_path): tree = ET.parse(annotation_path) root = tree.getroot() boxes = [] labels = [] for obj in root.iter('object'): label = obj.find('name').text bndbox = obj.find('bndbox') xmin = float(bndbox.find('xmin').text) ymin = float(bndbox.find('ymin').text) xmax = float(bndbox.find('xmax').text) ymax = float(bndbox.find('ymax').text) boxes.append([xmin, ymin, xmax, ymax]) labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0 if len(boxes) == 0 or len(labels) == 0: return None, None # If no boxes or labels are found, return None boxes = torch.as_tensor(boxes, dtype=torch.float32) labels = torch.tensor(labels, dtype=torch.int64) return boxes, labels # Define the collate function for DataLoader def collate_fn(batch): # Filter out None values and pack the rest into a batch batch = [item for item in batch if item[0] is not None and item[1] is not None] return tuple(zip(*batch)) # Load your pre-trained model (or initialize if required) def load_model(): model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # Assuming 2 classes: mask and no_mask num_classes = 2 in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) model.to(device) return model # Inference function def infer(image): model = load_model() # Load the model model.eval() # Apply transformations transform = transforms.Compose([ transforms.Resize((224, 224)), # Resize all images to 224x224 transforms.ToTensor(), ]) image = Image.fromarray(image) image = transform(image).unsqueeze(0).to(device) # Add batch dimension with torch.no_grad(): prediction = model(image) # Get boxes and labels from the predictions boxes = prediction[0]['boxes'].cpu().numpy() labels = prediction[0]['labels'].cpu().numpy() return boxes, labels # Gradio interface def gradio_interface(image): boxes, labels = infer(image) # Assuming labels: 0 = no mask, 1 = mask result = {"boxes": boxes, "labels": labels} return result # Create Gradio interface iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json") # Launch Gradio interface iface.launch()