MoinulwithAI's picture
Create app.py
ef4706e verified
raw
history blame
4.41 kB
import os
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision
from torchvision import transforms
import xml.etree.ElementTree as ET
import torch.optim as optim
import matplotlib.pyplot as plt
import gradio as gr
# Ensure device is set to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class FaceMaskDataset(Dataset):
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
self.images_dir = images_dir
self.annotations_dir = annotations_dir
self.transform = transform
self.resize = resize
self.image_files = os.listdir(images_dir)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = os.path.join(self.images_dir, self.image_files[idx])
image = Image.open(image_path).convert("RGB")
# Resize the image to a fixed size, while maintaining aspect ratio
image = image.resize(self.resize)
# Handle both .jpg and .png files
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
if not os.path.exists(annotation_path):
print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.")
return None, None # Return a tuple with None to skip the image/annotation pair
boxes, labels = self.load_annotations(annotation_path)
if boxes is None or labels is None:
return None, None # Skip this item if annotations are invalid
target = {'boxes': boxes, 'labels': labels}
if self.transform:
image = self.transform(image)
return image, target
def load_annotations(self, annotation_path):
tree = ET.parse(annotation_path)
root = tree.getroot()
boxes = []
labels = []
for obj in root.iter('object'):
label = obj.find('name').text
bndbox = obj.find('bndbox')
xmin = float(bndbox.find('xmin').text)
ymin = float(bndbox.find('ymin').text)
xmax = float(bndbox.find('xmax').text)
ymax = float(bndbox.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(1 if label == "mask" else 0) # Assuming "mask" = 1, "no_mask" = 0
if len(boxes) == 0 or len(labels) == 0:
return None, None # If no boxes or labels are found, return None
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
return boxes, labels
# Define the collate function for DataLoader
def collate_fn(batch):
# Filter out None values and pack the rest into a batch
batch = [item for item in batch if item[0] is not None and item[1] is not None]
return tuple(zip(*batch))
# Load your pre-trained model (or initialize if required)
def load_model():
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# Assuming 2 classes: mask and no_mask
num_classes = 2
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.to(device)
return model
# Inference function
def infer(image):
model = load_model() # Load the model
model.eval()
# Apply transformations
transform = transforms.Compose([
transforms.Resize((224, 224)), # Resize all images to 224x224
transforms.ToTensor(),
])
image = Image.fromarray(image)
image = transform(image).unsqueeze(0).to(device) # Add batch dimension
with torch.no_grad():
prediction = model(image)
# Get boxes and labels from the predictions
boxes = prediction[0]['boxes'].cpu().numpy()
labels = prediction[0]['labels'].cpu().numpy()
return boxes, labels
# Gradio interface
def gradio_interface(image):
boxes, labels = infer(image)
# Assuming labels: 0 = no mask, 1 = mask
result = {"boxes": boxes, "labels": labels}
return result
# Create Gradio interface
iface = gr.Interface(fn=gradio_interface, inputs=gr.Image(type="numpy"), outputs="json")
# Launch Gradio interface
iface.launch()