Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import os | |
from torch.utils.data import Dataset, DataLoader | |
from torchvision import transforms | |
from PIL import Image | |
import xml.etree.ElementTree as ET | |
import torch.optim as optim | |
from torch import nn | |
# Your model training and evaluation functions (already defined in your previous code) | |
# Define the custom dataset | |
class FaceMaskDataset(Dataset): | |
def __init__(self, images_dir, annotations_dir, transform=None, resize=(800, 800)): | |
self.images_dir = images_dir | |
self.annotations_dir = annotations_dir | |
self.transform = transform | |
self.resize = resize | |
self.image_files = os.listdir(images_dir) | |
def __len__(self): | |
return len(self.image_files) | |
def __getitem__(self, idx): | |
image_path = os.path.join(self.images_dir, self.image_files[idx]) | |
image = Image.open(image_path).convert("RGB") | |
image = image.resize(self.resize) | |
annotation_path = os.path.join(self.annotations_dir, self.image_files[idx].replace(".jpg", ".xml").replace(".png", ".xml")) | |
if not os.path.exists(annotation_path): | |
print(f"Warning: Annotation file {annotation_path} does not exist. Skipping image {self.image_files[idx]}.") | |
return None, None # Return None if annotation is missing | |
boxes, labels = self.load_annotations(annotation_path) | |
if boxes is None or labels is None: | |
return None, None # Skip if annotations are invalid | |
target = {'boxes': boxes, 'labels': labels} | |
if self.transform: | |
image = self.transform(image) | |
return image, target | |
def load_annotations(self, annotation_path): | |
tree = ET.parse(annotation_path) | |
root = tree.getroot() | |
boxes = [] | |
labels = [] | |
for obj in root.iter('object'): | |
label = obj.find('name').text | |
bndbox = obj.find('bndbox') | |
xmin = float(bndbox.find('xmin').text) | |
ymin = float(bndbox.find('ymin').text) | |
xmax = float(bndbox.find('xmax').text) | |
ymax = float(bndbox.find('ymax').text) | |
boxes.append([xmin, ymin, xmax, ymax]) | |
labels.append(1 if label == "mask" else 0) # "mask" = 1, "no_mask" = 0 | |
if len(boxes) == 0 or len(labels) == 0: | |
return None, None # If no boxes or labels, return None | |
boxes = torch.as_tensor(boxes, dtype=torch.float32) | |
labels = torch.tensor(labels, dtype=torch.int64) | |
return boxes, labels | |
# Model Training Loop (referred to from previous code) | |
def train_model(model, train_loader, val_loader, optimizer, num_epochs=10): | |
for epoch in range(num_epochs): | |
# Training loop | |
running_loss = 0.0 | |
model.train() | |
for images, targets in train_loader: | |
if images is None or targets is None: | |
continue # Skip invalid images/annotations | |
# Move data to device | |
images = [image.to(device) for image in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
optimizer.zero_grad() | |
loss_dict = model(images, targets) | |
# Calculate total loss | |
total_loss = sum(loss for loss in loss_dict.values()) | |
total_loss.backward() | |
optimizer.step() | |
running_loss += total_loss.item() | |
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}") | |
# Evaluate after every epoch | |
val_loss = evaluate_model(model, val_loader) | |
print(f"Validation Loss: {val_loss}") | |
# Validation function | |
def evaluate_model(model, val_loader): | |
model.eval() | |
running_loss = 0.0 | |
with torch.no_grad(): | |
for images, targets in val_loader: | |
if images is None or targets is None: | |
continue # Skip invalid data | |
# Move data to device | |
images = [image.to(device) for image in images] | |
targets = [{k: v.to(device) for k, v in t.items()} for t in targets] | |
loss_dict = model(images, targets) | |
# Calculate total loss | |
total_loss = sum(loss for loss in loss_dict.values()) | |
running_loss += total_loss.item() | |
return running_loss / len(val_loader) | |
# Function to upload dataset and start training | |
def train_on_uploaded_data(train_data, val_data): | |
# Save the uploaded dataset (files) | |
train_data_path = "train_data.zip" | |
val_data_path = "val_data.zip" | |
# Unzip and prepare directories (assuming you upload zip files for simplicity) | |
with open(train_data.name, 'wb') as f: | |
f.write(train_data.read()) | |
with open(val_data.name, 'wb') as f: | |
f.write(val_data.read()) | |
# Extract zip files | |
os.system(f"unzip {train_data_path} -d ./train/") | |
os.system(f"unzip {val_data_path} -d ./val/") | |
# Load datasets | |
train_dataset = FaceMaskDataset( | |
images_dir="train/images", | |
annotations_dir="train/annotations", | |
transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
) | |
val_dataset = FaceMaskDataset( | |
images_dir="val/images", | |
annotations_dir="val/annotations", | |
transform=transforms.Compose([transforms.Resize((224, 224)), transforms.ToTensor()]) | |
) | |
# Dataloaders | |
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) | |
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) | |
# Train the model | |
model = get_model(num_classes=2) # Assuming you have a model function | |
model.to(device) | |
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005) | |
# Train the model and return feedback | |
train_model(model, train_loader, val_loader, optimizer, num_epochs=10) | |
return "Training completed and model saved." | |
# Create Gradio Interface | |
iface = gr.Interface( | |
fn=train_on_uploaded_data, | |
inputs=[ | |
gr.File(label="Upload Train Dataset (ZIP)"), | |
gr.File(label="Upload Validation Dataset (ZIP)") | |
], | |
outputs=gr.Textbox(label="Training Status"), | |
live=True | |
) | |
# Launch Gradio interface | |
iface.launch() | |