Spaces:
Sleeping
Sleeping
File size: 4,412 Bytes
ef4706e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision
from torchvision import transforms
import xml.etree.ElementTree as ET
import torch.optim as optim
import matplotlib.pyplot as plt
import gradio as gr
# Ensure device is set to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class FaceMaskDataset(Dataset):
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
self.images_dir = images_dir
self.annotations_dir = annotations_dir
self.transform = transform
self.resize = resize
self.image_files = os.listdir(images_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = os.path.join(self.images_dir, self.image_files[idx])
image = Image.open(image_path).convert("RGB")
# Resize the image to a fixed size, while maintaining aspect ratio
image = image.resize(self.resize)
# Handle both .jpg and .png files
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
if not os.path.exists(annotation_path):
print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
return None, None # Return a tuple with None to skip the image/annotation pair
boxes, labels = self.load_annotations(annotation_path)
if boxes is None or labels is None:
return None, None # Skip this item if annotations are invalid
target = {'boxes': boxes, 'labels': labels}
if self.transform:
image = self.transform(image)
return image, target
def load_annotations(self, annotation_path):
tree = ET.parse(annotation_path)
root = tree.getroot()
boxes = []
labels = []
for obj in root.iter('object'):
label = obj.find('name').text
bndbox = obj.find('bndbox')
xmin = float(bndbox.find('xmin').text)
ymin = float(bndbox.find('ymin').text)
xmax = float(bndbox.find('xmax').text)
ymax = float(bndbox.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0
if len(boxes) == 0 or len(labels) == 0:
return None, None # If no boxes or labels are found, return None
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
return boxes, labels
# Define the collate function for DataLoader
def collate_fn(batch):
# Filter out None values and pack the rest into a batch
batch = [item for item in batch if item[0] is not None and item[1] is not None]
return tuple(zip(*batch))
# Load your pre-trained model (or initialize if required)
def load_model():
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# Assuming 2 classes: mask and no_mask
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.to(device)
return model
# Inference function
def infer(image):
model = load_model() # Load the model
model.eval()
# Apply transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize all images to 224x224
transforms.ToTensor(),
])
image = Image.fromarray(image)
image = transform(image).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
prediction = model(image)
# Get boxes and labels from the predictions
boxes = prediction[0]['boxes'].cpu().numpy()
labels = prediction[0]['labels'].cpu().numpy()
return boxes, labels
# Gradio interface
def gradio_interface(image):
boxes, labels = infer(image)
# Assuming labels: 0 = no mask, 1 = mask
result = {"boxes": boxes, "labels": labels}
return result
# Create Gradio interface
iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json")
# Launch Gradio interface
iface.launch()
|