MoinulwithAI's picture
Update app.py
c4bd279 verified
import gradio as gr
import os
import zipfile
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import xml.etree.ElementTree as ET
import torchvision.models.detection
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Dataset class
class FaceMaskDataset(Dataset):
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)):
self.images_dir = images_dir
self.annotations_dir = annotations_dir
self.transform = transform
self.resize = resize
self.image_files = [f for f in os.listdir(images_dir) if f.endswith(('.jpg', '.png'))]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_path = os.path.join(self.images_dir, self.image_files[idx])
image = Image.open(image_path).convert("RGB")
image = image.resize(self.resize)
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml"))
if not os.path.exists(annotation_path):
return None, None
boxes, labels = self.load_annotations(annotation_path)
if boxes is None or labels is None:
return None, None
target = {'boxes': boxes, 'labels': labels}
if self.transform:
image = self.transform(image)
return image, target
def load_annotations(self, annotation_path):
tree = ET.parse(annotation_path)
root = tree.getroot()
boxes = []
labels = []
for obj in root.iter('object'):
label = obj.find('name').text
bndbox = obj.find('bndbox')
xmin = float(bndbox.find('xmin').text)
ymin = float(bndbox.find('ymin').text)
xmax = float(bndbox.find('xmax').text)
ymax = float(bndbox.find('ymax').text)
boxes.append([xmin, ymin, xmax, ymax])
labels.append(1 if label == "mask" else 0)
if not boxes or not labels:
return None, None
boxes = torch.tensor(boxes, dtype=torch.float32)
labels = torch.tensor(labels, dtype=torch.int64)
return boxes, labels
def collate_fn(batch):
batch = [b for b in batch if b[0] is not None and b[1] is not None]
images, targets = zip(*batch)
return list(images), list(targets)
def get_model(num_classes):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def extract_zip(zip_file, extract_to):
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(extract_to)
def train_model(train_zip, val_zip):
extract_zip(train_zip, './data/train')
extract_zip(val_zip, './data/val')
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = FaceMaskDataset(
images_dir='./data/train/train/images',
annotations_dir='./data/train/train/annotations',
transform=transform
)
val_dataset = FaceMaskDataset(
images_dir='./data/val/val/images',
annotations_dir='./data/val/val/annotations',
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)
model = get_model(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
for epoch in range(3): # Reduce for demo
model.train()
total_loss = 0
for images, targets in train_loader:
images = [img.to(device) for img in images]
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
optimizer.zero_grad()
loss_dict = model(images, targets)
loss = sum(loss for loss in loss_dict.values())
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")
torch.save(model.state_dict(), "model.pth")
return "Training completed. Model saved as model.pth"
# Gradio upload interface
iface = gr.Interface(
fn=train_model,
inputs=[
gr.File(label="Upload Train ZIP"),
gr.File(label="Upload Val ZIP")
],
outputs="text"
)
if __name__ == "__main__":
iface.launch()