Spaces:
Configuration error
Configuration error
import gradio as gr | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from resnet import SupCEResNet # Ensure the correct import path | |
# β Define class labels (from Clothing1M) | |
class_labels = [ | |
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie", | |
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress", | |
"Vest", "Underwear" | |
] | |
# β Function to load the model | |
def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'): | |
"""Loads a self-supervised pretrained model for Clothing1M classification""" | |
print(f"π Loading model from: {checkpoint_path}") | |
# Load the checkpoint safely | |
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False) | |
# Remove 'module.' prefix if using DataParallel | |
state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()} | |
# Initialize and load model | |
model = SupCEResNet(net, num_classes=num_class, pool=True) | |
model.load_state_dict(state_dict, strict=False) | |
# Move model to GPU if available | |
model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
model.eval() # Set model to evaluation mode | |
print("β Model loaded successfully!") | |
return model | |
# β Load the model once | |
model = create_model_selfsup() | |
# β Define image preprocessing function | |
def preprocess_image(image): | |
"""Transforms input image for the model""" | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
]) | |
return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu") | |
# β Define inference function | |
def predict_clothing(image): | |
"""Runs inference on an uploaded image""" | |
image = Image.fromarray(image) # Convert numpy array to PIL Image | |
image = preprocess_image(image) # Preprocess image | |
with torch.no_grad(): | |
output = model(image) | |
predicted_class = torch.argmax(output, dim=1).item() # Get class index | |
return class_labels[predicted_class] # Return class name | |
# β Create Gradio Interface | |
gr.Interface( | |
fn=predict_clothing, | |
inputs=gr.Image(type="numpy"), | |
outputs=gr.Textbox(label="Predicted Clothing Type"), | |
title="Clothing1M Classification", | |
description="Upload an image to classify clothing into one of 14 categories." | |
).launch() | |