Clothing / app.py
Moditha24's picture
Update app.py
49d337b verified
import subprocess
import sys
import torch
print("βœ… PyTorch version:", torch.__version__)
# Ensure torch is installed before importing
try:
import torch
except ModuleNotFoundError:
print("🚨 Torch not found! Installing...")
subprocess.run([sys.executable, "-m", "pip", "install", "torch", "torchvision", "torchaudio"], check=True)
import torch # Try again after installation
# βœ… Define class labels (from Clothing1M)
class_labels = [
"T-Shirt", "Shirt", "Knitwear", "Chiffon", "Sweater", "Hoodie",
"Windbreaker", "Jacket", "Downcoat", "Suit", "Shawl", "Dress",
"Vest", "Underwear"
]
# βœ… Function to load the model
def create_model_selfsup(net='resnet50', num_class=14, checkpoint_path='/content/ckpt_clothing_resnet50.pth'):
"""Loads a self-supervised pretrained model for Clothing1M classification"""
print(f"πŸ”„ Loading model from: {checkpoint_path}")
# Load the checkpoint safely
checkpoint = torch.load(checkpoint_path, map_location="cuda" if torch.cuda.is_available() else "cpu", weights_only=False)
# Remove 'module.' prefix if using DataParallel
state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}
# Initialize and load model
model = SupCEResNet(net, num_classes=num_class, pool=True)
model.load_state_dict(state_dict, strict=False)
# Move model to GPU if available
model = model.to("cuda" if torch.cuda.is_available() else "cpu")
model.eval() # Set model to evaluation mode
print("βœ… Model loaded successfully!")
return model
# βœ… Load the model once
model = create_model_selfsup()
# βœ… Define image preprocessing function
def preprocess_image(image):
"""Transforms input image for the model"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return transform(image).unsqueeze(0).to("cuda" if torch.cuda.is_available() else "cpu")
# βœ… Define inference function
def predict_clothing(image):
"""Runs inference on an uploaded image"""
image = Image.fromarray(image) # Convert numpy array to PIL Image
image = preprocess_image(image) # Preprocess image
with torch.no_grad():
output = model(image)
predicted_class = torch.argmax(output, dim=1).item() # Get class index
return class_labels[predicted_class] # Return class name
# βœ… Create Gradio Interface
gr.Interface(
fn=predict_clothing,
inputs=gr.Image(type="numpy"),
outputs=gr.Textbox(label="Predicted Clothing Type"),
title="Clothing1M Classification",
description="Upload an image to classify clothing into one of 14 categories."
).launch()