Spaces:
Sleeping
Sleeping
import os | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from PIL import Image | |
import torchvision | |
from torchvision import transforms | |
import xml.etree.ElementTree as ET | |
import torch.optim as optim | |
import matplotlib.pyplot as plt | |
import gradio as gr | |
# Ensure device is set to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class FaceMaskDataset(Dataset): | |
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
self.images_dir = images_dir | |
self.annotations_dir = annotations_dir | |
self.transform = transform | |
self.resize = resize | |
self.image_files = os.listdir(images_dir) | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
image = Image.open(image_path).convert("RGB") | |
# Resize the image to a fixed size, while maintaining aspect ratio | |
image = image.resize(self.resize) | |
# Handle both .jpg and .png files | |
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
if not os.path.exists(annotation_path): | |
print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.") | |
return None, None # Return a tuple with None to skip the image/annotation pair | |
boxes, labels = self.load_annotations(annotation_path) | |
if boxes is None or labels is None: | |
return None, None # Skip this item if annotations are invalid | |
target = {'boxes': boxes, 'labels': labels} | |
if self.transform: | |
image = self.transform(image) | |
return image, target | |
def load_annotations(self, annotation_path): | |
tree = ET.parse(annotation_path) | |
root = tree.getroot() | |
boxes = [] | |
labels = [] | |
for obj in root.iter('object'): | |
label = obj.find('name').text | |
bndbox = obj.find('bndbox') | |
xmin = float(bndbox.find('xmin').text) | |
ymin = float(bndbox.find('ymin').text) | |
xmax = float(bndbox.find('xmax').text) | |
ymax = float(bndbox.find('ymax').text) | |
boxes.append([xmin, ymin, xmax, ymax]) | |
labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0 | |
if len(boxes) == 0 or len(labels) == 0: | |
return None, None # If no boxes or labels are found, return None | |
boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
labels = torch.tensor(labels, dtype=torch.int64) | |
return boxes, labels | |
# Define the collate function for DataLoader | |
def collate_fn(batch): | |
# Filter out None values and pack the rest into a batch | |
batch = [item for item in batch if item[0] is not None and item[1] is not None] | |
return tuple(zip(*batch)) | |
# Load your pre-trained model (or initialize if required) | |
def load_model(): | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
# Assuming 2 classes: mask and no_mask | |
num_classes = 2 | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) | |
model.to(device) | |
return model | |
# Inference function | |
def infer(image): | |
model = load_model() # Load the model | |
model.eval() | |
# Apply transformations | |
transform = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize all images to 224x224 | |
transforms.ToTensor(), | |
]) | |
image = Image.fromarray(image) | |
image = transform(image).unsqueeze(0).to(device) # Add batch dimension | |
with torch.no_grad(): | |
prediction = model(image) | |
# Get boxes and labels from the predictions | |
boxes = prediction[0]['boxes'].cpu().numpy() | |
labels = prediction[0]['labels'].cpu().numpy() | |
return boxes, labels | |
# Gradio interface | |
def gradio_interface(image): | |
boxes, labels = infer(image) | |
# Assuming labels: 0 = no mask, 1 = mask | |
result = {"boxes": boxes, "labels": labels} | |
return result | |
# Create Gradio interface | |
iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json") | |
# Launch Gradio interface | |
iface.launch() | |